R/evalHeatmap.R

Defines functions .scaledLegend .getCellFn .getReducedNames .prepRes .ds_anno colCenterScale .defaultColorMapping .safescale .dosort .doscale evalHeatmap

Documented in colCenterScale evalHeatmap

#' evalHeatmap
#' 
#' General heatmap representation of aggregated evaluation results. By default,
#' the actual metric values are printed in the cells, and while the coloring is
#' determined by \code{\link{colCenterScale}} (number of matrix-median absolute
#'  deviations from the column means). Unless the total number of analyses is
#'  small, it is strongly recommended to use the `agg.by` argument to limit the
#'  size and improve the readability of the heatmap.
#'
#' @param res Aggregated pipeline results (i.e. the output of `runPipeline` or
#' `aggregateResults`)
#' @param step Name of the step for which to plot the evaluation results. If 
#' unspecified, will use the latest step that has evaluation results.
#' @param what What metric to plot.
#' @param what2 If the step has more than one benchmark data.frame, which one
#' to use. The function will attempt to guess that automatically based on 
#' `what`, and will notify in case of ambiguity.
#' @param agg.by Aggregate results by these columns (default no aggregation)
#' @param agg.fn Function for aggregation (default mean)
#' @param filterExpr An optional filtering expression based on the columns of 
#' the target dataframe, (e.g. `filterExpr=param1=="value1"`).
#' @param scale Controls the scaling of the columns for the color mapping. Can 
#' either be a logical (TRUE will use NA-safe column z-scores, FALSE will not 
#' scale) or a function performing the scaling. The default uses the
#'  `colCenterScale` function (per-column centering, but per-matrix variance
#'  scaling).
#' @param value_format Format for displaying cells' values (use 
#' `value_format=""` to disable)
#' @param reorder_rows Logical; whether to sort rows (default FALSE). The row 
#' names themselves can also be passed to specify an order, or a 
#' `ComplexHeatmap`.
#' @param row_split Optional column (included in `agg.by`) by which to split
#' the rows. Alternatively, an expression using the columns (retained after
#' aggregation) can be passed.
#' @param show_heatmap_legend Passed to `Heatmap` (default FALSE)
#' @param show_column_names Passed to `Heatmap` (default FALSE)
#' @param col Colors for the heatmap, or a color-mapping function as produced by
#' `colorRamp2`. If passing a vector of colors and the data is scaled, there 
#' should be an odd number of colors. By default, will apply linear mapping (if
#' the data is not scaled) or signed sqrt mapping (if scaled) on the 
#' `viridisLite::inferno` palette.
#' @param font_factor A scaling factor applied to fontsizes (default 1)
#' @param value_cols A vector of length 2 indicating the colors of the values
#' (above and below the mean), if printed
#' @param title Plot title
#' @param shortNames Logical; whether to use short row names (with only
#' the parameter values instead of the parameter name and value pairs), default
#' TRUE.
#' @param name Heatmap name (e.g. used for the legend)
#' @param anno_legend Logical; whether to plot the legend for the datasets
#' @param ... Passed to `Heatmap`
#' 
#' @return A Heatmap
#' @export
#' @importFrom viridisLite inferno
#' @importFrom grid gpar
#' @import ComplexHeatmap
#' @examples
#' data("exampleResults", package="pipeComp")
#' evalHeatmap( exampleResults, step="clustering", what=c("ARI","MI","min_pr"), 
#'              agg.by=c("filt", "norm"), row_split = "norm" ) +
#' evalHeatmap( exampleResults, step="clustering", what="ARI", 
#'              agg.by=c("filt", "norm"), filterExpr=n_clus==true.nbClusts, 
#'              name="ARI at true k", title="ARI at
#' true K" )
evalHeatmap <- function( res, step=NULL, what, what2=NULL, agg.by=NULL, 
                         agg.fn=mean, filterExpr=NULL, scale="colCenterScale", 
                         value_format="%.2f", reorder_rows=FALSE, 
                         show_heatmap_legend=FALSE, show_column_names=FALSE, 
                         col=NULL, font_factor=0.9, 
                         row_split=NULL, shortNames=TRUE,
                         value_cols=c("black","white"), title=NULL, 
                         name=NULL, anno_legend=TRUE, ...){
  pd <- NULL
  if(is(res,"SimpleList")) pd <- metadata(res)$PipelineDefinition
  #if(is.null(pd)) stop("Could not find the PipelineDefinition.")
  if(is.null(step)){
    step <- rev(names(res$evaluation))[1]
    if(length(res$evaluation)>1) message("Using step '",step,"'.")
  }
  step <- match.arg(step, names(res$evaluation))
  if(length(what)>1){
    title <- lapply(seq_along(what), FUN=function(i){
      tryCatch( title[[i]], error=function(e) what[[i]] )
    })
    fcall <- match.call()
    H <- NULL
    for(i in seq_along(what)){
      fcall$what <- what[[i]]
      fcall$title <- title[[i]]
      H <- H + eval(fcall)
    }
    return(H)
  }
  if(is(res,"SimpleList") && "evaluation" %in% names(res)) res <- res$evaluation
  if(is(res,"list") && step %in% names(res)) res <- res[[step]]
  if(is(res,"list")){
    if(is.null(what2)){
      # try to guess the table
      test <- which(vapply( res, FUN=function(x) what %in% colnames(x), 
                            logical(1) ))
      if(length(test)==0) stop("`what2` is undefined and could not be guessed.")
      if(length(test)>1) message("`what2` is undefined, and `what` is present ",
        "in more than one results data.frame (",
        paste(paste0("`",names(res)[test],"`"), collapse=", "), "). ",
        names(res)[test[1]], "' will be used.")
      what2 <- names(res)[test[1]]
    }
    what2 <- match.arg(what2, names(res))
    res <- res[[what2]]
  }
  if(!is.null(pd)){
    what_options <- setdiff( colnames(res), 
                           c("dataset",unlist(arguments(pd))) )
    what <- match.arg(what, choices=what_options)
  }
  res <- .prepRes(res, what=what, agg.by=agg.by, pipDef=pd, 
                  filt=substitute(filterExpr), shortNames=shortNames, 
                  returnParams=TRUE)
  pp <- res$pp
  res <- res$res
  res2 <- .doscale(res, param=scale)
  ro <- .dosort(res2, reorder_rows)
  res2 <- as.matrix(res2)
  cellfn <- .getCellFn(res, res2, value_format, value_cols, font_factor, 
                       scaled=!(is.logical(scale) && !scale) )
  if(is.null(title)) title <- gsub("\\.","\n",what)
  suppressWarnings({
    if(!tryCatch(is.null(row_split), error=function(e) FALSE)){
      if(tryCatch(is.character(row_split), error=function(e) FALSE)){
        if(row_split %in% colnames(pp)){
          row_split <- pp[,row_split]
        }else{
          warning("`row_split` wasn't found and will be ignored.")
          row_split <- NULL
        }
      }else{
        row_split <- eval(substitute(row_split), envir=pp)
      }
    }
  })
  if(is.null(name)) name <- what
  if(is.null(col) || !is.function(col))
    col <- .defaultColorMapping(res2, center=!(is.logical(scale) && !scale),
                                cols=col)
  Heatmap( res2, name=name, cluster_rows=FALSE, cluster_columns=FALSE, 
           show_heatmap_legend=show_heatmap_legend, row_order=ro,
           bottom_annotation=.ds_anno(colnames(res),anno_legend,font_factor), 
           show_column_names=show_column_names, cell_fun=cellfn, col=col,
           column_title=title, row_split=row_split,
           row_title_gp = gpar(fontsize = (14*font_factor)),
           column_title_gp = gpar(fontsize = (14*font_factor)),
           row_names_gp = gpar(fontsize = (12*font_factor)),
           column_names_gp = gpar(fontsize = (12*font_factor)), ... )
}

