#' Visualization functions for FRASER
#' The FRASER package provides mutliple functions to visualize
#' the data and the results of a full data set analysis.
#' This is the list of all plotting function provided by FRASER:
#' \itemize{
#'   \item plotAberrantPerSample()
#'   \item plotVolcano()
#'   \item plotExpression()
#'   \item plotQQ()
#'   \item plotExpectedVsObservedPsi()
#'   \item plotCountCorHeatmap()
#'   \item plotFilterExpression()
#'   \item plotFilterVariability()
#'   \item plotEncDimSearch()
#' }
#' For a detailed description of each plot function please see the details.
#' Most of the functions share the same parameters.
#### Data specific parameters
#' @param object,fds An \code{\link{FraserDataSet}} object.
#' @param type The psi type: either psi5, psi3 or theta (for SE).
#' @param sampleID A sample ID which should be plotted. Can also be a vector.
#'             Integers are treated as indices.
#' @param idx,site A junction site ID or gene ID or one of both, which
#'             should be plotted. Can also be a vector. Integers are treated
#'             as indices.
#' @param padjCutoff,zScoreCutoff,deltaPsiCutoff Significance, Z-score or delta
#'             psi cutoff to mark outliers
#' @param global Flag to plot a global Q-Q plot, default FALSE
#' @param normalized If TRUE, the normalized psi values are used, the default,
#'             otherwise the raw psi values
#' @param aggregate If TRUE, the pvalues are aggregated by gene (default), 
#'             otherwise junction level pvalues are used (default for Q-Q plot).
#' @param result The result table to be used by the method.
#' @param label Indicates the genes or samples that will be labelled in the 
#'             plot (only for \code{basePlot=TRUE}). Setting 
#'             \code{label="aberrant"} will label all aberrant genes or 
#'             samples. Labelling can be turned off by setting 
#'             \code{label=NULL}. The user can also provide a custom 
#'             list of gene symbols or sampleIDs.
#' @param BPPARAM BiocParallel parameter to use.
#' @param Ncpus Number of cores to use.
#' @param plotType The type of plot that should be shown as character string. 
#'             For plotEncDimSearch, it has to be either \code{"auc"} 
#'             for a plot of the area under the curve (AUC) or 
#'             \code{"loss"} for the model loss. For the correlation heatmap,
#'             it can be either \code{"sampleCorrelation"} for a
#'             sample-sample correlation heatmap or \code{"junctionSample"}
#'             for a junction-sample correlation heatmap.
#' @param onlyVariableIntrons Logical value indicating whether to show only 
#'              introns that also pass the variability filter. Defaults to 
#'              FALSE.
#' @param onlyExpressedIntrons Logical value indicating whether to show only 
#'              introns that also pass the expression filter. Defaults to 
#'              FALSE.
#### Graphical parameters
#' @param main Title for the plot, if missing a default title will be used.
#' @param colGroup Group of samples that should be colored.
#' @param basePlot if \code{TRUE} (default), use the R base plot version, else
#'             use the plotly framework.
#' @param conf.alpha If set, a confidence interval is plotted, defaults to 0.05
#' @param samplingPrecision Plot only non overlapping points in Q-Q plot to 
#'             reduce number of points to plot. Defines the digits to round to. 
#' @param logit If TRUE, the default, psi values are plotted in logit space.
#' @param nClust Number of clusters to show in the row and
#'             column dendrograms.
#' @param sampleClustering A clustering of the samples that should be used as an
#'             annotation of the heatmap.
#' @param show_rownames,show_colnames Logical value indicating whether to show
#'             row or column names on the heatmap axes.
#' @param annotation_col,annotation_row Row or column annotations that should be
#'             plotted on the heatmap.
#' @param topN,topJ Top x most variable junctions that should be used in the
#'             heatmap. TopN is used for sample-sample correlation heatmaps and
#'             topJ for junction-sample correlation heatmaps.
#' @param minMedian,minCount,minDeltaPsi Minimal median (\eqn{m \ge 1}), 
#'             delta psi (\eqn{|\Delta\psi| > 0.1}), read count (\eqn{n \ge 10})
#'             value of a junction to be considered for the correlation heatmap.
#' @param border_color Sets the border color of the heatmap
#' @param plotMeanPsi,plotCov If \code{TRUE}, then the heatmap is annotated with
#'             the mean psi values or the junction coverage.
#' @param bins Set the number of bins to be used in the histogram.
#' @param legend.position Set legend position (x and y coordinate), defaults to
#'             the top right corner.
#### Additional ... parameter
#' @param ... Additional parameters passed to plot() or plot_ly() if not stated
#'             otherwise in the details for each plot function
#' @details
#' \code{plotAberrantPerSample}: The number of aberrant events per sample are
#' plotted sorted by rank. The ... parameters are passed on to the
#' \code{\link{aberrant}} function.
#' \code{plotVolcano}: the volcano plot is sample-centric. It plots for a given
#' sample and psi type the negative log10 nominal P-values against the delta psi
#' values for all splice sites or aggregates by gene if requested.
#' \code{plotExpression}: This function plots for a given site the
#' read count at this site (i.e. K) against the total coverage (i.e. N) for the
#' given psi type (\eqn{\psi_5, \psi_3, or \theta}{\psi5, \psi3, or \theta}
#' (SE)) for all samples.
#' \code{plotQQ}: the quantile-quantile plot for a given gene or if
#' \code{global} is set to \code{TRUE} over the full data set. Here the
#' observed P-values are plotted against the expected ones in the negative
#' log10 space.
#' \code{plotExpectedVsObservedPsi}: A scatter plot of the observed psi
#' against the predicted psi for a given site.
#' \code{plotCountCorHeatmap}: The correlation heatmap of the count data either
#' of the full data set (i.e. sample-sample correlations) or of the top x most
#' variable junctions (i.e. junction-sample correlations). By default the values
#' are log transformed and row centered. The ... arguments are passed to the
#' \code{\link[pheatmap]{pheatmap}} function.
#' \code{plotFilterExpression}: The distribution of FPKM values. If the 
#' FraserDataSet object contains the \code{passedFilter} column, it will plot 
#' both FPKM distributions for the expressed introns and for the filtered 
#' introns.
#' \code{plotFilterVariability}: The distribution of maximal delta Psi values. 
#' If the FraserDataSet object contains the \code{passedFilter} column, 
#' it will plot both maximal delta Psi distributions for the variable 
#' introns and for the filtered (i.e. non-variable) introns.
#' \code{plotEncDimSearch}: Visualization of the hyperparameter optimization.
#' It plots the encoding dimension against the achieved loss (area under the
#' precision-recall curve). From this plot the optimum should be choosen for
#' the \code{q} in fitting process.
#' @return If base R graphics are used nothing is returned else the plotly or
#'             the gplot object is returned.
#' @name plotFunctions
#' @rdname plotFunctions
#' @aliases plotFunctions plotAberrantPerSample plotVolcano plotQQ 
#'             plotExpression plotCountCorHeatmap plotFilterExpression 
#'             plotExpectedVsObservedPsi plotEncDimSearch
#' @examples
#' # create full FRASER object 
#' fds <- makeSimulatedFraserDataSet(m=40, j=200)
#' fds <- calculatePSIValues(fds)
#' fds <- filterExpressionAndVariability(fds, filter=FALSE)
#' # this step should be done for all splicing metrics and more dimensions
#' fds <- optimHyperParams(fds, "psi5", q_param=c(2,5,10,25))
#' fds <- FRASER(fds)
#' # QC plotting
#' plotFilterExpression(fds)
#' plotFilterVariability(fds)
#' plotCountCorHeatmap(fds, "theta")
#' plotCountCorHeatmap(fds, "theta", normalized=TRUE)
#' plotEncDimSearch(fds, type="psi5")
#' # extract results 
#' plotAberrantPerSample(fds)
#' plotVolcano(fds, "sample1", "psi5")
#' # dive into gene/sample level results
#' res <- results(fds)
#' res
#' plotExpression(fds, result=res[1])
#' plotQQ(fds, result=res[1])
#' plotExpectedVsObservedPsi(fds, type="psi5", res=res[1])

