Nothing
#' @title Generate marker decision tree from single-cell clustering output
#' @description Create a decision tree that identifies gene markers for given
#' cell populations. The algorithm uses a decision tree procedure to generate
#' a set of rules for each cell cluster defined by single-cell clustering.
#' Splits are determined by one of two metrics at each split: a one-off metric
#' to determine rules for identifying clusters by a single feature, and a
#' balanced metric to determine rules for identifying sets of similar clusters.
#' @param x A numeric \link{matrix} of counts or a
#' \linkS4class{SingleCellExperiment}
#' with the matrix located in the assay slot under \code{useAssay}.
#' Rows represent features and columns represent cells.
#' @param useAssay A string specifying which \link{assay}
#' slot to use if \code{x} is a
#' \link[SingleCellExperiment]{SingleCellExperiment} object. Default "counts".
#' @param altExpName The name for the \link{altExp} slot
#' to use. Default "featureSubset".
#' @param class Vector of cell cluster labels.
#' @param oneoffMetric A character string. What one-off metric to run, either
#' `modified F1` or `pairwise AUC`. Default is 'modified F1'.
#' @param metaclusters List where each element is a metacluster (e.g. known
#' cell type) and all the clusters within that metacluster (e.g. subtypes).
#' @param featureLabels Vector of feature assignments, e.g. which cluster
#' does each gene belong to? Useful when using clusters of features
#' (e.g. gene modules or Seurat PCs) and user wishes to expand tree results
#' to individual features (e.g. score individual genes within marker gene
#' modules).
#' @param counts Numeric counts matrix. Useful when using clusters
#' of features (e.g. gene modules) and user wishes to expand tree results to
#' individual features (e.g. score individual genes within marker gene
#' modules). Row names should be individual feature names. Ignored if
#' \code{x} is a \linkS4class{SingleCellExperiment} object.
#' @param celda A \emph{celda_CG} or \emph{celda_C} object.
#' Counts matrix has to be provided as well.
#' @param seurat A seurat object. Note that the seurat functions
#' \emph{RunPCA} and \emph{FindClusters} must have been run on the object.
#' @param threshold Numeric between 0 and 1. The threshold for the oneoff
#' metric. Smaller values will result in more one-off splits. Default is 0.90.
#' @param reuseFeatures Logical. Whether or not a feature can be used more than
#' once on the same cluster. Default is TRUE.
#' @param altSplit Logical. Whether or not to force a marker for clusters that
#' are solely defined by the absence of markers. Default is TRUE.
#' @param consecutiveOneoff Logical. Whether or not to allow one-off splits at
#' consecutive brances. Default is FALSE.
#' @param autoMetaclusters Logical. Whether to identify metaclusters prior to
#' creating the tree based on the distance between clusters in a UMAP
#' dimensionality reduction projection. A metacluster is simply a large
#' cluster that includes several clusters within it. Default is TRUE.
#' @param seed Numeric. Seed used to enable reproducible UMAP results
#' for identifying metaclusters. Default is 12345.
#' @param ... Ignored. Placeholder to prevent check warning.
#' @return A named list with six elements:
#' \itemize{
#' \item rules - A named list with one data frame for every label. Each
#' data frame has five columns and gives the set of rules for disinguishing
#' each label.
#' \itemize{
#' \item feature - Marker feature, e.g. marker gene name.
#' \item direction - Relationship to feature value. -1 if cluster is
#' down-regulated for this feature, 1 if cluster is up-regulated.
#' \item stat - The performance value returned by the splitting metric for
#' this split.
#' \item statUsed - Which performance metric was used. "Split" if information
#' gain and "One-off" if one-off.
#' \item level - The level of the tree at which is rule was defined. 1 is the
#' level of the first split of the tree.
#' \item metacluster - Optional. If metaclusters were used, the metacluster
#' this rule is applied to.
#' }
#' \item dendro - A dendrogram object of the decision tree output. Plot with
#' plotMarkerDendro()
#' \item classLabels - A vector of the class labels used in the model, i.e.
#' cell cluster labels.
#' \item metaclusterLabels - A vector of the metacluster labels
#' used in the model
#' \item prediction - A character vector of label of predictions of the
#' training data using the final model. "MISSING" if label prediction was
#' ambiguous.
#' \item performance - A named list denoting the training performance of the
#' model:
#' \itemize{
#' \item accuracy - (number correct/number of samples) for the whole set of
#' samples.
#' \item balAcc - mean sensitivity across all clusters
#' \item meanPrecision - mean precision across all clusters
#' \item correct - the number of correct predictions of each cluster
#' \item sizes - the number of actual counts of each cluster
#' \item sensitivity - the sensitivity of the prediciton of each cluster
#' \item precision - the precision of the prediciton of each cluster
#' }
#' }
#' @examples
#' \dontrun{
#' # Generate simulated single-cell dataset using celda
#' sim_counts <- simulateCells("celda_CG", K = 4, L = 10, G = 100)
#'
#' # Celda clustering into 5 clusters & 10 modules
#' cm <- celda_CG(sim_counts, K = 5, L = 10, verbose = FALSE)
#'
#' # Get features matrix and cluster assignments
#' factorized <- factorizeMatrix(cm)
#' features <- factorized$proportions$cell
#' class <- celdaClusters(cm)
#'
#' # Generate Decision Tree
#' DecTree <- findMarkersTree(features, class)
#'
#' # Plot dendrogram
#' plotMarkerDendro(DecTree)
#' }
#' @export
setGeneric("findMarkersTree", function(x, ...) {
standardGeneric("findMarkersTree")})
#' @rdname findMarkersTree
#' @export
setMethod("findMarkersTree",
signature(x = "SingleCellExperiment"),
function(x,
useAssay = "counts",
altExpName = "featureSubset",
class,
oneoffMetric = c("modified F1", "pairwise AUC"),
metaclusters,
featureLabels,
counts,
seurat,
threshold = 0.90,
reuseFeatures = FALSE,
altSplit = TRUE,
consecutiveOneoff = FALSE,
autoMetaclusters = TRUE,
seed = 12345) {
altExp <- SingleCellExperiment::altExp(x, altExpName)
if ("celda_parameters" %in% names(S4Vectors::metadata(altExp))) {
counts <- SummarizedExperiment::assay(altExp, i = useAssay)
# factorize matrix (proportion of each module in each cell)
features <- factorizeMatrix(x,
useAssay = useAssay,
altExpName = altExpName)$proportions$cell
# get class labels
class <- celdaClusters(x, altExpName = altExpName)
# get feature labels
featureLabels <- paste0("L",
celdaModules(x, altExpName = altExpName))
} else if (methods::hasArg(seurat)) {
# get counts matrix from seurat object
counts <- as.matrix(seurat@assays$RNA@data)
# get class labels
class <- as.character(Seurat::Idents(seurat))
# get feature labels
featureLabels <-
unlist(apply(
seurat@reductions$pca@feature.loadings, 1,
function(x) {
return(names(x)[which(x == max(x))])
}
))
# sum counts for each PC in each cell
features <-
matrix(
unlist(lapply(unique(featureLabels), function(pc) {
colSums(counts[featureLabels == pc, ])
})),
ncol = length(class),
byrow = TRUE,
dimnames = list(unique(featureLabels), colnames(counts))
)
# normalize column-wise (i.e. convert counts to proportions)
features <- apply(features, 2, function(x) {
x / sum(x)
})
}
if (ncol(features) != length(class)) {
stop("Number of columns of features must equal length of class")
}
if (any(is.na(class))) {
stop("NA class values")
}
if (any(is.na(features))) {
stop("NA feature values")
}
# Match the oneoffMetric argument
oneoffMetric <- match.arg(oneoffMetric)
branchPoints <- .findMarkersTree(features = features,
class = class,
oneoffMetric = oneoffMetric,
metaclusters = metaclusters,
featureLabels = featureLabels,
counts = counts,
seurat = seurat,
threshold = threshold,
reuseFeatures = reuseFeatures,
altSplit = altSplit,
consecutiveOneoff = consecutiveOneoff,
autoMetaclusters = autoMetaclusters,
seed = seed)
return(branchPoints)
}
)
#' @rdname findMarkersTree
#' @export
setMethod("findMarkersTree",
signature(x = "matrix"),
function(x,
class,
oneoffMetric = c("modified F1", "pairwise AUC"),
metaclusters,
featureLabels,
counts,
celda,
seurat,
threshold = 0.90,
reuseFeatures = FALSE,
altSplit = TRUE,
consecutiveOneoff = FALSE,
autoMetaclusters = TRUE,
seed = 12345) {
features <- x
if (methods::hasArg(celda)) {
# check that counts matrix is provided
if (!methods::hasArg(counts)) {
stop("Please provide counts matrix in addition to",
" celda object.")
}
# factorize matrix (proportion of each module in each cell)
features <- factorizeMatrix(counts, celda)$proportions$cell
# get class labels
class <- celdaClusters(celda)$z
# get feature labels
featureLabels <- paste0("L", celdaClusters(celda)$y)
} else if (methods::hasArg(seurat)) {
# get counts matrix from seurat object
counts <- as.matrix(seurat@assays$RNA@data)
# get class labels
class <- as.character(Seurat::Idents(seurat))
# get feature labels
featureLabels <-
unlist(apply(
seurat@reductions$pca@feature.loadings, 1,
function(x) {
return(names(x)[which(x == max(x))])
}
))
# sum counts for each PC in each cell
features <-
matrix(
unlist(lapply(unique(featureLabels), function(pc) {
colSums(counts[featureLabels == pc, ])
})),
ncol = length(class),
byrow = TRUE,
dimnames = list(unique(featureLabels), colnames(counts))
)
# normalize column-wise (i.e. convert counts to proportions)
features <- apply(features, 2, function(x) {
x / sum(x)
})
}
if (ncol(features) != length(class)) {
stop("Number of columns of features must equal length of class")
}
if (any(is.na(class))) {
stop("NA class values")
}
if (any(is.na(features))) {
stop("NA feature values")
}
# Match the oneoffMetric argument
oneoffMetric <- match.arg(oneoffMetric)
branchPoints <- .findMarkersTree(features = features,
class = class,
oneoffMetric = oneoffMetric,
metaclusters = metaclusters,
featureLabels = featureLabels,
counts = counts,
seurat = seurat,
threshold = threshold,
reuseFeatures = reuseFeatures,
altSplit = altSplit,
consecutiveOneoff = consecutiveOneoff,
autoMetaclusters = autoMetaclusters,
seed = seed)
return(branchPoints)
}
)
.findMarkersTree <- function(features,
class,
oneoffMetric,
metaclusters,
featureLabels,
counts,
seurat,
threshold,
reuseFeatures,
altSplit,
consecutiveOneoff,
autoMetaclusters,
seed) {
# Transpose features
features <- t(features)
# If no detailed cell types are provided or to be identified
if (!methods::hasArg(metaclusters) & (!autoMetaclusters)) {
message("Building tree...")
# Set class to factor
class <- as.factor(class)
# Generate list of tree levels
tree <- .generateTreeList(
features,
class,
oneoffMetric,
threshold,
reuseFeatures,
consecutiveOneoff
)
# Add alternative node for the solely down-regulated leaf
if (altSplit) {
tree <- .addAlternativeSplit(tree, features, class)
}
message("Computing performance metrics...")
# Format tree output for plotting and generate summary statistics
DTsummary <- .summarizeTree(tree, features, class)
# Remove confusing 'value' column
DTsummary$rules <- lapply(DTsummary$rules, function(x) {
x["value"] <- NULL
x
})
# Add column to each rules table which specifies its class
DTsummary$rules <- mapply(cbind,
"class" = as.character(names(DTsummary$rules)),
DTsummary$rules,
SIMPLIFY = FALSE
)
# Generate table for each branch point in the tree
DTsummary$branchPoints <-
.createBranchPoints(DTsummary$rules)
# Add class labels to output
DTsummary$classLabels <- class
return(DTsummary)
} else {
# If metaclusters are provided or to be identified
# consecutive one-offs break the code(tricky to find 1st balanced split)
if (consecutiveOneoff) {
stop(
"Cannot use metaclusters if consecutive one-offs are allowed.",
" Please set the consecutiveOneoff parameter to FALSE."
)
}
# Check if need to identify metaclusters
if (autoMetaclusters & !methods::hasArg(metaclusters)) {
message("Identifying metaclusters...")
# if seurat object then use seurat's UMAP parameters
if (methods::hasArg(seurat)) {
suppressMessages(seurat <-
Seurat::RunUMAP(
seurat,
dims = seq(ncol(seurat@reductions$pca@feature.loadings))
))
umap <- seurat@reductions$umap@cell.embeddings
}
else {
if (is.null(seed)) {
umap <- uwot::umap(
t(sqrt(t(features))),
n_neighbors = 15,
min_dist = 0.01,
spread = 1,
n_sgd_threads = 1
)
}
else {
withr::with_seed(
seed,
umap <- uwot::umap(
t(sqrt(t(features))),
n_neighbors = 15,
min_dist = 0.01,
spread = 1,
n_sgd_threads = 1
)
)
}
}
# dbscan to find metaclusters
dbscan <- dbscan::dbscan(umap, eps = 1)
# place each population in the correct metacluster
mapping <-
unlist(lapply(
sort(as.integer(
unique(class)
)),
function(population) {
# get indexes of occurences of this population
indexes <-
which(class == population)
# get corresponding metaclusters
metaIndices <-
dbscan$cluster[indexes]
# return corresponding metacluster with majority vote
return(names(sort(table(
metaIndices
), decreasing = TRUE)[1]))
}
))
# create list which will contain subtypes of each metacluster
metaclusters <- vector(mode = "list")
# fill in list of populations for each metacluster
for (i in unique(mapping)) {
metaclusters[[i]] <-
sort(as.integer(unique(class)))[which(mapping == i)]
}
names(metaclusters) <- paste0("M", unique(mapping))
message(paste("Identified", length(metaclusters), "metaclusters"))
}
# Check that cell types match class labels
if (mean(unlist(metaclusters) %in% unique(class)) != 1) {
stop(
"Provided cell types do not match class labels. ",
"Please check the 'metaclusters' argument."
)
}
# Create vector with metacluster labels
metaclusterLabels <- class
for (i in names(metaclusters)) {
metaclusterLabels[metaclusterLabels %in% metaclusters[[i]]] <- i
}
# Rename metaclusters with just one cluster
oneCluster <-
names(metaclusters[lengths(metaclusters) == 1])
if (length(oneCluster) > 0) {
oneClusterIndices <- which(metaclusterLabels %in% oneCluster)
metaclusterLabels[oneClusterIndices] <-
paste0(
metaclusterLabels[oneClusterIndices], "(",
class[oneClusterIndices], ")"
)
names(metaclusters[lengths(metaclusters) == 1]) <-
paste0(
names(metaclusters[lengths(metaclusters) == 1]), "(",
unlist(metaclusters[lengths(metaclusters) == 1]), ")"
)
}
# create temporary variables for top-level tree
tmpThreshold <- threshold
# create list to store split off classes at each threshold
markerThreshold <- list()
# Create top-level tree
# while there is still a balanced split at the top-level
while (TRUE) {
# create tree
message("Building top-level tree across all metaclusters...")
tree <-
.generateTreeList(
features,
as.factor(metaclusterLabels),
oneoffMetric,
tmpThreshold,
reuseFeatures,
consecutiveOneoff
)
# Add alternative node for the solely down-regulated leaf
tree <- .addAlternativeSplit(
tree, features,
as.factor(metaclusterLabels)
)
# store clusters with markers at current threshold
topLevel <- tree[[1]][[1]]
if (topLevel$statUsed == "One-off") {
markerThreshold[[as.character(tmpThreshold)]] <-
unlist(lapply(
topLevel[seq(length(topLevel) - 3)],
function(marker) {
return(marker$group1Consensus)
}
))
}
# if no more balanced split
if (length(tree) == 1) {
# if all clusters have positive markers
if (length(tree[[1]][[1]]) == (length(metaclusters) + 3)) {
break
}
else {
# decrease threshold by 10%
tmpThreshold <- tmpThreshold * 0.9
message("Decreasing classifier threshold to ", tmpThreshold)
next
}
}
# still balanced split
else {
# get up-regulated clusters at first balanced split
upClass <- tree[[2]][[1]][[1]]$group1Consensus
# if only 2 clusters at the balanced split then merge them
if ((length(upClass) == 1) &&
(length(tree[[2]][[1]][[1]]$group2Consensus) == 1)) {
upClass <- c(upClass, tree[[2]][[1]][[1]]$group2Consensus)
}
# update metacluster label of each cell
tmpMeta <- metaclusterLabels
tmpMeta[tmpMeta %in% upClass] <-
paste(upClass, sep = "", collapse = "+")
# create top-level tree again
tmpTree <-
.generateTreeList(
features,
as.factor(tmpMeta),
oneoffMetric,
tmpThreshold,
reuseFeatures,
consecutiveOneoff
)
# Add alternative node for the solely down-regulated leaf
tmpTree <- .addAlternativeSplit(
tmpTree, features,
as.factor(tmpMeta)
)
# if new tree still has balanced split/no markers for some
if ((length(tmpTree) > 1) ||
(length(tree[[1]][[1]]) != (length(metaclusters) + 3))) {
# decrease threshold by 10%
tmpThreshold <- tmpThreshold * 0.9
message("Decreasing classifier threshold to ", tmpThreshold)
}
else {
# set final metacluster labels to new set of clusters
metaclusterLabels <- tmpMeta
# set final tree to current tree
tree <- tmpTree
## update 'metaclusters' (list of metaclusters)
# get celda clusters in these metaclusters
newMetacluster <- unlist(metaclusters[upClass])
# remove old metaclusters
metaclusters[upClass] <- NULL
# add new metacluster to list of metaclusters
metaclusters[paste(upClass, sep = "", collapse = "+")] <-
list(unname(newMetacluster))
break
}
}
}
# re-format output
finalTree <- tree
tree <- list(rules = .mapClass2features(
finalTree,
features,
as.factor(metaclusterLabels),
topLevelMeta = TRUE
)$rules)
# keep markers at first threshold they reached only
markersToRemove <- c()
for (thresh in names(markerThreshold)) {
thresholdClasses <- markerThreshold[[thresh]]
for (cl in thresholdClasses) {
curRules <- tree$rules[[cl]]
lowMarkerIndices <- which(curRules$direction == 1 &
curRules$stat < as.numeric(thresh))
if (length(lowMarkerIndices) > 0 &
length(which(curRules$direction == 1)) > 1) {
markersToRemove <- c(
markersToRemove,
curRules[lowMarkerIndices, "feature"]
)
}
}
}
tree$rules <- lapply(tree$rules, function(rules) {
return(rules[!rules$feature %in% markersToRemove, ])
})
# store final set of top-level markers
topLevelMarkers <-
unlist(lapply(tree$rules, function(cluster) {
markers <- cluster[cluster$direction == 1, "feature"]
return(paste(markers, collapse = ";"))
}))
# create tree dendrogram
tree$dendro <-
.convertToDendrogram(finalTree, as.factor(metaclusterLabels),
splitNames = topLevelMarkers
)
# add metacluster label to rules table
for (metacluster in names(tree$rules)) {
tree$rules[[metacluster]]$metacluster <- metacluster
}
# Store tree's dendrogram in a separate variable
dendro <- tree$dendro
# Find which metaclusters have more than one cluster
largeMetaclusters <-
names(metaclusters[lengths(metaclusters) > 1])
# Update subtype labels for large metaclusters
subtypeLabels <- metaclusterLabels
subtypeLabels[subtypeLabels %in% largeMetaclusters] <-
paste0(
subtypeLabels[subtypeLabels %in% largeMetaclusters],
"(",
class[subtypeLabels %in% largeMetaclusters],
")"
)
# Update metaclusters list
for (metacluster in names(metaclusters)) {
subtypes <- metaclusters[metacluster]
subtypes <- lapply(subtypes, function(subtype) {
paste0(metacluster, "(", subtype, ")")
})
metaclusters[metacluster] <- subtypes
}
# Create separate trees for each cell type with more than one cluster
newTrees <- lapply(largeMetaclusters, function(metacluster) {
# Print current status
message("Building tree for metacluster ", metacluster)
# Remove used features
featUse <- colnames(features)
if (!reuseFeatures) {
tmpRules <- tree$rules[[metacluster]]
featUse <-
featUse[!featUse %in%
tmpRules[tmpRules$direction == 1, "feature"]]
}
# Create new tree
newTree <-
.generateTreeList(
features[metaclusterLabels == metacluster, featUse],
as.factor(subtypeLabels[metaclusterLabels == metacluster]),
oneoffMetric,
threshold,
reuseFeatures,
consecutiveOneoff
)
# Add alternative node for the solely down-regulated leaf
if (altSplit) {
newTree <-
.addAlternativeSplit(
newTree,
features[metaclusterLabels == metacluster, featUse],
as.factor(subtypeLabels[metaclusterLabels == metacluster])
)
}
newTree <- list(
rules = .mapClass2features(
newTree,
features[metaclusterLabels
== metacluster, ],
as.factor(subtypeLabels[metaclusterLabels == metacluster])
)$rules,
dendro = .convertToDendrogram(
newTree,
as.factor(subtypeLabels[metaclusterLabels ==
metacluster])
)
)
# Adjust 'rules' table for new tree
newTree$rules <- lapply(newTree$rules, function(rules) {
rules$level <- rules$level +
max(tree$rules[[metacluster]]$level)
rules$metacluster <- metacluster
rules <- rbind(tree$rules[[metacluster]], rules)
})
return(newTree)
})
names(newTrees) <- largeMetaclusters
# Fix max depth in original tree
if (length(newTrees) > 0) {
maxDepth <- max(unlist(lapply(newTrees, function(newTree) {
lapply(newTree$rules, function(ruleDF) {
ruleDF$level
})
})))
addDepth <- maxDepth - attributes(dendro)$height
dendro <- stats::dendrapply(dendro, function(node, addDepth) {
if (attributes(node)$height > 1) {
attributes(node)$height <- attributes(node)$height +
addDepth + 1
}
return(node)
}, addDepth)
}
# Find indices of cell type nodes in tree
indices <- lapply(
largeMetaclusters,
function(metacluster) {
# Initialize sub trees, indices string, and flag
dendSub <- dendro
index <- ""
flag <- TRUE
while (flag) {
# Get the edge with the class of interest
whEdge <- which(unlist(
lapply(
dendSub,
function(edge) {
metacluster %in%
attributes(edge)$classLabels
}
)
))
# Add this as a string
index <-
paste0(index, "[[", whEdge, "]]")
# Move to this branch
dendSub <-
eval(parse(text = paste0("dendro", index)))
# Is this the only class in that branch
flag <- length(attributes(dendSub)$classLabels) > 1
}
return(index)
}
)
names(indices) <- largeMetaclusters
# Add each cell type tree
for (metacluster in largeMetaclusters) {
# Get current tree
metaclusterDendro <- newTrees[[metacluster]]$dendro
# Adjust labels, member count, and midpoint of nodes
dendro <- stats::dendrapply(dendro, function(node) {
# Check if in right branch
if (metacluster %in%
as.character(attributes(node)$classLabels)) {
# Replace cell type label with subtype labels
labels <- attributes(node)$classLabels
labels <- as.character(labels)
labels <- labels[labels != metacluster]
labels <- c(labels, unique(subtypeLabels)
[grep(metacluster, unique(subtypeLabels))])
attributes(node)$classLabels <- labels
# Assign new member count for this branch
attributes(node)$members <-
length(attributes(node)$classLabels)
# Assign new midpoint for this branch
attributes(node)$midpoint <-
(attributes(node)$members - 1) / 2
}
return(node)
})
# Replace label at new tree's branch point
branchPointAttr <- attributes(eval(parse(text = paste0(
"dendro", indices[[metacluster]]
))))
branchPointLabel <- branchPointAttr$label
branchPointStatUsed <- branchPointAttr$statUsed
if (!is.null(branchPointLabel)) {
attributes(metaclusterDendro)$label <- branchPointLabel
attributes(metaclusterDendro)$statUsed <-
branchPointStatUsed
}
# Fix height
indLoc <-
gregexpr("\\[\\[", indices[[metacluster]])[[1]]
indLoc <- indLoc[length(indLoc)]
parentIndexString <- substr(
indices[[metacluster]],
0,
indLoc - 1
)
parentHeight <- attributes(eval(parse(
text = paste0("dendro", parentIndexString)
)))$height
metaclusterHeight <-
attributes(metaclusterDendro)$height
metaclusterDendro <- stats::dendrapply(
metaclusterDendro,
function(node,
parentHeight,
metaclusterHeight) {
if (attributes(node)$height > 1) {
attributes(node)$height <-
parentHeight - 1 -
(metaclusterHeight -
attributes(node)$height)
}
return(node)
}, parentHeight, metaclusterHeight
)
# Add new tree to original tree
eval(parse(text = paste0(
"dendro", indices[[metacluster]], " <- metaclusterDendro"
)))
# Append new tree's 'rules' tables to original tree
tree$rules <-
append(tree$rules,
newTrees[[metacluster]]$rules,
after = which(names(tree$rules) == metacluster)
)
# Remove old tree's rules
tree$rules <-
tree$rules[-which(names(tree$rules) == metacluster)]
}
# Set final tree dendro
tree$dendro <- dendro
# Get performance statistics
message("Computing performance statistics...")
perfList <- .getPerformance(
tree$rules,
features,
as.factor(subtypeLabels)
)
tree$prediction <- perfList$prediction
tree$performance <- perfList$performance
# Remove confusing 'value' column
tree$rules <-
lapply(tree$rules, function(x) {
x["value"] <- NULL
x
})
# add column to each rules table which specifies its class
tree$rules <-
mapply(cbind,
"class" = as.character(names(tree$rules)),
tree$rules,
SIMPLIFY = FALSE
)
# create branch points table
branchPoints <-
.createBranchPoints(tree$rules, largeMetaclusters, metaclusters)
# collapse all rules tables into one large table
collapsed <- do.call("rbind", tree$rules)
# get top-level rules
topLevelRules <- collapsed[collapsed$level == 1, ]
# add 'class' column
topLevelRules$class <- topLevelRules$metacluster
# add to branch point list
branchPoints[["top_level"]] <- topLevelRules
# check if need to expand features to gene-level
if (methods::hasArg(featureLabels) &&
methods::hasArg(counts)) {
message("Computing scores for individual genes...")
# make sure feature labels match those in the tree
if (!all(unique(collapsed$feature) %in% unique(featureLabels))) {
m <- "Provided feature labels don't match those in count matrix."
stop(m)
}
# iterate over branch points
branchPoints <- lapply(branchPoints, function(branch) {
# iterate over unique features
featAUC <-
lapply(
unique(branch$feature),
.getGeneAUC,
branch,
subtypeLabels,
metaclusterLabels,
featureLabels,
counts
)
# update branch table after merging genes data
return(do.call("rbind", featAUC))
})
# simplify top-level in rules tables to only up-regulated markers
tree$rules <- lapply(tree$rules, function(rule) {
return(rule[-intersect(
which(rule$level == 1),
which(rule$direction == (-1))
), ])
})
## add gene-level info to rules tables
# collapse branch points tables into one
collapsedBranches <- do.call("rbind", branchPoints)
collapsedBranches$class <-
as.character(collapsedBranches$class)
# loop over rules tables and get relevant info
tree$rules <- lapply(tree$rules, function(class) {
# initialize table to return
toReturn <- data.frame(NULL)
# loop over rows of this class
for (i in seq(nrow(class))) {
# extract relevant genes from branch points tables
genesAUC <- collapsedBranches[collapsedBranches$feature ==
class$feature[i] &
collapsedBranches$level == class$level[i] &
collapsedBranches$class == class$class[i], ]
# don't forget top-level
if (class$level[i] == 1) {
genesAUC <- collapsedBranches[collapsedBranches$feature ==
class$feature[i] &
collapsedBranches$level == class$level[i] &
collapsedBranches$class == class$metacluster[i], ]
}
# merge table
toReturn <- rbind(toReturn, genesAUC)
}
return(toReturn)
})
# remove table row names
tree$rules <- lapply(tree$rules, function(t) {
rownames(t) <- NULL
return(t)
})
# add feature labels to output
tree$featureLabels <- featureLabels
}
# simplify top-level branch point to save memory
branchPoints$top_level <-
branchPoints$top_level[branchPoints$top_level$direction == 1, ]
branchPoints$top_level <-
branchPoints$top_level[!duplicated(branchPoints$top_level), ]
# remove branch points row names
branchPoints <- lapply(branchPoints, function(br) {
rownames(br) <- NULL
return(br)
})
# adjust subtype labels
branchPoints <- lapply(branchPoints, function(br) {
br$class <- as.character(br$class)
br$class[grepl("\\(.*\\)", br$class)] <- regmatches(
br$class[grepl("\\(.*\\)", br$class)],
regexpr(
pattern = "(?<=\\().*?(?=\\)$)",
br$class[grepl("\\(.*\\)", br$class)],
perl = TRUE
)
)
br$metacluster <- as.character(br$metacluster)
br$metacluster[grepl("\\(.*\\)", br$metacluster)] <-
gsub(
"\\(.*\\)", "",
br$metacluster[grepl("\\(.*\\)", br$metacluster)]
)
return(br)
})
# adjust subtype labels
tree$rules <-
suppressWarnings(lapply(tree$rules, function(r) {
r$class <- as.character(r$class)
r$class[grepl("\\(.*\\)", r$class)] <- regmatches(
r$class[grepl("\\(.*\\)", r$class)],
regexpr(
pattern = "(?<=\\().*?(?=\\)$)",
r$class[grepl("\\(.*\\)", r$class)],
perl = TRUE
)
)
r$metacluster[grepl("\\(.*\\)", r$metacluster)] <-
gsub(
"\\(.*\\)", "",
r$metacluster[grepl("\\(.*\\)", r$metacluster)]
)
return(r)
}))
# add to tree
tree$branchPoints <- branchPoints
# return class labels
tree$classLabels <- regmatches(
subtypeLabels,
regexpr(
pattern = "(?<=\\().*?(?=\\)$)",
subtypeLabels, perl = TRUE
)
)
tree$metaclusterLabels <- metaclusterLabels
tree$metaclusterLabels[grepl("\\(.*\\)", metaclusterLabels)] <-
gsub(
"\\(.*\\)", "",
metaclusterLabels[grepl("\\(.*\\)", metaclusterLabels)]
)
# Final return
return(tree)
}
}
# helper function to create table for each branch point in the tree
.createBranchPoints <-
function(rules, largeMetaclusters, metaclusters) {
# First step differs if metaclusters were used
if (methods::hasArg(metaclusters) &&
(length(largeMetaclusters) > 0)) {
# iterate over metaclusters and add the rules for each level
branchPoints <-
lapply(largeMetaclusters, function(metacluster) {
# get names of subtypes
subtypes <- metaclusters[[metacluster]]
# collapse rules tables of subtypes
subtypeRules <- do.call("rbind", rules[subtypes])
# get rules at each level
levels <-
lapply(seq(2, max(subtypeRules$level)), function(level) {
return(subtypeRules[subtypeRules$level == level, ])
})
names(levels) <- paste0(
metacluster, "_level_",
seq(max(subtypeRules$level) - 1)
)
return(levels)
})
branchPoints <- unlist(branchPoints, recursive = FALSE)
}
else {
# collapse all rules into one table
collapsed <- do.call("rbind", rules)
# subset rules at each level
branchPoints <-
lapply(seq(max(collapsed$level)), function(level) {
return(collapsed[collapsed$level == level, ])
})
names(branchPoints) <-
paste0("level_", seq(max(collapsed$level)))
}
# split each level into its branch points
branchPoints <- lapply(branchPoints, function(level) {
# check if need to split
firstFeat <- level$feature[1]
firstStat <- level$stat[1]
if (setequal(
level[
level$feature == firstFeat &
level$stat == firstStat,
"class"
],
unique(level$class)
)) {
return(level)
}
# initialize lists for new tables
bSplits <- NA
oSplits <- NA
# get balanced split rows by themselves
balS <- level[level$statUsed == "Split", ]
# return table for each unique value of 'stat'
if (nrow(balS) > 0) {
# get unique splits (based on stat)
unS <- unique(balS$stat)
# return table for each unique split
bSplits <- lapply(unS, function(s) {
balS[balS$stat == s, ]
})
}
# get one-off rows by themselves
oneS <- level[level$statUsed == "One-off", ]
if (nrow(oneS) > 0) {
# check if need to split
firstFeat <- oneS$feature[1]
if (setequal(
oneS[oneS$feature == firstFeat, "class"],
unique(oneS$class)
)) {
oSplits <- oneS
}
# get class groups for each marker
markers <- oneS[oneS$direction == 1, "feature"]
groups <- unique(unlist(lapply(markers, function(m) {
return(paste(as.character(oneS[oneS$feature == m, "class"]),
collapse = " "
))
})))
# return table for each class group
oSplits <- lapply(groups, function(x) {
gr <- unlist(strsplit(x, split = " "))
oneS[as.character(oneS$class) %in% gr, ]
})
}
# rename new tables
if (is.list(bSplits)) {
names(bSplits) <- paste0(
"split_",
LETTERS[seq(length(bSplits), 1)]
)
}
if (is.list(oSplits)) {
names(oSplits) <- paste0(
"one-off_",
LETTERS[seq(length(oSplits), 1)]
)
}
# return 2 sets of table
toReturn <- list(oSplits, bSplits)
toReturn <- toReturn[!is.na(toReturn)]
toReturn <- unlist(toReturn, recursive = FALSE)
return(toReturn)
})
# adjust for new tables
branchPoints <- lapply(branchPoints, function(br) {
if (inherits(br, "list")) {
return(br)
}
else {
return(list(br))
}
})
branchPoints <- unlist(branchPoints, recursive = FALSE)
# replace dots in names of new branches with underscores
names(branchPoints) <- gsub(
pattern = "\\.([^\\.]*)$",
replacement = "_\\1",
names(branchPoints)
)
return(branchPoints)
}
# helper function to get AUC for individual genes within feature
.getGeneAUC <- function(marker,
table,
subtypeLabels,
metaclusterLabels,
featureLabels,
counts) {
# get up-regulated & down-regulated classes for this feature
upClass <-
as.character(table[table$feature == marker &
table$direction == 1, "class"])
downClasses <-
as.character(table[table$feature == marker &
table$direction == (-1), "class"])
# subset counts matrix
if (table$level[1] > 1) {
subCounts <-
counts[, which(subtypeLabels %in% c(upClass, downClasses))]
}
else {
subCounts <- counts[, which(metaclusterLabels %in%
c(upClass, downClasses))]
}
# subset class labels
if (table$level[1] > 1) {
subLabels <- subtypeLabels[which(subtypeLabels %in%
c(upClass, downClasses))]
}
else {
subLabels <- metaclusterLabels[which(metaclusterLabels %in%
c(upClass, downClasses))]
}
# set label to 0 if not class of interest
subLabels <- as.numeric(subLabels %in% upClass)
# get individual features within this marker
markers <- rownames(counts)[which(featureLabels == marker)]
# get one-vs-all AUC for each gene
auc <- unlist(lapply(markers, function(markerGene) {
as.numeric(pROC::auc(
pROC::roc(
subLabels,
subCounts[markerGene, ],
direction = "<",
quiet = TRUE
)
))
}))
names(auc) <- markers
# sort by AUC
auc <- sort(auc, decreasing = TRUE)
# create table for this marker
featTable <- table[table$feature == marker, ]
featTable <-
featTable[rep(seq_len(nrow(featTable)), each = length(auc)), ]
featTable$gene <-
rep(names(auc), length(c(upClass, downClasses)))
featTable$geneAUC <- rep(auc, length(c(upClass, downClasses)))
# return table for merging with main table
return(featTable)
}
# This function generates the decision tree by recursively separating classes.
.generateTreeList <- function(features,
class,
oneoffMetric,
threshold,
reuseFeatures,
consecutiveOneoff = FALSE) {
# Initialize Tree
treeLevel <- tree <- list()
# Initialize the first split
treeLevel[[1]] <- list()
# Generate the first split at the first level
treeLevel[[1]] <- .wrapSplitHybrid(
features,
class,
threshold,
oneoffMetric
)
# Add set of features used at this split
treeLevel[[1]]$fUsed <- unlist(lapply(
treeLevel[[1]][names(treeLevel[[1]]) != "statUsed"],
function(X) {
X$featureName
}
))
# Initialize split directions
treeLevel[[1]]$dirs <- 1
# Add split list as first level
tree[[1]] <- treeLevel
# Initialize tree depth
mDepth <- 1
# Build tree until all leafs are of a single cluster
while (length(unlist(treeLevel)) > 0) {
# Create list of branches on this level
outList <-
lapply(treeLevel, function(split, features, class) {
# Check for consecutive oneoff
tryOneoff <- TRUE
if (!consecutiveOneoff & split$statUsed == "One-off") {
tryOneoff <- FALSE
}
# If length(split == 4) than this split is binary node
if (length(split) == 4 &
length(split[[1]]$group1Consensus) > 1) {
# Create branch from this split.
branch1 <- .wrapBranchHybrid(
split[[1]]$group1,
features,
class,
split$fUsed,
threshold,
reuseFeatures,
oneoffMetric,
tryOneoff
)
if (!is.null(branch1)) {
# Add feature to list of used features.
branch1$fUsed <- c(split$fUsed, unlist(lapply(
branch1[names(branch1) != "statUsed"],
function(X) {
X$featureName
}
)))
# Add the split direction (always 1 when splitting group 1)
branch1$dirs <- c(split$dirs, 1)
}
} else {
branch1 <- NULL
}
# If length(split == 4) than this split is binary node
if (length(split) == 4 &
length(split[[1]]$group2Consensus) > 1) {
# Create branch from this split
branch2 <- .wrapBranchHybrid(
split[[1]]$group2,
features,
class,
split$fUsed,
threshold,
reuseFeatures,
oneoffMetric,
tryOneoff
)
if (!is.null(branch2)) {
# Add feature to list of used features.
branch2$fUsed <- c(split$fUsed, unlist(lapply(
branch2[names(branch2) != "statUsed"],
function(X) {
X$featureName
}
)))
# Add the split direction (always 2 when splitting group 2)
branch2$dirs <- c(split$dirs, 2)
}
# If length(split > 4) than this split is more than 2 edges
# In this case group 1 will always denote leaves.
} else if (length(split) > 4) {
# Get samples that are never in group 1 in this split
group1Samples <- unique(unlist(lapply(
split[!names(split) %in% c("statUsed", "fUsed", "dirs")],
function(X) {
X$group1
}
)))
group2Samples <- unique(unlist(lapply(
split[!names(split) %in% c("statUsed", "fUsed", "dirs")],
function(X) {
X$group2
}
)))
group2Samples <- group2Samples[!group2Samples %in%
group1Samples]
# Check that there is still more than one class
group2Classes <- levels(droplevels(class[rownames(features) %in%
group2Samples]))
if (length(group2Classes) > 1) {
# Create branch from this split
branch2 <- .wrapBranchHybrid(
group2Samples,
features,
class,
split$fUsed,
threshold,
reuseFeatures,
oneoffMetric,
tryOneoff
)
if (!is.null(branch2)) {
# Add multiple features
branch2$fUsed <-
c(split$fUsed, unlist(lapply(
branch2[names(branch2) != "statUsed"],
function(X) {
X$featureName
}
)))
# Instead of 2, this direction is 1 + the num. splits
branch2$dirs <- c(
split$dirs,
sum(!names(split) %in%
c("statUsed", "fUsed", "dirs")) + 1
)
}
} else {
branch2 <- NULL
}
} else {
branch2 <- NULL
}
# Combine these branches
outBranch <- list(branch1, branch2)
# Only keep non-null branches
outBranch <-
outBranch[!unlist(lapply(outBranch, is.null))]
if (length(outBranch) > 0) {
return(outBranch)
} else {
return(NULL)
}
}, features, class)
# Unlist outList so is one list per 'treeLevel'
treeLevel <- unlist(outList, recursive = FALSE)
# Increase tree depth
mDepth <- mDepth + 1
# Add this level to the tree
tree[[mDepth]] <- treeLevel
}
return(tree)
}
# Wrapper to subset the feature and class set for each split
.wrapBranchHybrid <- function(groups,
features,
class,
fUsed,
threshold = 0.95,
reuseFeatures = FALSE,
oneoffMetric,
tryOneoff) {
# Subset for branch to run split
gKeep <- rownames(features) %in% groups
# Remove used features?
if (reuseFeatures) {
fSub <- features[gKeep, ]
} else {
fSub <-
features[gKeep, !colnames(features) %in% fUsed, drop = FALSE]
}
# Drop levels (class that are no longer in)
cSub <- droplevels(class[gKeep])
# If multiple columns in fSub run split, else return null
if (ncol(fSub) > 1) {
return(.wrapSplitHybrid(fSub, cSub, threshold, oneoffMetric, tryOneoff))
} else {
return(NULL)
}
}
# Wrapper function to perform split metrics
.wrapSplitHybrid <- function(features,
class,
threshold = 0.95,
oneoffMetric,
tryOneoff = TRUE) {
# Get best one-2-one splits
## Use modified f1 or pairwise auc?
if (tryOneoff) {
if (oneoffMetric == "modified F1") {
splitMetric <- .splitMetricModF1
} else {
splitMetric <- .splitMetricPairwiseAUC
}
splitStats <- .splitMetricRecursive(features,
class,
splitMetric = splitMetric
)
splitStats <- splitStats[splitStats >= threshold]
statUsed <- "One-off"
} else {
splitStats <- integer(0)
}
# If no one-2-one split meets threshold, run semi-supervised clustering
if (length(splitStats) == 0) {
splitMetric <- .splitMetricIGpIGd
splitStats <- .splitMetricRecursive(features,
class,
splitMetric = splitMetric
)[1] # Use top
statUsed <- "Split"
}
# Get split for best gene
splitList <- lapply(
names(splitStats),
.getSplit,
splitStats,
features,
class,
splitMetric
)
# Combine feature rules when same group1 class arises
if (length(splitList) > 1) {
group1Vec <- unlist(lapply(splitList, function(X) {
X$group1Consensus
}), recursive = FALSE)
splitList <- lapply(
unique(group1Vec),
function(group1, splitList, group1Vec) {
# Get subset with same group1
splitListSub <- splitList[group1Vec == group1]
# Get feature, value, and stat for these
splitFeature <- unlist(lapply(
splitListSub,
function(X) {
X$featureName
}
))
splitValue <- unlist(lapply(
splitListSub,
function(X) {
X$value
}
))
splitStat <- unlist(lapply(
splitListSub,
function(X) {
X$stat
}
))
# Create a single object and add these
splitSingle <- splitListSub[[1]]
splitSingle$featureName <- splitFeature
splitSingle$value <- splitValue
splitSingle$stat <- splitStat
return(splitSingle)
}, splitList, group1Vec
)
}
names(splitList) <- unlist(lapply(
splitList,
function(X) {
paste(X$featureName, collapse = ";")
}
))
# Add statUsed
splitList$statUsed <- statUsed
return(splitList)
}
# Recursively run split metric on every feature
.splitMetricRecursive <- function(features, class, splitMetric) {
splitStats <- vapply(colnames(features),
function(feat, features, class, splitMetric) {
splitMetric(feat, class, features, rPerf = TRUE)
}, features, class, splitMetric,
FUN.VALUE = double(1)
)
names(splitStats) <- colnames(features)
splitStats <- sort(splitStats, decreasing = TRUE)
return(splitStats)
}
# Run pairwise AUC metirc on single feature
.splitMetricPairwiseAUC <-
function(feat, class, features, rPerf = FALSE) {
# Get current feature
currentFeature <- features[, feat]
# Get unique classes
classUnique <- sort(unique(class))
# Do one-to-all to determine top cluster
# For each class K1 determine best AUC
auc1toAll <-
vapply(classUnique, function(k1, class, currentFeature) {
# Set value to k1
classK1 <- as.numeric(class == k1)
# Get AUC value
aucK1 <-
pROC::auc(pROC::roc(
classK1,
currentFeature,
direction = "<",
quiet = TRUE
))
# Return
return(aucK1)
}, class, currentFeature, FUN.VALUE = double(1))
# Get class with best AUC (Class with generally highest values)
classMax <- as.character(classUnique[which.max(auc1toAll)])
# Get other classes
classRest <- as.character(classUnique[classUnique != classMax])
# for each second cluster k2
aucFram <- as.data.frame(do.call(
rbind,
lapply(
classRest,
function(k2, k1, class, currentFeature) {
# keep cells in k1 or k2 only
obsKeep <- class %in% c(k1, k2)
currentFeatureSubset <- currentFeature[obsKeep]
# update cluster assignments
currentClusters <- class[obsKeep]
# label cells whether they belong to k1 (0 or 1)
currentLabels <- as.integer(currentClusters == k1)
# get AUC value for this feat-cluster pair
rocK2 <-
pROC::roc(currentLabels,
currentFeatureSubset,
direction = "<",
quiet = TRUE
)
aucK2 <- rocK2$auc
coordK2 <-
pROC::coords(rocK2, "best", ret = "threshold", transpose = TRUE)[1]
# Concatenate vectors
statK2 <- c(threshold = coordK2, auc = aucK2)
return(statK2)
}, classMax, class, currentFeature
)
))
# Get Min Value
aucMin <- min(aucFram$auc)
# Get indices where this AUC occurs
aucMinIndices <- which(aucFram$auc == aucMin)
# Use maximum value if there are ties
aucValue <- max(aucFram$threshold)
# Return performance or value?
if (rPerf) {
return(aucMin)
} else {
return(aucValue)
}
}
# Run modified F1 metric on single feature
.splitMetricModF1 <-
function(feat, class, features, rPerf = FALSE) {
# Get number of samples
len <- length(class)
# Get Values
featValues <- features[, feat]
# Get order of values
ord <- order(featValues, decreasing = TRUE)
# Get sorted class and values
featValuesSort <- featValues[ord]
classSort <- class[ord]
# Keep splits of the data where the class changes
keep <- c(
classSort[seq(1, (len - 1))] != classSort[seq(2, (len))] &
featValuesSort[seq(1, (len - 1))] != featValuesSort[seq(2, (len))],
FALSE
)
# Create data.matrix
X <- stats::model.matrix(~ 0 + classSort)
# Get cumulative sums
sRCounts <- apply(X, 2, cumsum)
# Keep only values where the class changes
sRCounts <- sRCounts[keep, , drop = FALSE]
featValuesKeep <- featValuesSort[keep]
# Number of each class
Xsum <- colSums(X)
# Remove impossible splits (No class has > 50% of there samples on one side)
sRProbs <- sRCounts %*% diag(Xsum^-1)
sKeepPossible <-
rowSums(sRProbs >= 0.5) > 0 & rowSums(sRProbs < 0.5) > 0
# Remove anything after a full prob (Doesn't always happen)
maxCheck <-
min(c(which(apply(sRProbs, 1, max) == 1), nrow(sRProbs)))
sKeepCheck <- seq(1, nrow(sRProbs)) %in% seq(1, maxCheck)
# Combine logical vectors
sKeep <- sKeepPossible & sKeepCheck
if (sum(sKeep) > 0) {
# Remove these if they exist
sRCounts <- sRCounts[sKeep, , drop = FALSE]
featValuesKeep <- featValuesKeep[sKeep]
# Get left counts
sLCounts <- t(Xsum - t(sRCounts))
# Calculate the harmonic mean of Sens, Prec, and Worst Alt Sens
statModF1 <- vapply(seq(nrow(sRCounts)),
function(i, Xsum, sRCounts, sLCounts) {
# Right Side
sRRowSens <-
sRCounts[i, ] / Xsum # Right sensitivities
sRRowPrec <-
sRCounts[i, ] / sum(sRCounts[i, ]) # Right prec
sRRowF1 <-
2 * (sRRowSens * sRRowPrec) / (sRRowSens + sRRowPrec)
sRRowF1[is.nan(sRRowF1)] <- 0 # Get right F1
bestF1Ind <- which.max(sRRowF1) # Which is the best?
bestSens <-
sRRowSens[bestF1Ind] # The corresponding sensitivity
bestPrec <-
sRRowPrec[bestF1Ind] # The corresponding precision
# Left Side
sLRowSens <-
sLCounts[i, ] / Xsum # Get left sensitivities
worstSens <-
min(sLRowSens[-bestF1Ind]) # Get the worst
# Get harmonic mean of best sens, best prec, and worst sens
HMout <- (3 * bestSens * bestPrec * worstSens) /
(bestSens * bestPrec + bestPrec * worstSens +
bestSens * worstSens)
return(HMout)
}, Xsum, sRCounts, sLCounts,
FUN.VALUE = double(1)
)
# Get Max Value
ModF1Max <- max(statModF1)
# Get indices where this value occurs (use minimum row)
ModF1Index <- which.max(statModF1)
# Get value at this point
ValueCeiling <- featValuesKeep[ModF1Index]
ValueWhich <- which(featValuesSort == ValueCeiling)
ModF1Value <- mean(c(
featValuesSort[ValueWhich],
featValuesSort[ValueWhich + 1]
))
} else {
ModF1Max <- 0
ModF1Value <- NA
}
if (rPerf) {
return(ModF1Max)
} else {
return(ModF1Value)
}
}
# Run Information Gain (probability + density) on a single feature
.splitMetricIGpIGd <- function(feat, class, features, rPerf = FALSE) {
# Get number of samples
len <- length(class)
# Get Values
featValues <- features[, feat]
# Get order of values
ord <- order(featValues, decreasing = TRUE)
# Get sorted class and values
featValuesSort <- featValues[ord]
classSort <- class[ord]
# Keep splits of the data where the class changes
keep <- c(
classSort[seq(1, (len - 1))] != classSort[seq(2, (len))] &
featValuesSort[seq(1, (len - 1))] != featValuesSort[seq(2, (len))],
FALSE
)
# Create data.matrix
X <- stats::model.matrix(~ 0 + classSort)
# Get cumulative sums
sRCounts <- apply(X, 2, cumsum)
# Keep only values where the class changes
sRCounts <- sRCounts[keep, , drop = FALSE]
featValuesKeep <- featValuesSort[keep]
# Number of each class
Xsum <- colSums(X)
# Remove impossible splits
sRProbs <- sRCounts %*% diag(Xsum^-1)
sKeep <-
rowSums(sRProbs >= 0.5) > 0 & rowSums(sRProbs < 0.5) > 0
if (sum(sKeep) > 0) {
# Remove these if they exist
sRCounts <- sRCounts[sKeep, , drop = FALSE]
featValuesKeep <- featValuesKeep[sKeep]
# Get left counts
sLCounts <- t(Xsum - t(sRCounts))
# Multiply them to get probabilities
sRProbs <- t(t(sRCounts) %*%
diag(rowSums(sRCounts)^-1, nrow = nrow(sRCounts)))
sLProbs <- t(t(sLCounts) %*%
diag(rowSums(sLCounts)^-1, nrow = nrow(sLCounts)))
# Multiply them by there log
sRTrans <- sRProbs * log(sRProbs)
sRTrans[is.na(sRTrans)] <- 0
sLTrans <- sLProbs * log(sLProbs)
sLTrans[is.na(sLTrans)] <- 0
# Get entropies
HSR <- -rowSums(sRTrans)
HSL <- -rowSums(sLTrans)
# Get overall probabilities and entropy
nProbs <- colSums(X) / len
HS <- -sum(nProbs * log(nProbs))
# Get split proporions
sProps <- rowSums(sRCounts) / nrow(X)
# Get information gain (Probability)
IGprobs <- HS - (sProps * HSR + (1 - sProps) * HSL)
IGprobs[is.nan(IGprobs)] <- 0
IGprobsQuantile <- IGprobs / max(IGprobs)
IGprobsQuantile[is.nan(IGprobsQuantile)] <- 0
# Get proportions at each split
classProps <- sRCounts %*% diag(Xsum^-1)
classSplit <- classProps >= 0.5
# Initialize information gain density vector
splitIGdensQuantile <- rep(0, nrow(classSplit))
# Get unique splits of the data
classSplitUnique <- unique(classSplit)
classSplitUnique <-
classSplitUnique[!rowSums(classSplitUnique) %in%
c(0, ncol(classSplitUnique)), , drop = FALSE]
# Get density information gain
if (nrow(classSplitUnique) > 0) {
# Get log(determinant of full matrix)
DET <- .psdet(stats::cov(features))
# Information gain of every observation
IGdens <- apply(
classSplitUnique,
1,
.infoGainDensity,
X,
features,
DET
)
names(IGdens) <- apply(
classSplitUnique * 1,
1,
function(X) {
paste(X, collapse = "")
}
)
IGdens[is.nan(IGdens) | IGdens < 0] <- 0
IGdensQuantile <- IGdens / max(IGdens)
IGdensQuantile[is.nan(IGdensQuantile)] <- 0
# Get ID of each class split
splitsIDs <- apply(
classSplit * 1,
1,
function(x) {
paste(x, collapse = "")
}
)
# Append information gain density vector
for (ID in names(IGdens)) {
splitIGdensQuantile[splitsIDs == ID] <- IGdensQuantile[ID]
}
}
# Add this to the other matrix
IG <- IGprobsQuantile + splitIGdensQuantile
# Get IG(probabilty) of maximum value
IGreturn <- IGprobs[which.max(IG)[1]]
# Get maximum value
maxVal <- featValuesKeep[which.max(IG)]
wMax <- max(which(featValuesSort == maxVal))
IGvalue <-
mean(c(featValuesSort[wMax], featValuesSort[wMax + 1]))
} else {
IGreturn <- 0
IGvalue <- NA
}
# Report maximum ID or value at maximum IG
if (rPerf) {
return(IGreturn)
} else {
return(IGvalue)
}
}
# Function to find pseudo-determinant
.psdet <- function(x) {
if (sum(is.na(x)) == 0) {
svalues <- zapsmall(svd(x)$d)
sum(log(svalues[svalues > 0]))
} else {
0
}
}
# Function to calculate density information gain
.infoGainDensity <- function(splitVector, X, features, DET) {
# Get Subsets of the feature matrix
sRFeat <- features[as.logical(rowSums(X[, splitVector, drop = FALSE])), ,
drop = FALSE
]
sLFeat <- features[as.logical(rowSums(X[, !splitVector, drop = FALSE])), ,
drop = FALSE
]
# Get pseudo-determinant of covariance matrices
DETR <- .psdet(stats::cov(sRFeat))
DETL <- .psdet(stats::cov(sLFeat))
# Get relative sizes
sJ <- nrow(features)
sJR <- nrow(sRFeat)
sJL <- nrow(sLFeat)
IUout <- 0.5 * (DET - (sJR / sJ * DETR + sJL / sJ * DETL))
return(IUout)
}
# Wrapper function for getting split statistics
.getSplit <-
function(feat,
splitStats,
features,
class,
splitMetric) {
stat <- splitStats[feat]
splitVal <- splitMetric(feat, class, features, rPerf = FALSE)
featValues <- features[, feat]
# Get classes split to one node
node1Class <- class[featValues > splitVal]
# Get proportion of each class at each node
group1Prop <- table(node1Class) / table(class)
group2Prop <- 1 - group1Prop
# Get class consensus
group1Consensus <- names(group1Prop)[group1Prop >= 0.5]
group2Consensus <- names(group1Prop)[group1Prop < 0.5]
# Get group samples
group1 <- rownames(features)[class %in% group1Consensus]
group2 <- rownames(features)[class %in% group2Consensus]
# Get class vector
group1Class <- droplevels(class[class %in% group1Consensus])
group2Class <- droplevels(class[class %in% group2Consensus])
return(
list(
featureName = feat,
value = splitVal,
stat = stat,
group1 = group1,
group1Class = group1Class,
group1Consensus = group1Consensus,
group1Prop = c(group1Prop),
group2 = group2,
group2Class = group2Class,
group2Consensus = group2Consensus,
group2Prop = c(group2Prop)
)
)
}
# Function to annotate alternate split of a soley downregulated terminal nodes
.addAlternativeSplit <- function(tree, features, class) {
# Unlist decsision tree
DecTree <- unlist(tree, recursive = FALSE)
# Get leaves
groupList <- lapply(DecTree, function(split) {
# Remove directions
split <-
split[!names(split) %in% c("statUsed", "fUsed", "dirs")]
# Get groups
group1 <- unique(unlist(lapply(
split,
function(node) {
node$group1Consensus
}
)))
group2 <- unique(unlist(lapply(
split,
function(node) {
node$group2Consensus
}
)))
return(list(
group1 = group1,
group2 = group2
))
})
# Get vector of each group
group1Vec <-
unique(unlist(lapply(groupList, function(g) {
g$group1
})))
group2Vec <-
unique(unlist(lapply(groupList, function(g) {
g$group2
})))
# Get group that is never up-regulated
group2only <- group2Vec[!group2Vec %in% group1Vec]
# Check whether there are solely downregulated splits
AltSplitInd <-
which(unlist(lapply(groupList, function(g, group2only) {
group2only %in% g$group2
}, group2only)))
if (length(AltSplitInd) > 0) {
AltDec <-
max(which(unlist(
lapply(groupList, function(g, group2only) {
group2only %in% g$group2
}, group2only)
)))
# Get split
downSplit <- DecTree[[AltDec]]
downNode <- downSplit[[1]]
# Get classes to rerun
branchClasses <- names(downNode$group1Prop)
# Get samples from these classes and features from this cluster
sampKeep <- class %in% branchClasses
featKeep <- !colnames(features) %in% downSplit$fUsed
# Subset class and features
cSub <- droplevels(class[sampKeep])
fSub <- features[sampKeep, featKeep, drop = FALSE]
# Get best alternative split
altStats <- do.call(
rbind,
lapply(
colnames(fSub),
function(feat,
splitMetric,
features,
class,
cInt) {
Val <- splitMetric(feat, cSub, fSub, rPerf = FALSE)
# Get node1 classes
node1Class <- class[features[, feat] > Val]
# Get sensitivity/precision/altSens
Sens <- sum(node1Class == cInt) / sum(class == cInt)
Prec <- mean(node1Class == cInt)
# Get Sensitivity of Alternate Classes
AltClasses <- unique(class)[unique(class) != cInt]
AltSizes <- vapply(AltClasses,
function(cAlt, class) {
sum(class == cAlt)
}, class,
FUN.VALUE = double(1)
)
AltWrong <- vapply(AltClasses,
function(cAlt, node1Class) {
sum(node1Class == cAlt)
}, node1Class,
FUN.VALUE = double(1)
)
AltSens <- min(1 - (AltWrong / AltSizes))
# Get harmonic mean
HM <- (3 * Sens * Prec * AltSens) /
(Sens * Prec + Prec * AltSens + Sens * AltSens)
HM[is.nan(HM)] <- 0
# Return
return(data.frame(
feat = feat,
val = Val,
stat = HM,
stringsAsFactors = FALSE
))
}, .splitMetricModF1, fSub, cSub, group2only
)
)
altStats <-
altStats[order(altStats$stat, decreasing = TRUE), ]
# Get alternative splits
splitStats <- altStats$stat[1]
names(splitStats) <- altStats$feat[1]
altSplit <- .getSplit(
altStats$feat[1],
splitStats,
fSub,
cSub,
.splitMetricModF1
)
# Check that this split out the group2 of interest
if (length(altSplit$group1Consensus) == 1) {
# Add it to split
downSplit[[length(downSplit) + 1]] <- altSplit
names(downSplit)[length(downSplit)] <- altStats$feat[1]
downSplit <- downSplit[c(
which(!names(downSplit) %in% c("statUsed", "fUsed", "dirs")),
which(names(downSplit) %in% c("statUsed", "fUsed", "dirs"))
)]
# Get index of split to add it to
branchLengths <- unlist(lapply(tree, length))
branchCum <- cumsum(branchLengths)
wBranch <- min(which(branchCum >= AltDec))
if (wBranch == 1) {
wSplit <- 1
}
else {
wSplit <- which(seq(
(branchCum[(wBranch - 1)] + 1),
branchCum[wBranch]
) == AltDec)
}
# Add it to decision tree
tree[[wBranch]][[wSplit]] <- downSplit
} else {
cat(
"No non-ambiguous rule to separate",
group2only,
"from",
branchClasses,
". No alternative split added."
)
}
} else {
# print("No solely down-regulated cluster to add alternative split.")
}
return(tree)
}
#' @title Gets cluster estimates using rules generated by
#' `celda::findMarkersTree`
#' @description Get decisions for a matrix of features. Estimate cell
#' cluster membership using feature matrix input.
#' @param rules List object. The `rules` element from `findMarkersTree`
#' output. Returns NA if cluster estimation was ambiguous.
#' @param features A L(features) by N(samples) numeric matrix.
#' @return A character vector of label predicitions.
getDecisions <- function(rules, features) {
features <- t(features)
votes <- apply(features, 1, .predictClass, rules)
return(votes)
}
# Function to predict class from list of rules
.predictClass <- function(samp, rules) {
# Initilize possible classes and level
classes <- names(rules)
level <- 1
# Set maximum levele possible to prevent infinity run
maxLevel <- max(unlist(lapply(rules, function(ruleSet) {
ruleSet$level
})))
while (length(classes) > 1 & level <= maxLevel) {
# Get possible classes
clLogical <-
unlist(lapply(classes, function(cl, rules, level, samp) {
# Get the rules for this class
ruleClass <- rules[[cl]]
# Get the rules for this level
ruleClass <-
ruleClass[ruleClass$level == level, , drop = FALSE]
# Subset class for the features at this level
ruleClass$sample <- samp[ruleClass$feature]
# For multiple direction == 1, use one with the top stat
if (sum(ruleClass$direction == 1) > 1) {
ruleClass <- ruleClass[order(ruleClass$direction,
decreasing = TRUE
), ]
ruleClass <- ruleClass[c(
which.max(ruleClass$stat[ruleClass$direction == 1]),
which(ruleClass$direction == -1)
), , drop = FALSE]
}
# Check for followed rules
ruleClass$check <- ruleClass$sample >= ruleClass$value
ruleClass$check[ruleClass$direction == -1] <-
!ruleClass$check[ruleClass$direction == -1]
# Check that all rules were followed
ruleFollowed <- mean(ruleClass$check &
ruleClass$direction == 1) > 0 |
mean(ruleClass$check) == 1
return(ruleFollowed)
}, rules, level, samp))
# Subset possible classes
classes <- classes[clLogical]
# Add level
level <- level + 1
}
# Return if only one class selected
if (length(classes) == 1) {
return(classes)
} else {
return(NA)
}
}
# Function to summarize and format tree list output by .generateTreeList
.summarizeTree <- function(tree, features, class) {
# Format tree into dendrogram object
dendro <- .convertToDendrogram(tree, class)
# Map classes to features
class2features <- .mapClass2features(tree, features, class)
# Get performance of the tree on training samples
perfList <-
.getPerformance(class2features$rules, features, class)
return(
list(
rules = class2features$rules,
dendro = dendro,
prediction = perfList$prediction,
performance = perfList$performance
)
)
}
# Function to reformat raw tree ouput to a dendrogram
.convertToDendrogram <- function(tree, class, splitNames = NULL) {
# Unlist decision tree (one element for each split)
DecTree <- unlist(tree, recursive = FALSE)
if (is.null(splitNames)) {
# Name split by gene and threshold
splitNames <- lapply(DecTree, function(split) {
# Remove non-split elements
dirs <- paste0(split$dirs, collapse = "_")
split <-
split[!names(split) %in% c("statUsed", "fUsed", "dirs")]
# Get set of features and values for each
featuresplits <- lapply(split, function(node) {
nodeFeature <- node$featureName
nodeStrings <- paste(nodeFeature, collapse = ";")
})
# Get split directions
names(featuresplits) <- paste(dirs,
seq(length(featuresplits)),
sep = "_"
)
return(featuresplits)
})
splitNames <- unlist(splitNames)
names(splitNames) <- sub("1_", "", names(splitNames))
}
else {
names(splitNames) <- seq(length(DecTree[[1]]) - 3)
}
# Get Stat Used
statUsed <- unlist(lapply(DecTree, function(split) {
split$statUsed
}))
statRep <- unlist(lapply(
DecTree,
function(split) {
length(split[!names(split) %in% c("statUsed", "fUsed", "dirs")])
}
))
statUsed <- unlist(lapply(
seq(length(statUsed)),
function(i) {
rep(statUsed[i], statRep[i])
}
))
names(statUsed) <- names(splitNames)
# Create Matrix of results
mat <-
matrix(0, nrow = length(DecTree), ncol = length(unique(class)))
colnames(mat) <- unique(class)
for (i in seq(1, length(DecTree))) {
# If only one split than ezpz
split <- DecTree[[i]]
split <-
split[!names(split) %in% c("statUsed", "fUsed", "dirs")]
if (length(split) == 1) {
mat[i, split[[1]]$group1Consensus] <- 1
mat[i, split[[1]]$group2Consensus] <- 2
# Otherwise we need to assign > 2 splits for different higher groups
} else {
# Get classes in group 1
group1classUnique <- unique(lapply(
split,
function(X) {
X$group1Consensus
}
))
group1classVec <- unlist(group1classUnique)
# Get classes always in group 2
group2classUnique <- unique(unlist(lapply(
split,
function(X) {
X$group2Consensus
}
)))
group2classUnique <-
group2classUnique[!group2classUnique %in%
group1classVec]
# Assign
for (j in seq(length(group1classUnique))) {
mat[i, group1classUnique[[j]]] <- j
}
mat[i, group2classUnique] <- j + 1
}
}
## Collapse matrix to get set of direction to include in dendrogram
matCollapse <- sort(apply(
mat,
2,
function(x) {
paste(x[x != 0], collapse = "_")
}
))
matUnique <- unique(matCollapse)
# Get branchlist
bList <- c()
j <- 1
for (i in seq(max(ncharX(matUnique)))) {
sLength <- matUnique[ncharX(matUnique) >= i]
sLength <- unique(subUnderscore(sLength, i))
for (k in sLength) {
bList[j] <- k
j <- j + 1
}
}
# Initialize dendrogram list
val <- max(ncharX(matUnique)) + 1
dendro <- list()
attributes(dendro) <- list(
members = length(matCollapse),
classLabels = unique(class),
height = val,
midpoint = (length(matCollapse) - 1) / 2,
label = NULL,
name = NULL
)
for (i in bList) {
# Add element
iSplit <- unlist(strsplit(i, "_"))
iPaste <- paste0(
"dendro",
paste(paste0("[[", iSplit, "]]"), collapse = "")
)
eval(parse(
text =
paste0(iPaste, "<-list()")
))
# Add attributes
classLabels <- names(matCollapse[subUnderscore(
matCollapse,
ncharX(i)
) == i])
members <- length(classLabels)
# Add height, set to one if leaf
height <- val - ncharX(i)
# Check that this isn't a terminal split
if (members == 1) {
height <- 1
}
# Add labels and stat used
if (i %in% names(splitNames)) {
lab <- splitNames[i]
statUsedI <- statUsed[i]
} else {
lab <- NULL
statUsedI <- NULL
}
att <- list(
members = members,
classLabels = classLabels,
edgetext = lab,
height = height,
midpoint = (members - 1) / 2,
label = lab,
statUsed = statUsedI,
name = i
)
eval(parse(text = paste0("attributes(", iPaste, ") <- att")))
# Add leaves
leaves <- matCollapse[matCollapse == i]
if (length(leaves) > 0) {
for (l in seq(1, length(leaves))) {
# Add element
lPaste <- paste0(iPaste, "[[", l, "]]")
eval(parse(text = paste0(lPaste, "<-list()")))
# Add attributes
members <- 1
leaf <- names(leaves)[l]
height <- 0
att <- list(
members = members,
classLabels = leaf,
height = height,
label = leaf,
leaf = TRUE,
name = i
)
eval(parse(text = paste0("attributes(", lPaste, ") <- att")))
}
}
}
class(dendro) <- "dendrogram"
return(dendro)
}
# Function to calculate the number of non-underscore characters in a string
ncharX <- function(x) {
unlist(lapply(strsplit(x, "_"), length))
}
# Function to subset a string of characters seperated by underscores
subUnderscore <- function(x, n) {
unlist(lapply(
strsplit(x, "_"),
function(y) {
paste(y[seq(n)], collapse = "_")
}
))
}
# Function to calculate performance statistics
.getPerformance <- function(rules, features, class) {
# Get classification accuracy, balanced accurecy, and per class sensitivity
## Get predictions
votes <- getDecisions(rules, t(features))
votes[is.na(votes)] <- "MISSING"
## Calculate accuracy statistics and per class sensitivity
class <- as.character(class)
acc <- mean(votes == as.character(class))
classCorrect <- vapply(unique(class),
function(x) {
sum(votes == x & class == x)
},
FUN.VALUE = double(1)
)
classCount <- c(table(class))[unique(class)]
sens <- classCorrect / classCount
## Calculate balanced accuracy
balacc <- mean(sens)
## Calculate per class and mean precision
voteCount <- c(table(votes))[unique(class)]
prec <- classCorrect / voteCount
meanPrecision <- mean(prec)
## Add performance metrics
performance <- list(
accuracy = acc,
balAcc = balacc,
meanPrecision = meanPrecision,
correct = classCorrect,
sizes = classCount,
sensitivity = sens,
precision = prec
)
return(list(
prediction = votes,
performance = performance
))
}
# Create rules of classes and features sequences
.mapClass2features <-
function(tree, features, class, topLevelMeta = FALSE) {
# Get class to feature indices
class2featuresIndices <- do.call(rbind, lapply(
seq(length(tree)),
function(i) {
treeLevel <- tree[[i]]
c2fsub <- as.data.frame(do.call(rbind, lapply(
treeLevel,
function(split) {
# Keep track of stat used for rule list
statUsed <- split$statUsed
# Keep only split information
split <- split[!names(split) %in%
c("statUsed", "fUsed", "dirs")]
# Create data frame of split rules
edgeFram <-
do.call(rbind, lapply(split, function(edge) {
# Create data.frame of groups, split-dirs, feature IDs
groups <-
c(edge$group1Consensus, edge$group2Consensus)
sdir <- c(
rep(1, length(edge$group1Consensus)),
rep(-1, length(edge$group2Consensus))
)
feat <- edge$featureName
val <- edge$value
stat <- edge$stat
data.frame(
class = rep(groups, length(feat)),
feature = rep(feat, each = length(groups)),
direction = rep(sdir, length(feat)),
value = rep(val, each = length(groups)),
stat = rep(stat, each = length(groups)),
stringsAsFactors = FALSE
)
}))
# Add stat used
edgeFram$statUsed <- statUsed
return(edgeFram)
}
)))
c2fsub$level <- i
return(c2fsub)
}
))
rownames(class2featuresIndices) <- NULL
# Generate list of rules for each class
if (topLevelMeta) {
orderedClass <- unique(class2featuresIndices[
class2featuresIndices$direction == 1, "class"
])
}
else {
orderedClass <- levels(class)
}
rules <-
lapply(orderedClass, function(cl, class2featuresIndices) {
class2featuresIndices[
class2featuresIndices$class == cl,
colnames(class2featuresIndices) != "class"
]
}, class2featuresIndices)
names(rules) <- orderedClass
return(list(rules = rules))
}
#' @title Plots dendrogram of \emph{findMarkersTree} output
#' @description Generates a dendrogram of the rules and performance
#' (optional) of the decision tree generated by findMarkersTree().
#' @param tree List object. The output of findMarkersTree()
#' @param classLabel A character value. The name of a specific label to draw
#' the path and rules. If NULL (default), the tree for all clusters is shown.
#' @param addSensPrec Logical. Print training sensitivities and precisions
#' for each cluster below leaf label? Default is FALSE.
#' @param maxFeaturePrint Numeric value. Maximum number of markers to print
#' at a given split. Default is 4.
#' @param leafSize Numeric value. Size of text below each leaf. Default is 24.
#' @param boxSize Numeric value. Size of rule labels. Default is 7.
#' @param boxColor Character value. Color of rule labels. Default is black.
#' @examples
#' \dontrun{
#' # Generate simulated single-cell dataset using celda
#' sim_counts <- celda::simulateCells("celda_CG", K = 4, L = 10, G = 100)
#'
#' # Celda clustering into 5 clusters & 10 modules
#' cm <- celda_CG(sim_counts$counts, K = 5, L = 10, verbose = FALSE)
#'
#' # Get features matrix and cluster assignments
#' factorized <- factorizeMatrix(sim_counts$counts, cm)
#' features <- factorized$proportions$cell
#' class <- celdaClusters(cm)
#'
#' # Generate Decision Tree
#' DecTree <- findMarkersTree(features, class, threshold = 1)
#'
#' # Plot dendrogram
#' plotMarkerDendro(DecTree)
#' }
#' @return A ggplot2 object
#' @export
plotMarkerDendro <- function(tree,
classLabel = NULL,
addSensPrec = FALSE,
maxFeaturePrint = 4,
leafSize = 10,
boxSize = 2,
boxColor = "black") {
# Get necessary elements
dendro <- tree$dendro
# Get performance information (training or CV based)
if (addSensPrec) {
performance <- tree$performance
# Create vector of per class performance
perfVec <- paste0(
"Sens. ",
format(round(performance$sensitivity, 2), nsmall = 2),
"\n Prec. ",
format(round(performance$precision, 2), nsmall = 2)
)
names(perfVec) <- names(performance$sensitivity)
}
# Get dendrogram segments
dendSegs <-
ggdendro::dendro_data(dendro, type = "rectangle")$segments
# Get necessary coordinates to add labels to
# These will have y > 1
dendSegs <-
unique(dendSegs[dendSegs$y > 1, c("x", "y", "yend", "xend")])
# Labeled splits will be vertical (x != xend) or
# Length 0 (x == xend & y == yend)
dendSegsAlt <- dendSegs[
dendSegs$x != dendSegs$xend |
(dendSegs$x == dendSegs$xend &
dendSegs$y == dendSegs$yend),
c("x", "xend", "y")
]
colnames(dendSegsAlt)[1] <- "xalt"
# Label names will be at nodes, these will
# Occur at the end of segments
segs <- as.data.frame(dendextend::get_nodes_xy(dendro))
colnames(segs) <- c("xend", "yend")
# Add labels to nodes
segs$label <-
gsub(";", "\n", dendextend::get_nodes_attr(dendro, "label"))
# Subset for max
segs$label <-
sapply(segs$label, function(lab, maxFeaturePrint) {
loc <- gregexpr("\n", lab)[[1]][maxFeaturePrint]
if (!is.na(loc)) {
lab <- substr(lab, 1, loc - 1)
}
return(lab)
}, maxFeaturePrint)
segs$statUsed <- dendextend::get_nodes_attr(dendro, "statUsed")
# If highlighting a class label, remove non-class specific rules
if (!is.null(classLabel)) {
if (!classLabel %in% names(tree$rules)) {
stop("classLabel not a valid class ID.")
}
dendro <- .highlightClassLabel(dendro, classLabel)
keepLabel <- dendextend::get_nodes_attr(dendro, "keepLabel")
keepLabel[is.na(keepLabel)] <- FALSE
segs$label[!keepLabel] <- NA
}
# Remove non-labelled nodes &
# leaf nodes (yend == 0)
segs <- segs[!is.na(segs$label) & segs$yend != 0, ]
# Merge to full set of coordinates
dendSegsLabelled <- merge(dendSegs, segs)
# Remove duplicated labels
dendSegsLabelled <- dendSegsLabelled[order(dendSegsLabelled$y,
decreasing = TRUE
), ]
dendSegsLabelled <- dendSegsLabelled[!duplicated(dendSegsLabelled[
,
c(
"xend", "x", "yend",
"label", "statUsed"
)
]), ]
# Merge with alternative x-coordinates for alternative split
dendSegsLabelled <- merge(dendSegsLabelled, dendSegsAlt)
# Order by height and coordinates
dendSegsLabelled <-
dendSegsLabelled[order(dendSegsLabelled$x), ]
# Find information gain splits
igSplits <- dendSegsLabelled$statUsed == "Split" &
!duplicated(dendSegsLabelled[, c("xalt", "y")])
# Set xend for IG splits
dendSegsLabelled$xend[igSplits] <-
dendSegsLabelled$xalt[igSplits]
# Set y for non-IG splits
dendSegsLabelled$y[!igSplits] <-
dendSegsLabelled$y[!igSplits] - 0.2
# Get index of leaf labels
leafLabels <- dendextend::get_leaves_attr(dendro, "label")
# Adjust leaf labels if there are metacluster labels
if (!is.null(tree$metaclusterLabels)) {
leafLabels <- regmatches(
leafLabels,
regexpr(
pattern = "(?<=\\().*?(?=\\)$)",
leafLabels, perl = TRUE
)
)
}
# Add sensitivity and precision measurements
if (addSensPrec) {
leafLabels <- paste(leafLabels, perfVec, sep = "\n")
leafAngle <- 0
leafHJust <- 0.5
leafVJust <- -1
} else {
leafAngle <- 90
leafHJust <- 1
leafVJust <- 0.5
}
# Create plot of dendrogram
suppressMessages(
dendroP <- ggdendro::ggdendrogram(dendro) +
ggplot2::geom_label(
data = dendSegsLabelled,
ggplot2::aes(
x = dendSegsLabelled$xend,
y = dendSegsLabelled$y,
label = dendSegsLabelled$label
),
size = boxSize,
label.size = 1,
fontface = "bold",
vjust = 1,
nudge_y = 0.1,
color = boxColor
) +
ggplot2::theme_bw() +
ggplot2::scale_x_reverse(
breaks =
seq(length(leafLabels)),
label = leafLabels
) +
ggplot2::scale_y_continuous(expand = c(0, 0)) +
ggplot2::theme(
panel.grid.major.y = ggplot2::element_blank(),
legend.position = "none",
panel.grid.minor.y = ggplot2::element_blank(),
panel.grid.minor.x = ggplot2::element_blank(),
panel.grid.major.x = ggplot2::element_blank(),
panel.border = ggplot2::element_blank(),
axis.title = ggplot2::element_blank(),
axis.ticks = ggplot2::element_blank(),
axis.text.x = ggplot2::element_text(
hjust = leafHJust,
angle = leafAngle,
size = leafSize,
family = "Palatino",
face = "bold",
vjust = leafVJust
),
axis.text.y = ggplot2::element_blank()
)
)
# Check if need to add metacluster labels
if (!is.null(tree$metaclusterLabels)) {
# store metacluster labels to add
newLabels <- unique(tree$branchPoints$top_level$metacluster)
# adjust labels for metaclusters of size one
newLabels <- unlist(lapply(newLabels, function(curMeta) {
if (substr(curMeta, nchar(curMeta), nchar(curMeta)) == ")") {
return(gsub(
pattern = "\\(.*\\)$",
replacement = "",
x = curMeta
))
}
else {
return(curMeta)
}
}))
# Create table for metacluster labels
metaclusterText <- dendSegsLabelled[
dendSegsLabelled$y ==
max(dendSegsLabelled$y),
c("xend", "y", "label")
]
metaclusterText$label <- newLabels
# Add metacluster labels to top of plot
dendroP <- dendroP +
ggplot2::geom_text(
data = metaclusterText,
ggplot2::aes(
x = metaclusterText$xend,
y = metaclusterText$y,
label = metaclusterText$label,
fontface = 2
),
angle = 90,
nudge_y = 0.5,
family = "Palatino",
size = leafSize / 3
)
# adjust coordinates of plot to show labels
dendroP <- dendroP + ggplot2::coord_cartesian(
ylim =
c(
0,
max(dendSegsLabelled$y +
1)
)
)
}
# Increase line width slightly for aesthetic purposes
dendroP$layers[[2]]$aes_params$size <- 1.3
return(dendroP)
}
# Function to reformat the dendrogram to draw path to a specific class
.highlightClassLabel <- function(dendro, classLabel) {
# Reorder dendrogram
flag <- TRUE
bIndexString <- ""
# Get branch
branch <- eval(parse(text = paste0("dendro", bIndexString)))
while (flag) {
# Get attributes
att <- attributes(branch)
# Get split with the label of interest
labList <- lapply(branch, function(split) {
attributes(split)$classLabels
})
wSplit <- which(unlist(lapply(
labList,
function(vec) {
classLabel %in% vec
}
)))
# Keep labels for this branch
branch <- lapply(branch, function(edge) {
attributes(edge)$keepLabel <- TRUE
return(edge)
})
# Make a dendrogram class again
class(branch) <- "dendrogram"
attributes(branch) <- att
# Add branch to dendro
eval(parse(text = paste0("dendro", bIndexString, "<- branch")))
# Create new bIndexString
bIndexString <- paste0(bIndexString, "[[", wSplit, "]]")
# Get branch
branch <- eval(parse(text = paste0("dendro", bIndexString)))
# Add flag
flag <- attributes(branch)$members > 1
}
return(dendro)
}
#' @title Generate heatmap for a marker decision tree
#' @description Creates heatmap for a specified branch point in a marker tree.
#' @param tree A decision tree returned from \link{findMarkersTree} function.
#' @param counts Numeric matrix. Gene-by-cell counts matrix.
#' @param branchPoint Character. Name of branch point to plot heatmap for.
#' Name should match those in \emph{tree$branchPoints}.
#' @param featureLabels List of feature cluster assignments. Length should
#' be equal to number of rows in counts matrix, and formatting should match
#' that used in \emph{findMarkersTree()}. Required when using clusters
#' of features and not previously provided to \emph{findMarkersTree()}
#' @param topFeatures Integer. Number of genes to plot per marker module.
#' Genes are sorted based on their AUC for their respective cluster.
#' Default is 10.
#' @param silent Logical. Whether to avoid plotting heatmap to screen.
#' Default is FALSE.
#' @return A heatmap visualizing the counts matrix for the cells and genes at
#' the specified branch point.
#' @examples
#' \dontrun{
#' # Generate simulated single-cell dataset using celda
#' sim_counts <- simulateCells("celda_CG", K = 4, L = 10, G = 100)
#'
#' # Celda clustering into 5 clusters & 10 modules
#' cm <- celda_CG(sim_counts, K = 5, L = 10, verbose = FALSE)
#'
#' # Get features matrix and cluster assignments
#' factorized <- factorizeMatrix(cm)
#' features <- factorized$proportions$cell
#' class <- celdaClusters(cm)
#'
#' # Generate Decision Tree
#' DecTree <- findMarkersTree(features, class, threshold = 1)
#'
#' # Plot example heatmap
#' plotMarkerHeatmap(DecTree, assay(sim_counts),
#' branchPoint = "top_level",
#' featureLabels = paste0("L", celdaModules(cm)))
#' }
#' @export
plotMarkerHeatmap <- function(tree,
counts,
branchPoint,
featureLabels,
topFeatures = 10,
silent = FALSE) {
# get branch point to plot
branch <- tree$branchPoints[[branchPoint]]
# check that user entered valid branch point name
if (is.null(branch)) {
stop(
"Invalid branch point.",
" Branch point name should match one of those in tree$branchPoints."
)
}
# convert counts matrix to matrix (e.g. from dgCMatrix)
counts <- as.matrix(counts)
# get marker features
marker <- unique(branch$feature)
# add feature labels
if ("featureLabels" %in% names(tree)) {
featureLabels <- tree$featureLabels
}
# check that feature labels are provided
if (missing(featureLabels) &
!("featureLabels" %in% names(tree)) &
(sum(marker %in% rownames(counts)) != length(marker))) {
stop("Please provide feature labels, i.e. gene cluster labels")
}
else {
if (missing(featureLabels) &
!("featureLabels" %in% names(tree)) &
(sum(marker %in% rownames(counts)) == length(marker))) {
featureLabels == rownames(counts)
}
}
# make sure feature labels match the table
if (!all(branch$feature %in% featureLabels)) {
stop(
"Provided feature labels don't match those in the tree.",
" Please check the feature names in the tree's rules' table."
)
}
# if top-level in metaclusters tree
if (branchPoint == "top_level") {
# get unique metaclusters
metaclusters <- unique(branch$metacluster)
# list which will contain final set of genes for heatmap
whichFeatures <- c()
# loop over unique metaclusters
for (meta in metaclusters) {
# subset table
curMeta <- branch[branch$metacluster == meta, ]
# if we have gene-level info in the tree
if ("gene" %in% names(branch)) {
# sort by gene AUC score
curMeta <-
curMeta[order(curMeta$geneAUC, decreasing = TRUE), ]
# get genes
genes <- unique(curMeta$gene)
# keep top N features
genes <- utils::head(genes, topFeatures)
# get gene indices
markerGenes <- which(rownames(counts) %in% genes)
# get features with non-zero variance to avoid clustering error
markerGenes <- .removeZeroVariance(
counts,
cells = which(
tree$metaclusterLabels %in%
unique(curMeta$metacluster)
),
markers = markerGenes
)
# add to list of features
whichFeatures <- c(whichFeatures, markerGenes)
}
else {
# current markers
curMarker <- unique(curMeta$feature)
# get marker gene indices
markerGenes <- which(featureLabels %in% curMarker)
# get features with non-zero variance to avoid error
markerGenes <- .removeZeroVariance(
counts,
cells = which(
tree$metaclusterLabels %in%
unique(curMeta$metacluster)
),
markers = markerGenes
)
# add to list of features
whichFeatures <- c(whichFeatures, markerGenes)
}
}
# order the metaclusters by size
colOrder <- data.frame(
groupName = names(sort(
table(tree$metaclusterLabels),
decreasing = TRUE
)),
groupIndex = seq_along(unique(tree$metaclusterLabels))
)
# order the markers for metaclusters
allMarkers <- stats::setNames(as.list(colOrder$groupName),
colOrder$groupName)
allMarkers <- lapply(allMarkers, function(x) {
unique(branch[branch$metacluster == x, "feature"])
})
rowOrder <- data.frame(
groupName = unlist(allMarkers),
groupIndex = seq_along(unlist(allMarkers))
)
toRemove <-
which(!rowOrder$groupName %in% featureLabels[whichFeatures])
if (length(toRemove) > 0) {
rowOrder <- rowOrder[-toRemove, ]
}
# sort cells according to metacluster size
x <- tree$metaclusterLabels
y <- colOrder$groupName
sortedCells <- seq(ncol(counts))[order(match(x, y))]
# create heatmap with only the markers
return(
plotHeatmap(
counts = counts,
z = tree$metaclusterLabels,
y = featureLabels,
featureIx = whichFeatures,
cellIx = sortedCells,
showNamesFeature = TRUE,
main = "Top-level",
silent = silent,
treeheightFeature = 0,
colGroupOrder = colOrder,
rowGroupOrder = rowOrder,
treeheightCell = 0
)
)
}
# if balanced split
if (branch$statUsed[1] == "Split") {
# keep entries for balanced split only (in case of alt. split)
split <- branch$feature[1]
branch <- branch[branch$feature == split, ]
# get up-regulated and down-regulated classes
upClasses <- unique(branch[branch$direction == 1, "class"])
downClasses <-
unique(branch[branch$direction == (-1), "class"])
# re-order cells to keep up and down separate on the heatmap
reorderedCells <- c(
(which(tree$classLabels %in% upClasses)
[order(tree$classLabels[tree$classLabels %in% upClasses])]),
(which(tree$classLabels %in% downClasses)
[order(tree$classLabels[tree$classLabels %in% downClasses])])
)
# cell annotation based on split
cellAnno <-
data.frame(
split = rep("Down-regulated", ncol(counts)),
stringsAsFactors = FALSE
)
cellAnno$split[which(tree$classLabels %in% upClasses)] <-
"Up-regulated"
rownames(cellAnno) <- colnames(counts)
# if we have gene-level info in the tree
if (("gene" %in% names(branch))) {
# get genes
genes <- unique(branch$gene)
# keep top N features
genes <- utils::head(genes, topFeatures)
# get gene indices
whichFeatures <- which(rownames(counts) %in% genes)
# get features with non-zero variance to avoid error
whichFeatures <- .removeZeroVariance(counts,
cells = which(tree$classLabels %in%
unique(branch$class)),
markers = whichFeatures
)
# create heatmap with only the split feature and split classes
return(
plotHeatmap(
counts = counts,
z = tree$classLabels,
y = featureLabels,
featureIx = whichFeatures,
cellIx = reorderedCells,
clusterCell = FALSE,
showNamesFeature = TRUE,
main = branchPoint,
silent = silent,
treeheightFeature = 0,
treeheightCell = 0,
annotationCell = cellAnno
)
)
}
else {
# get features with non-zero variance to avoid error
whichFeatures <-
.removeZeroVariance(
counts,
cells = reorderedCells,
markers = which(featureLabels ==
branch$feature[1])
)
# create heatmap with only the split feature and split classes
return(
plotHeatmap(
counts = counts,
z = tree$classLabels,
y = featureLabels,
featureIx = whichFeatures,
cellIx = reorderedCells,
clusterCell = FALSE,
showNamesFeature = TRUE,
main = branchPoint,
silent = silent,
treeheightFeature = 0,
treeheightCell = 0,
annotationCell = cellAnno
)
)
}
}
# if one-off split
if (branch$statUsed[1] == "One-off") {
# get unique classes
classes <- unique(branch$class)
# list which will contain final set of genes for heatmap
whichFeatures <- c()
# loop over unique classes
for (class in classes) {
# subset table
curClass <-
branch[branch$class == class & branch$direction == 1, ]
# if we have gene-level info in the tree
if (("gene" %in% names(branch))) {
# get genes
genes <- unique(curClass$gene)
# keep top N features
genes <- utils::head(genes, topFeatures)
# get gene indices
markerGenes <- which(rownames(counts) %in% genes)
# get features with non-zero variance to avoid error
markerGenes <- .removeZeroVariance(
counts,
cells = which(tree$classLabels %in%
unique(curClass$class)),
markers = markerGenes
)
# add to list of features
whichFeatures <- c(whichFeatures, markerGenes)
}
else {
# get features with non-zero variance to avoid error
markerGenes <- .removeZeroVariance(
counts,
cells = which(tree$classLabels %in%
unique(curClass$class)),
markers = which(featureLabels %in%
unique(curClass$feature))
)
# add to list of features
whichFeatures <- c(whichFeatures, markerGenes)
}
}
# order the clusters such that up-regulated come first
colOrder <- data.frame(
groupName = unique(branch[
order(branch$direction, decreasing = TRUE),
"class"
]),
groupIndex = seq_along(unique(branch$class))
)
# order the markers for clusters
allMarkers <- stats::setNames(as.list(colOrder$groupName),
colOrder$groupName)
allMarkers <- lapply(allMarkers, function(x) {
unique(branch[branch$class == x & branch$direction == 1, "feature"])
})
rowOrder <- data.frame(
groupName = unlist(allMarkers),
groupIndex = seq_along(unlist(allMarkers))
)
toRemove <-
which(!rowOrder$groupName %in% featureLabels[whichFeatures])
if (length(toRemove) > 0) {
rowOrder <- rowOrder[-toRemove, ]
}
# sort cells according to metacluster size
x <-
tree$classLabels # [tree$classLabels %in% unique(branch$class)]
y <- colOrder$groupName
sortedCells <- seq(ncol(counts))[order(match(x, y))]
sortedCells <-
sortedCells[seq(sum(tree$classLabels %in% classes))]
# create heatmap with only the split features and split classes
return(
plotHeatmap(
counts = counts,
z = tree$classLabels,
y = featureLabels,
featureIx = whichFeatures,
cellIx = sortedCells,
showNamesFeature = TRUE,
main = branchPoint,
silent = silent,
treeheightFeature = 0,
colGroupOrder = colOrder,
rowGroupOrder = rowOrder,
treeheightCell = 0
)
)
}
}
# helper function to identify zero-variance genes in a counts matrix
.removeZeroVariance <- function(counts, cells, markers) {
# subset counts matrix
counts <- counts[, cells]
# scale rows
counts <- t(scale(t(counts)))
# get indices of genes which have NA
zeroVarianceGenes <- which(!stats::complete.cases(counts))
# find overlap between zero-variance genes and marker genes
zeroVarianceMarkers <- intersect(zeroVarianceGenes, markers)
# return indices of marker genes without zero-variance
if (length(zeroVarianceMarkers) > 0) {
return(markers[-which(markers %in% zeroVarianceMarkers)])
} else {
return(markers)
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.