.doscale <- function(res, param){
  if(is.null(param)) param <- colCenterScale
  if(is.logical(param) && !param) return(res)
  if(is.logical(param) && param) return(.safescale(res))
  if(is.function(param)) return(param(res))
  if(is.character(param) && is.function(get(param))) return(get(param)(res))
  warning("Unknown scaling parameter, will use column z-scores.")
  .safescale(res)
}

#' @import ComplexHeatmap
.dosort <- function(res, reorder_rows){
  if(is(reorder_rows, "Heatmap")){
    ro <- row_order(reorder_rows)
  }else{
    if(length(reorder_rows)>1){
      ro <- reorder_rows
    }else{
      if(reorder_rows){
        ro <- order(rowMeans(res, na.rm=TRUE), decreasing=TRUE)
      }else{
        ro <- seq_len(nrow(res))
      }
    }
  }
  if(!is.numeric(ro))
    ro <- vapply(ro, FUN=function(x) which(row.names(res)==x), integer(1))
  ro
}

# NA-robust scale
.safescale <- function(x){
  if(!is.null(dim(x))){
    y <- apply(x,2,.safescale)
    row.names(y) <- row.names(x)
    return(y)
  }
  if(all(is.na(x))) return(x)
  if(sd(x,na.rm=TRUE)>0) return(base::scale(x))
  if(sum(!is.na(x))==0) return(base::scale(as.numeric(!is.na(x))))
  rep(0,length(x))
}