plotVolcano.FRASER <- function(object, sampleID, 
                    type=c("psi3", "psi5", "theta"), basePlot=TRUE, 
                    aggregate=FALSE, main=NULL, label=NULL,
                    deltaPsiCutoff=0.3, padjCutoff=0.1, ...){
    type <- match.arg(type)

    dt <- getPlottingDT(object, axis="col", type=type, idx=sampleID,
            aggregate=aggregate, deltaPsiCutoff=deltaPsiCutoff, 
            padjCutoff=padjCutoff, ...)
    g <- ggplot(dt, aes(x=deltaPsi, y=-log10(pval), color=aberrant, 
                        label=featureID, text=paste0(
            "SampleID: ", sampleID, "<br>",
            "featureID: ", featureID, "<br>",
            "Raw count (K): ", k, "<br>",
            "Raw total count (N): ", n, "<br>",
            "p value: ", signif(pval, 5), "<br>",
            "delta Psi: ", round(deltaPsi, 2), "<br>",
            "Type: ", type))) +
        geom_point(aes(alpha=ifelse(aberrant == TRUE, 1, 0.8))) +
                bquote(paste(Delta, .(ggplotLabelPsi(type)[[1]]) ))
            )) +
        ylab(expression(paste(-log[10], "(P value)"))) +
        theme_bw() +
        theme(legend.position="none") +
        scale_color_manual(values=c("gray40", "firebrick"))
        g <- g + 
            geom_vline(xintercept=c(-deltaPsiCutoff, deltaPsiCutoff),
                    color="firebrick", linetype=2)
        if(dt[,any(padj < padjCutoff)]){
            padj_line <- min(dt[padj < padjCutoff, -log10(pval)])
        if(!"padj_line" %in% ls() || padj_line > 10 || is.na(padj_line)){
            padj_line <- 6
        g <- g + 
            geom_hline(yintercept=padj_line, color="firebrick", linetype=4)
        g <- g + xlab(paste("delta", 
                            ggplotLabelPsi(type, asCharacter=TRUE)[[1]])) +
            ylab("-log[10](P value)")
            main <- paste0("Volcano plot: ", sampleID, ", ", 
                            ggplotLabelPsi(type, asCharacter=TRUE)[[1]])
    } else{
            if(isScalarCharacter(label) && label == "aberrant"){
                if(nrow(dt[aberrant == TRUE,]) > 0){
                    g <- g + geom_text_repel(data=dt[aberrant == TRUE,],
                                        fontface='bold', hjust=-.2, vjust=-.2)
                if(nrow(dt[featureID %in% label]) > 0){
                    g <- g + geom_text_repel(data=
                                        subset(dt, featureID %in% label),
                                    fontface='bold', hjust=-.2, vjust=-.2)
                if(any(!(label %in% dt[,featureID]))){
                    warning("Did not find gene(s) ", 
                            paste(label[!(label %in% dt[,featureID])], 
                                collapse=", "), " to label.")
            main <- as.expression(bquote(paste(
                bold("Volcano plot: "), .(sampleID), ", ",
    g <- g + ggtitle(main)
    plotBasePlot(g, basePlot)

#' Volcano plot
#' Plots the p values over the delta psi values, known as volcano plot.
#' Visualizes per sample the outliers. By type and aggregate by
#' gene if requested.
#' @rdname plotFunctions
#' @export
setMethod("plotVolcano", signature="FraserDataSet", plotVolcano.FRASER)

plotAberrantPerSample.FRASER <- function(object, main, 
                    type=c("psi3", "psi5", "theta"),
                    padjCutoff=0.1, zScoreCutoff=NA, deltaPsiCutoff=0.3,
                    aggregate=TRUE, BPPARAM=bpparam(), ...){

    type <- match.arg(type, several.ok=TRUE)

        main <- paste('Aberrant events per sample')
            main <- paste(main, "by gene")

    # extract outliers
    outliers <- bplapply(type, aberrant, object=object, by="sample",
            padjCutoff=padjCutoff, zScoreCutoff=zScoreCutoff,
            deltaPsiCutoff=deltaPsiCutoff, ..., BPPARAM=BPPARAM)
    dt2p <- rbindlist(lapply(seq_along(outliers), function(idx){
        vals <- outliers[[idx]]
        data.table(type=type[idx], value=sort(vals), median=median(vals),

    # plot them
    g <- ggplot(dt2p, aes(x=rank, y=value, color=type)) +
        geom_line() +
        geom_hline(aes(yintercept=median, color=type, lty="Median")) +
        theme_bw() +
        theme_cowplot() +
        ggtitle(main) +
        xlab("Sample rank") +
        ylab("Number of outliers") +
        scale_color_brewer(palette="Dark2", name=element_blank(),
                labels=ggplotLabelPsi(dt2p[,unique(type)])) +
        scale_linetype_manual(name="", values=2, labels="Median")

    if(!all(dt2p[,value] == 0)){
        g <- g + scale_y_log10()

#' Number of aberrant events per sample
#' Plot the number of aberrant events per samples
#' @rdname plotFunctions
#' @export
setMethod("plotAberrantPerSample", signature="FraserDataSet",

#' Junction expression plot
#' Plots the observed split reads of the junction of interest over all reads
#' coming from the given donor/acceptor.
#' @rdname plotFunctions
#' @export
plotExpression <- function(fds, type=c("psi5", "psi3", "theta"),
                    site=NULL, result=NULL, colGroup=NULL, 
                    basePlot=TRUE, main=NULL, label="aberrant", ...){
        type <- as.character(result$type)
        site <- getIndexFromResultTable(fds, result)
    } else {
        type <- match.arg(type)

    dt <- getPlottingDT(fds, axis="row", type=type, idx=site, ...)
        if(all(colGroup %in% samples(fds))){
            colGroup <- samples(fds) %in% colGroup
    dt[,aberrant:=factor(aberrant, levels=c("TRUE", "FALSE"))]

            main <- as.expression(bquote(bold(paste(
                .(ggplotLabelPsi(type)[[1]]), " expression plot: ",
                " (site ", .(as.character(dt[,unique(idx)])), ")"))))
        } else{
            main <- paste0(ggplotLabelPsi(type, asCharacter=TRUE)[[1]], 
                        " expression plot: ", dt[,unique(featureID)], 
                        " (site ", dt[,unique(idx)], ")")

    g <- ggplot(dt, aes(x=n + 2, y=k + 1, color=aberrant, label=sampleID, 
            "Sample: ", sampleID, "<br>",
            "Counts (K): ", k, "<br>",
            "Total counts (N): ", n, "<br>",
            "p value: ", signif(pval, 5), "<br>",
            "padjust: ", signif(padj, 5), "<br>",
            "Observed Psi: ", round(obsPsi, 2), "<br>",
            "Predicted mu: ", round(predPsi, 2), "<br>"))) +
        geom_point(alpha=ifelse(as.character(dt$aberrant) == "TRUE", 1, 0.7)) +
        scale_x_log10() +
        scale_y_log10() +
        theme_bw() +
        theme(legend.position="none", title=) +
        xlab("Total junction coverage + 2 (N)") +
        ylab("Junction count + 1 (K)") +
        ggtitle(main) +
    if(isTRUE(basePlot) && !is.null(label)){
        if(isScalarCharacter(label) && label == "aberrant"){
            if(nrow(dt[aberrant == TRUE,]) > 0){
                g <- g + geom_text_repel(data=dt[aberrant == TRUE,], 
                                        fontface='bold', hjust=-.2, vjust=-.2)
            if(nrow(dt[sampleID %in% label]) > 0){
                g <- g + geom_text_repel(data=subset(dt, sampleID %in% label), 
                                        fontface='bold', hjust=-.2, vjust=-.2)
            if(any(!(label %in% dt[,sampleID]))){
                warning("Did not find sample(s) ", 
                        paste(label[!(label %in% dt[,sampleID])], 
                            collapse=", "), " to label.")

        g <- g + scale_colour_manual(
                values=c("FALSE"="gray70", "TRUE"="firebrick"))

    plotBasePlot(g, basePlot)

#' Expected over Overserved plot
#' Plots the expected psi value over the observed psi value of the given
#' junction.
#' @rdname plotFunctions
#' @export
plotExpectedVsObservedPsi <- function(fds, type=c("psi5", "psi3", "theta"),
                    idx=NULL, result=NULL, colGroup=NULL, main=NULL,
                    basePlot=TRUE, label="aberrant", ...){
    type <- match.arg(type)
    # get plotting data
    dt   <- getPlottingDT(fds, axis="row", type=type, result=result, 
            idx=idx, ...)
    type <- as.character(unique(dt$type))
    idx  <- unique(dt$idx)
            main <- as.expression(bquote(bold(paste(
                " observed expression vs prediction plot: ",
                " (site ", .(as.character(idx)), ")"))))
        } else{
            main <- paste0(ggplotLabelPsi(type, asCharacter=TRUE)[[1]], 
                            " observed expression vs prediction plot: ", 
                            dt[,unique(featureID)], " (site ", idx, ")")
            dt[colGroup, aberrant:=TRUE]
        } else {
            warning("not implemented yet!")

        ylab <- bquote("Observed " ~ .(ggplotLabelPsi(type)[[1]]))
        xlab <- bquote("Predicted " ~ .(ggplotLabelPsi(type)[[1]]))
    } else{
        ylab <- paste("Observed", ggplotLabelPsi(type, asCharacter=TRUE)[[1]])
        xlab <- paste("Predicted", ggplotLabelPsi(type, asCharacter=TRUE)[[1]])
    g <- ggplot(dt, aes(y=obsPsi, x=predPsi, label=sampleID, color=aberrant, 
            "Sample: ", sampleID, "<br>",
            "Counts (K): ", k, "<br>",
            "Total counts (N): ", n, "<br>",
            "p value: ", signif(pval, 5), "<br>",
            "padjust: ", signif(padj, 5), "<br>",
            "Observed Psi: ", round(obsPsi, 2), "<br>",
            "Predicted mu: ", round(predPsi, 2), "<br>"))) +
        geom_point(alpha=ifelse(dt$aberrant, 1, 0.5),
                color=c("gray70", "firebrick")[dt$aberrant + 1]) +
        geom_abline(intercept = 0, slope=1) +
        theme_bw() +
        theme(legend.position="none") +
        xlab(xlab) +
        ylab(ylab) +
    if(isTRUE(basePlot) && !is.null(label)){
        if(isScalarCharacter(label) && label == "aberrant"){
            if(nrow(dt[aberrant == TRUE,]) > 0){
                g <- g + geom_text_repel(data=dt[aberrant == TRUE,], 
                                        fontface='bold', hjust=-.2, vjust=-.2)
            if(nrow(dt[sampleID %in% label]) > 0){
                g <- g + geom_text_repel(data=subset(dt, sampleID %in% label), 
                                    fontface='bold', hjust=-.2, vjust=-.2)
            if(any(!(label %in% dt[,sampleID]))){
                warning("Did not find sample(s) ", 
                        paste(label[!(label %in% dt[,sampleID])], 
                                collapse=", "), " to label.")

        g <- g + scale_colour_manual(
            values=c("FALSE"="gray70", "TRUE"="firebrick"))
    plotBasePlot(g, basePlot)

plotQQ.FRASER <- function(object, type=NULL, idx=NULL, result=NULL, 
                    aggregate=FALSE, global=FALSE, main=NULL, conf.alpha=0.05,
                    samplingPrecision=3, basePlot=TRUE, label="aberrant",
                    Ncpus=min(3, getDTthreads()), ...){

    # check parameters
        aggregate <- isTRUE(global)
    } else if(!(is.logical(aggregate) |
                all(aggregate %in% colnames(mcols(object))))){
        stop("Please provide TRUE/FALSE or a ",
            "charactor matching a column name in mcols.")

            type <- psiTypes
        dt <- rbindlist(bplapply(type, getPlottingDT, fds=object, axis="col",
                idx=TRUE, aggregate=aggregate, Ncpus=Ncpus, ...))
        # remove duplicated entries donor/acceptor sites if not aggregated 
        # by a feature
            dt <- dt[!duplicated(dt, by=c("type", "spliceID", "sampleID"))]
    } else {
        dots <- list(...)
        if(!"pvalLevel" %in% names(dots)){
            dots[["pvalLevel"]] <- "junction"
        dots <- append(list(fds=object, axis="row", type=type, idx=idx, 
                            result=result, aggregate=aggregate), 
        dt <- do.call(getPlottingDT, args=dots)
            main <- "Global QQ plot"
        } else {
            type <- as.character(dt[,unique(type)])
            featureID <- as.character(dt[,unique(featureID)])
                main <- as.expression(bquote(bold(paste(
                        " Q-Q plot: ", bolditalic(.(featureID)),
                        " (site ", .(as.character(dt[,unique(idx)])), ")"))))
            } else{
                main <- paste0(ggplotLabelPsi(type, asCharacter=TRUE)[[1]],
                                " Q-Q plot: ", featureID, 
                                " (site ", dt[,unique(idx)], ")")

    # points
    dt2p <- dt[order(type, pval)]
    dt2p[is.na(obs), obs:=0]
    dt2p[is.infinite(obs), obs:=dt2p[is.finite(obs),max(obs)]]
    if(dt2p[,length(unique(obs))] < 2 | nrow(dt2p) < 2){
        warning("There are no pvalues or all are NA or 1!")

    # down sample if requested
    if(isTRUE(samplingPrecision) | isScalarNumeric(samplingPrecision)){
            samplingPrecision <- 3
        mypoints <- !duplicated(dt2p[,.(obs=round(obs, samplingPrecision), 
                        exp=round(exp, samplingPrecision), type)], 
                by=c("obs", "exp", "type"))

    # create qq-plot
    g <- ggplot(dt2p[plotPoint == TRUE,], aes(x=exp, y=obs, col=aberrant, 
                "<br>SampleID: ", sampleID, "<br>K: ", k, "<br>N: ", n))) +
        geom_point() +
        theme_bw() +
        theme(legend.position="none") +
        g <- g +
            xlab(expression(-log[10]~"(expected P)")) +
            ylab(expression(-log[10]~"(observed P)"))
    } else{
        g <- g +
            xlab("-log[10] (expected P)") +
            ylab("-log[10] (observed P)")

    # Set color scale for global/local
        g <- g + scale_color_manual(values=c("black", "firebrick"),
    } else {
        g$mapping$colour <- quote(type)
        g <- g + scale_color_brewer(palette="Dark2", name="Splice metric",

    # add confidence band if requesded
    # http://genome.sph.umich.edu/wiki/Code_Sample:_Generating_QQ_Plots_in_R
        dt2p[,rank:=seq_len(.N), by=type]
        dt2p[plotPoint == TRUE,upper:=-log10(
                qbeta(conf.alpha/2, rank, max(rank) - rank)), by=type]
        dt2p[plotPoint == TRUE,lower:=-log10(
                qbeta(1-conf.alpha/2, rank, max(rank) - rank)), by=type]
        # only plot one psiType if multiple are existing
        if(length(unique(dt2p$type)) > 1){
            typeOrder <- c("theta", "psi5", "psi3")
            type2take <- min(which(typeOrder %in% unique(dt2p$type)))
            dt2p[type != typeOrder[type2take], 
                    c("upper", "lower"):=list(NA, NA)]
        g <- g + geom_ribbon(data=dt2p[plotPoint == TRUE & !is.na(upper),],
                aes(x=exp, ymin=lower, ymax=upper, text=NULL),
                alpha=0.2, color="gray")
    # add labels if requested
    if(isFALSE(global) && isTRUE(basePlot) && !is.null(label)){
        if(isScalarCharacter(label) && label == "aberrant"){
            if(nrow(dt2p[aberrant == TRUE,]) > 0){
                g <- g + geom_text_repel(data=dt2p[aberrant == TRUE,], 
                                        fontface='bold', hjust=-.2, vjust=-.2)
            if(nrow(dt2p[sampleID %in% label]) > 0){
                g <- g + geom_text_repel(data=subset(dt2p, sampleID %in% 
                                        fontface='bold', hjust=-.2, vjust=-.2)
            if(any(!(label %in% dt2p[,sampleID]))){
                warning("Did not find sample(s) ", 
                        paste(label[!(label %in% dt2p[,sampleID])], 
                                collapse=", "), " to label.")

    # add abline in the end
    g <- g + geom_abline(intercept=0, slope=1, col="firebrick")
        return(plotBasePlot(g, basePlot))

#' Q-Q plot
#' Plots the quantile-quantile plot
#' @rdname plotFunctions
#' @export
setMethod("plotQQ", signature="FraserDataSet", plotQQ.FRASER)

plotEncDimSearch.FRASER <- function(object, type=c("psi3", "psi5", "theta"), 
                    plotType=c("auc", "loss")){
    type <- match.arg(type)
    plotType <- match.arg(plotType)
    data <- hyperParams(object, type=type, all=TRUE)
    if (is.null(data)) {
        warning(paste("no hyperparameters were estimated for", type, 
                        "\nPlease use `optimHyperParams` to compute them."))
    if(!"nsubset" %in% colnames(data)){

    if(plotType == "auc"){
        g1 <- ggplot(data, aes(q, aroc, col=nsubset, linetype=noise)) +
            geom_point() +
            geom_smooth(method="loess", formula=y~x) +
            ggtitle("Q estimation") +
            xlab("Estimated q") +
            ylab("Area under the PR curve") +

        g2 <- ggplot(data, aes(q, eval, col=nsubset, linetype=noise)) +
            geom_point() +
            geom_smooth() +
            ggtitle("Q estimation") +
            xlab("Estimated q") +
            ylab("Model loss") +


#' Plots the results from the hyperparamter optimization.
#' @rdname plotFunctions
#' @export
setMethod("plotEncDimSearch", signature="FraserDataSet", 

#' Plot filter expression
#' Histogram of the geometric mean per junction based on the filter status
#' @rdname plotFunctions
#' @export
plotFilterExpression <- function(fds, bins=200, legend.position=c(0.8, 0.8),
    # check that expression filter has been calculated
    if(!("passedExpression" %in% colnames(mcols(fds, type="j")))){
        stop("Please calculate the expression filter values first with the ",
                "filterExpression function.")
    # get mean count for all junctions in the fds object
    cts    <- K(fds, "psi5")
    rowlgm <- exp(rowMeans(log(cts + 1)))

    dt <- data.table(
            passed=mcols(fds, type="j")[['passedExpression']])
        dt[,passed:=mcols(fds, type="j")[['passed']]]
    dt[,passed:=factor(passed, levels=c(TRUE, FALSE))]
    colors <- brewer.pal(3, "Dark2")[seq_len(2)]
    ggplot(dt, aes(value, fill=passed)) +
        geom_histogram(bins=bins) +
        scale_x_log10() +
        scale_y_log10() +
        scale_fill_manual(values=colors, name="Passed",
                labels=c("True", "False")) +
        xlab("Mean Junction Expression") +
        ylab("Count") +
        ggtitle("Expression filtering") +
        theme_bw() +

#' Plot filter variability
#' Histogram of minimal delta psi per junction
#' @rdname plotFunctions
#' @export
plotFilterVariability <- function(fds, bins=200, legend.position=c(0.8, 0.8),
    # check that expression filter has been calculated
    if(!("passedVariability" %in% colnames(mcols(fds, type="j")))){
        stop("Please calculate the expression filter values first with the ",
                "filterVariability function.")
    # get plotting data
    dt <- data.table(
        value=pmax(mcols(fds, type="j")[['maxDPsi3']], 
                    mcols(fds, type="j")[['maxDPsi5']],
                    mcols(fds, type="j")[['maxDThetaDonor']],
                    mcols(fds, type="j")[['maxDThetaAcceptor']]),
        passed=mcols(fds, type="j")[['passedVariability']])
        dt[,passed:=mcols(fds, type="j")[['passed']]]
    # check if file with removed counts exists and add them when it exists
    nonVarDir <- file.path(workingDir(fds), "savedObjects", nameNoSpace(fds),

        nV_stored <- loadHDF5SummarizedExperiment(dir=nonVarDir) 
        nonVar_dt <- data.table(
        dt <- rbind(dt, nonVar_dt)
    dt[,passed:=factor(passed, levels=c(TRUE, FALSE))]
    colors <- brewer.pal(3, "Dark2")[seq_len(2)]
    ggplot(dt, aes(value, fill=passed)) +
        geom_histogram(bins=bins) +
        scale_y_log10() + 
        scale_fill_manual(values=colors, name="Passed",
                            labels=c("True", "False")) +
        xlab(bquote("Maximal Junction" ~ Delta*Psi[5] ~ "or" ~ Delta*Psi[3])) +
        ylab("Count") +
        ggtitle("Variability filtering") +
        theme_bw() +

plotCountCorHeatmap.FRASER <- function(object,
                    type=c("psi5", "psi3", "theta"), logit=FALSE, topN=50000, 
                    topJ=5000, minMedian=1, minCount=10, 
                    main=NULL, normalized=FALSE, show_rownames=FALSE,
                    show_colnames=FALSE, minDeltaPsi=0.1, annotation_col=NA,
                    annotation_row=NA, border_color=NA, nClust=5,
                    plotType=c("sampleCorrelation", "junctionSample"),
                    sampleClustering=NULL, plotMeanPsi=TRUE, plotCov=TRUE, ...){

    type <- match.arg(type)
    plotType <- match.arg(plotType)

    # use counts as matrix, otherwise x(fds,...) does not work later on
    counts(object, type=type, side="other", HDF5=FALSE)      <-
        as.matrix(counts(object, type=type, side="other"))
    counts(object, type=type, side="ofInterest", HDF5=FALSE) <-
        as.matrix(counts(object, type=type, side="ofInterest"))

    kmat <- K(object, type=type)
    nmat <- N(object, type=type)

    expRowsMedian <- rowMedians(kmat) >= minMedian
    expRowsMax    <- rowMax(kmat)     >= minCount
    table(expRowsMax & expRowsMedian)

    skmat <- kmat[expRowsMax & expRowsMedian,]
    snmat <- nmat[expRowsMax & expRowsMedian,]

    xmat <- (skmat + 1)/(snmat + 2)
        xmat <- qlogisWithCap(xmat)
    xmat[snmat < minCount] <- NA
    xmat_rc    <- xmat - rowMeans(xmat, na.rm=TRUE)

    xmat_rc_sd <- rowSds(xmat_rc, na.rm=TRUE)
    plotIdx <- rank(xmat_rc_sd) >= length(xmat_rc_sd) - topN
    xmat_rc_2_plot <- xmat_rc[plotIdx,]
    cormatS <- cor(xmat_rc_2_plot, use="pairwise", method="spearman")
        pred_mu <- as.matrix(predictedMeans(object, type=type)[
            expRowsMax & expRowsMedian,][plotIdx,])
            pred_mu <- qlogisWithCap(pred_mu)
        pred_mu[(snmat < minCount)[plotIdx,]] <- NA
        lpred_mu_rc <- pred_mu - rowMeans(pred_mu, na.rm=TRUE)
        xmat_rc_2_plot <- xmat_rc_2_plot - lpred_mu_rc
    cormat <- cor(xmat_rc_2_plot, use="pairwise", method="spearman")
    breaks <- seq(-1, 1, length.out=50)

    if(plotType == "junctionSample"){

            pred_mu <- as.matrix(predictedMeans(object, type=type)[
                expRowsMax & expRowsMedian,])
                pred_mu <- qlogisWithCap(pred_mu)
            lpred_mu_rc <- pred_mu - rowMeans(pred_mu)
            xmat_rc <- xmat_rc - lpred_mu_rc

        object <- object[expRowsMax & expRowsMedian,,by=type]
        j2keepVa <- variableJunctions(object, type, minDeltaPsi)
        j2keepDP <- rowQuantiles(kmat[expRowsMax & expRowsMedian,],
                                    probs=0.75) >= 10
        j2keep <- j2keepDP & j2keepVa
        xmat_rc_2_plot <- xmat_rc[j2keep,]
        mostVarKeep <- subsetKMostVariableJunctions(object[j2keep,,by=type],
                                                    type, topJ)
        xmat_rc_2_plot <- xmat_rc_2_plot[mostVarKeep,]
        rownames(xmat_rc_2_plot) <- seq_len(nrow(xmat_rc_2_plot))
        breaks <- seq(-5, 5, length.out=50)


        annotation_col <- getColDataAsDFFactors(object, annotation_col)
        annotation_row <- getColDataAsDFFactors(object, annotation_row)

    # annotate with sample clusters
        # annotate samples with clusters from sample correlation heatmap
        nClust <- min(nClust, nrow(cormatS))
        clusters <- as.factor(cutree(hclust(dist(cormatS)), k=nClust))
    } else if(!is.na(sampleClustering)){
        clusters <- sampleClustering

            annotation_col$sampleCluster <- clusters
        } else {
            annotation_col <- data.frame(sampleCluster=clusters)

    if(plotType == "junctionSample"){

        # annotate junctions with meanPsi and meanCoverage
        xmat <- xmat[j2keep,]
        xmat <- xmat[mostVarKeep,]
        meanPsi <- if(isTRUE(logit)){
        } else{
        meanPsiBins <- cut(meanPsi, breaks = c(0, 0.33, 0.66, 1),
                annotation_row$meanPsi <- meanPsiBins
            } else{
                annotation_row <- data.frame(meanPsi=meanPsiBins)

        snmat <- snmat[j2keep,]
        snmat <- snmat[mostVarKeep,]
        meanCoverage <- rowMeans(snmat)
        cutpoints <- sort(unique(round(log10(meanCoverage))))
        if(max(cutpoints) < ceiling(log10(max(meanCoverage)))){
            cutpoints <- c(cutpoints, ceiling(log10(max(meanCoverage))))
        meanCoverage <- cut(meanCoverage, breaks=10^(cutpoints),

            annotation_row$meanCoverage <- meanCoverage
        if(isTRUE(nrow(annotation_row) > 0)){
            rownames(annotation_row) <- rownames(xmat_rc_2_plot)
        cormat <- xmat_rc_2_plot

        main <- ifelse(normalized, "Normalized intron-centered ", 
                        "Raw intron-centered ")
        if(plotType == "sampleCorrelation"){
                main <- paste0(main, "Logit(PSI) correlation (", type, ")")
            } else {
                main <- paste0(main, "PSI correlation (", type, ")")
        } else {
                main <- paste0(main, "Logit(PSI) data (", type, ", top ", topJ, 
            } else {
                main <- paste0(main, "PSI data (", type, ", top ", topJ, ")")

    pheatmap(cormat, show_rownames=show_rownames, show_colnames=show_colnames,
            main=main, annotation_col=annotation_col, breaks=breaks,
            annotation_row=annotation_row, ..., border_color=border_color,
            color=colorRampPalette(colors=rev(brewer.pal(11, "RdBu")))(50)

#' Plot count correlation
#' Count correlation heatmap function
#' @rdname plotFunctions
#' @export
setMethod("plotCountCorHeatmap", signature="FraserDataSet", 

#' helper function to get the annotation as data frame from the col data object
#' @noRd
getColDataAsDFFactors <- function(fds, names){
    tmpDF <- data.frame(colData(fds)[,names])
    colnames(tmpDF) <- names
    for(i in names){
        if(any(is.na(tmpDF[, i]))){
            tmpDF[,i] <- as.factor(paste0("", tmpDF[,i]))
        if(is.integer(tmpDF[,i]) && length(levels(as.factor(tmpDF[,i]))) <= 10){
            tmpDF[,i] <- as.factor(paste0("", tmpDF[,i]))
    rownames(tmpDF) <- rownames(colData(fds))

#' used to cap the qlogis for the correlation heatmap
#' @noRd
qlogisWithCap <- function(x, digits=2){
    x <- round(x, digits)
    x <- pmin(pmax(x, 10^-digits), 1-10^-digits)
    ans <- qlogis(x)
    ans[is.infinite(ans)] <- NA
    rowm <- rowMaxs(ans, na.rm=TRUE)
    idx <- which(is.na(ans), arr.ind=TRUE)
    ans[idx] <- rowm[idx[,"row"]]

#' Helper to get nice Splice metric labels in ggplot
#' @noRd
ggplotLabelPsi <- function(type, asCharacter=FALSE){
    type <- as.character(type)
        vapply(type, FUN=function(x)
            switch (x,
                    psi5 = c(bquote(psi[5])),
                    psi3 = c(bquote(psi[3])),
                    theta = c(bquote(theta))),
    } else{
        vapply(type, FUN=function(x)
            switch (x,
                    psi5 = "psi[5]",
                    psi3 = "psi[3]",
                    theta = "theta"),