.defaultColorMapping <- function(x, center=TRUE, cols=NULL){
  if(is.null(cols)) cols <- viridisLite::inferno(101)
  if(!center) return(cols)
  q <- max(abs(range(x, na.rm=TRUE)))
  b <- c( -seq(from=sqrt(q), to=0, length.out=floor(length(cols)/2)+1)^2,
          seq(from=0, to=sqrt(q), length.out=floor(length(cols)/2)+1)[-1]^2 )
  colorRamp2(b, cols)
}

#' colCenterScale
#'
#' Matrix scaling by centering columns separately and then performing variance 
#' scaling on the whole matrix, in a NA-robust fashion. With the default 
#' arguments, the output will be the number of (matrix-)median absolute 
#' deviations from the column-median.
#'
#' @param x A numeric matrix.
#' @param centerFn The function for calculating centers. Should accept the 
#' `na.rm` argument. E.g. `centerFn=mean` or `centerFn=median`.
#' @param scaleFn The function for calculating the (matrix-wise) scaling
#' factor. Should accept the `na.rm` argument. Default `median(abs(x))`.
#'
#' @return A scaled matrix of the same dimensions as `x`.
#' @export
#'
#' @examples
#' # random data with column mean differences
#' d <- cbind(A=rnorm(5, 10, 2), B=rnorm(5, 20, 2), C=rnorm(5,30, 2))
#' colCenterScale(d)
colCenterScale <- function(x, centerFn=median, 
                           scaleFn=function(x,na.rm) median(abs(x),na.rm=na.rm)
                           ){
  if(is.null(dim(x))) stop("`x` should be a numeric matrix or data.frame.")
  centers <- apply(x, MARGIN=2, na.rm=TRUE, FUN=centerFn)
  x2 <- t(t(x)-centers)
  x2 <- x2/scaleFn(x2, na.rm=TRUE)
  if(all(is.na(x2))) x2[is.na(x2)] <- 0
  x2
}

#' @importFrom ComplexHeatmap HeatmapAnnotation
.ds_anno <- function(x, legend=TRUE, font_factor=1){
  y <- vapply( strsplit(gsub("stepElapsed\\.","",x)," "),
               FUN=function(x) x[[1]], FUN.VALUE=character(1))
  cols <- getQualitativePalette(length(unique(y)))
  names(cols) <- sort(unique(y))
  HeatmapAnnotation(dataset=y, col=list(dataset=cols), 
                    annotation_legend_param=list(
                      labels_gp=gpar(fontsize=10*font_factor),
                      title_gp=gpar(fontsize=10*font_factor, fontface = "bold")
                    ),
                    show_annotation_name = FALSE, show_legend=legend)
}

.prepRes <- function(res, what=NULL, agg.by=NULL, agg.fn=mean, pipDef=NULL, 
                     filt=NULL, long=FALSE, shortNames=FALSE, 
                     returnParams=FALSE){
  if(is.null(what) && !long) stop("`what` should be defined.")
  if(!is.null(filt) && filt!=TRUE){
    if(is.null(pipDef)) stop("Filtering requires providing a `pipDef`.")
    res <- res[which(eval(filt, res)),]
  }
  if(!is.null(agg.by)){
    agg.by2 <- c(agg.by, intersect(colnames(res),c("dataset","subpopulation")))
    if(is.null(what)){
      what <- colnames(res)[which(
        vapply(res, function(x) is.integer(x) | is.numeric(x), logical(1))
      )]
      what <- setdiff(what, agg.by)
    }
    res <- aggregate(res[,what,drop=FALSE], by=res[,agg.by2,drop=FALSE], 
                     FUN=agg.fn)
    pp <- res[,agg.by,drop=FALSE]
  }else{
    if(is.null(pipDef)) stop("Either `agg.by` or `pipDef` should be given.")
    pp <- res[,intersect(colnames(res), unlist(arguments(pipDef))),drop=FALSE]
  }
  pp$method <- res$method <- .getReducedNames(pp, short=shortNames)
  if(long) return(res)
  
  if("subpopulation" %in% colnames(res))
    res$dataset <- paste(res$dataset, res$subpopulation)
  r2 <- reshape2::dcast( res, method~dataset, value.var=what,  
                         drop=FALSE, fun.aggregate=agg.fn)
  pp <- pp[!duplicated(pp$method),,drop=FALSE]
  row.names(pp) <- pp$method
  row.names(r2) <- r2[,1]
  res <- as.matrix(r2[,-1,drop=FALSE])
  if(returnParams) return(list(res=res, pp=pp[row.names(res),,drop=FALSE]))
  res
}

.getReducedNames <- function(res, short=FALSE){
  if(is.character(res)) res <- parsePipNames(res)
  pp <- vapply(res, FUN=function(x) length(unique(x))>1, logical(1))
  pp <- res[,pp,drop=FALSE]
  if(!short && ncol(pp)>1){
    y <- apply(pp,1,FUN=function(x){
      x <- paste0(colnames(pp),"=",x)
      paste(x, collapse="; ")
    })
  }else{
    y <- apply(pp,1,collapse=" > ",FUN=paste)
  }
  y
}

#' @importFrom grid grid.text
.getCellFn  <- function( res, res2, value_format, cols=c("black","white"), 
                         font_factor=1, scaled=TRUE ){
  if(scaled){
    resmid <- median(res2, na.rm=TRUE)
  }else{
    resmid <- range(res2, na.rm=TRUE)
    resmid <- resmid[1]+(resmid[2]-resmid[1])/2
  }
  if(is.na(resmid)) resmid <- Inf
  function(j, i, x, y, width, height, fill){
    if(value_format=="" || is.null(value_format) || is.na(value_format))
      return(NULL)
    lab <- sprintf(value_format, res[i,j])
    if(!any(abs(res)>1, na.rm=TRUE)){
      lab <- gsub("^0\\.",".",lab)
      lab <- gsub("^-0\\.","-.",lab)
    } 
    lab <- gsub("^1.00$","1",lab)
    lab <- gsub("^.00$","0",lab)
    cols <- ifelse(res2[i,j]>=resmid,cols[1],cols[2])
    grid.text(lab, x, y, gp = gpar(fontsize = 10*font_factor, col=cols))
  }
}

.scaledLegend <- function(){
  ComplexHeatmap::Legend(
    col_fun=circlize::colorRamp2(seq(1,100), viridisLite::inferno(100)), 
    at=c(1,50,100), title="MADs", labels=c("worst","median","best"), 
    direction="horizontal" )
}
plger/pipeComp documentation built on Nov. 2, 2021, 8:40 p.m.