#' @title Trains a (mostly) LSTM model on genomic data. Designed for developing genome based language models (GenomeNet)
#'
#' @description
#' Depth and number of neurons per layer of the netwok can be specified. First layer can be a Convolutional Neural Network (CNN) that is designed to capture codons.
#' If a path to a folder where FASTA files are located is provided, batches will ge generated using an external generator which
#' is recommended for big training sets. Alternative, a dataset can be supplied that holds the preprocessed batches (generated by \code{preprocessSemiRedundant()})
#' and keeps them in RAM. Supports also training on instances with multiple GPUs and scales linear with number of GPUs present.
#' @param train_type Either "lm" for language model, "label_header" or "label_folder". Language model is trained to predict next character in sequence.
#' label_header/label_folder are trained to predict a corresponding class, given a sequence as input. If "label_header", class will be read from fasta headers.
#' If "label_folder", class will be read from folder, i.e. all fasta files in one folder must belong to the same class. mailab
#' @param model A keras model.
#' @param model_path Path to a pretrained model.
#' @param path Path to folder where individual or multiple FASTA files are located for training. If \code{train_type} is \code{label_folder}, should be a vector
#' containing a path for each class.
#' @param path.val Path to folder where individual or multiple FASTA files are located for validation.If \code{train_type} is \code{label_folder}, should be a vector
#' containing a path for each class.
#' @param dataset Dataframe holding training samples in RAM instead of using generator.
#' @param checkpoint_path Path to checkpoints folder.
#' @param validation.split Defines the fraction of the batches that will be used for validation (compared to size of training data).
#' @param run.name Name of the run (without file ending). Name will be used to identify output from callbacks.
#' @param batch.size Number of samples that are used for one network update.
#' @param epochs Number of iterations.
#' @param max.queue.size Queue on fit_generator().
#' @param lr.plateau.factor Factor of decreasing learning rate when plateau is reached.
#' @param patience Number of epochs waiting for decrease in loss before reducing learning rate.
#' @param cooldown Number of epochs without changing learning rate.
#' @param steps.per.epoch Number of batches to finish one epoch.
#' @param step Frequency of sampling steps.
#' @param randomFiles TRUE/FALSE go through files sequentially or shuffle beforehand.
#' @param vocabulary Vector of allowed characters, character outside vocabulary get encoded as 0-vector.
#' @param initial_epoch Epoch at which to start training, set to 0 if no \code{model_path} argument is given. Note that network
#' will run for (\code{epochs} - \code{initial_epochs}) rounds and not \code{epochs} rounds.
#' @param tensorboard.log Path to tensorboard log directory.
#' @param save_best_only Only save model that improved on best val_loss score.
#' @param compile Whether to compile the model after loading.
#' @param solver Optimization method, options are "adam", "adagrad", "rmsprop" or "sgd". Only used when pretrained model is given (\code{model_path} is not NULL) and compile is FALSE.
#' Otherwise solver is determined when model is created.
#' @param learning.rate Learning rate for optimizer. Only used when pretrained model is given (\code{model_path} is not NULL) and compile is FALSE.
#' Otherwise learning rate is determined when model is created.
#' @param max_iter Stop after max_iter number of iterations failed to produce new sample.
#' @param seed Sets seed for set.seed function, for reproducible results when using \code{randomFiles} or \code{shuffleFastaEntries}
#' @param shuffleFastaEntries Logical, shuffle entries in file.
#' @param output List of optional outputs, no output if none is TRUE.
#' @param format File format, "fasta" or "fastq".
#' @param fileLog Write name of files to csv file if path is specified.
#' @param labelVocabulary Character vector of possible targets. Targets outside \code{labelVocabulary} will get discarded.
#' @param numberOfFiles Use only specified number of files, ignored if greater than number of files in corpus.dir.
#' @param reverseComplements Logical, half of batch contains sequences and other its reverse complements. Reverse complement
#' is given by reversed order of sequence and switching A/T and C/G. \code{batch.size} argument has to be even, otherwise 1 will be added
#' to \code{batch.size}
#' @export
trainNetwork <- function(train_type = "lm",
model_path = NULL,
model = NULL,
path = NULL,
path.val = NULL,
dataset = NULL,
checkpoint_path,
validation.split = 0.2,
run.name = "run",
batch.size = 64,
epochs = 10,
max.queue.size = 100,
lr.plateau.factor = 0.9,
patience = 5,
cooldown = 5,
steps.per.epoch = 1000,
step = 1,
randomFiles = FALSE,
initial_epoch = NULL,
vocabulary = c("a", "c", "g", "t"),
tensorboard.log,
save_best_only = TRUE,
compile = TRUE,
learning.rate = NULL,
solver = NULL,
max_iter = 1000,
seed = c(1234, 4321),
shuffleFastaEntries = FALSE,
output = list(none = FALSE,
checkpoints =TRUE,
tensorboard = TRUE,
log = TRUE,
serialize_model = TRUE,
full_model = TRUE
),
format = "fasta",
fileLog = NULL,
labelVocabulary = NULL,
numberOfFiles = NULL,
reverseComplements = FALSE) {
stopifnot(train_type %in% c("lm", "label_header", "label_folder"))
if (train_type == "lm"){
labelGen <- FALSE
labelByFolder <- FALSE
}
if (train_type == "label_header"){
labelGen <- TRUE
labelByFolder <- FALSE
stopifnot(!is.null(labelVocabulary))
}
if (train_type == "label_folder"){
labelGen <- TRUE
labelByFolder <- TRUE
}
if (output$none){
output$checkpoints <- FALSE
output$tensorboard <- FALSE
output$log <- FALSE
output$serialize_model <- FALSE
output$full_model <- FALSE
}
label.vocabulary.size <- length(labelVocabulary)
vocabulary.size <- length(vocabulary)
# extract maxlen from model
maxlen <- model$input$shape[1]
if (labelByFolder){
if (length(path) == 1) warning("Training with just one label")
}
if (output$checkpoints){
## create folder for checkpoints using run.name
## filenames contain epoch, validation loss and validation accuracy
checkpoint_dir <- paste0(checkpoint_path, "/", run.name, "_checkpoints")
dir.create(checkpoint_dir, showWarnings = FALSE)
filepath_checkpoints <- file.path(checkpoint_dir, "Ep.{epoch:03d}-val_loss{val_loss:.2f}-val_acc{val_acc:.3f}.hdf5")
}
# Check if run.name is unique
if (dir.exists(file.path(tensorboard.log, run.name)) & output$tensorboard) {
stop(paste0("Tensorboard entry '", run.name , "' is already present. Please give your run a unique name."))
}
# Load pretrained model
if (!is.null(model_path)){
# epochs arguments can be misleading
if (!missing(initial_epoch)){
if (initial_epoch >= epochs){
stop("Networks trains (epochs - initial_epochs) rounds overall, NOT epochs rounds. Increase epochs or decrease initial_epoch.")
}
}
# extract initial_epoch from filename if no argument is given
if (is.null(initial_epoch)){
epochFromFilename <- stringr::str_extract(model_path, "Ep.\\d+")
initial_epoch <- as.integer(substring(epochFromFilename, 4, nchar(epochFromFilename)))
if (initial_epoch >= epochs){
stop("Networks trains (epochs - initial_epochs) rounds overall, NOT epochs rounds. Increase epochs or decrease initial_epoch.")
}
}
# load model
model <- keras::load_model_hdf5(model_path, compile = compile)
summary(model)
# extract maxlen
maxlen <- model$input$shape[1]
if (compile & (!is.null(learning.rate)|!is.null(solver))){
message("Arguments for solver and learning rate will be ignored. Set compile to FALSE to use custom solver and learning rate.")
}
if (!compile){
# choose optimization method
if (solver == "adam")
optimizer <-
keras::optimizer_adam(lr = learning.rate)
if (solver == "adagrad")
optimizer <-
keras::optimizer_adagrad(lr = learning.rate)
if (solver == "rmsprop")
optimizer <-
keras::optimizer_rmsprop(lr = learning.rate)
if (solver == "sgd")
optimizer <-
keras::optimizer_sgd(lr = learning.rate)
model %>% keras::compile(loss = "categorical_crossentropy",
optimizer = optimizer, metrics = c("acc"))
}
} else {
initial_epoch <- 0
}
if (output$tensorboard){
hp <- reticulate::import("tensorboard.plugins.hparams.api")
model_hparam <- get_hyper_param(model)
# list of hyperparameters
hparams <- reticulate::dict(
HP_VOCABULARY = paste(vocabulary, collapse = ","),
HP_PATH = paste(path, collapse = ", "),
HP_REVERSE_COMP = reverseComplements,
HP_LABEL.VOC = paste(labelVocabulary, collapse = ", "),
HP_LAYER.SIZE = model_hparam$HP_LAYER.SIZE,
HP_OPTIMIZER = model_hparam$HP_OPTIMIZER,
HP_MAXLEN = maxlen,
HP_USE.CUDNN = model_hparam$HP_USE.CUDNN,
HP_EPOCHS = epochs,
HP_MAX.QUEUE.SIZE = max.queue.size,
HP_LR.PLATEAU.FACTOR = lr.plateau.factor,
HP_NUM_LAYERS = model_hparam$HP_NUM_LAYERS,
HP_BATCH.SIZE = batch.size,
HP_LEARNING.RATE = model_hparam$HP_LEARNING.RATE,
HP_DROPOUT = model_hparam$HP_DROPOUT,
HP_USE.CODON.CNN = model_hparam$HP_USE.CODON.CNN,
HP_PATIENCE = patience,
HP_COOLDOWN = cooldown,
HP_SPEPS.PER.EPOCHE = steps.per.epoch,
HP_STEP = step,
HP_RANDOM.FILES = randomFiles,
HP_BIDIRECTIONAL = model_hparam$HP_BIDIRECTIONAL
)
}
# if no dataset is supplied, external fasta generator will generate batches
if (is.null(dataset)) {
message("Starting fasta generator...")
if (!labelGen){
# generator for training
gen <- fastaFileGenerator(corpus.dir = path, batch.size = batch.size,
maxlen = maxlen, step = step, randomFiles = randomFiles,
vocabulary = vocabulary, max_iter = max_iter, seed = seed[1],
shuffleFastaEntries = shuffleFastaEntries, format = format,
fileLog = fileLog, reverseComplements = reverseComplements)
# generator for validation
gen.val <- fastaFileGenerator(corpus.dir = path.val, batch.size = batch.size,
maxlen = maxlen, step = step, randomFiles = randomFiles,
vocabulary = vocabulary, max_iter = max_iter, seed = seed[2],
shuffleFastaEntries = shuffleFastaEntries, format = format,
fileLog = NULL, reverseComplements = FALSE)
# label generator
} else {
# label by folder
if (labelByFolder){
# initialize training generators
initializeGenerators(directories = path,
format = format,
batch.size = batch.size,
maxlen = maxlen,
max_iter = max_iter,
vocabulary = vocabulary,
verbose = FALSE,
randomFiles = randomFiles,
step = step,
showWarnings = FALSE,
seed = seed[1],
shuffleFastaEntries = shuffleFastaEntries,
numberOfFiles = numberOfFiles,
fileLog = fileLog,
reverseComplements = reverseComplements,
val = FALSE)
# initialize validation generators
initializeGenerators(directories = path.val,
format = format,
batch.size = batch.size,
maxlen = maxlen,
max_iter = max_iter,
vocabulary = vocabulary,
verbose = FALSE,
randomFiles = randomFiles,
step = step,
showWarnings = FALSE,
seed = seed[2],
shuffleFastaEntries = shuffleFastaEntries,
numberOfFiles = NULL,
fileLog = fileLog,
reverseComplements = FALSE,
val = TRUE)
gen <- labelByFolderGeneratorWrapper(val = FALSE, path = path)
gen.val <- labelByFolderGeneratorWrapper(val = TRUE, path = path.val)
} else {
# generator for training
gen <- fastaLabelGenerator(corpus.dir = path,
format = format,
batch.size = batch.size,
maxlen = maxlen,
max_iter = max_iter,
vocabulary = vocabulary,
verbose = FALSE,
randomFiles = randomFiles,
step = step,
showWarnings = FALSE,
seed = seed[1],
shuffleFastaEntries = shuffleFastaEntries,
fileLog = fileLog,
labelVocabulary = labelVocabulary,
reverseComplements = reverseComplements
)
gen.val <- fastaLabelGenerator(corpus.dir = path.val,
format = format,
batch.size = batch.size,
maxlen = maxlen,
max_iter = max_iter,
vocabulary = vocabulary,
verbose = FALSE,
randomFiles = randomFiles,
step = step,
showWarnings = FALSE,
seed = seed[2],
shuffleFastaEntries = shuffleFastaEntries,
fileLog = NULL,
labelVocabulary = labelVocabulary,
reverseComplements = FALSE
)
}
}
# callback list
callbacks = list(keras::callback_reduce_lr_on_plateau(
monitor = "loss",
factor = lr.plateau.factor,
patience = patience,
cooldown = cooldown
)
)
# add optional callbacks
list_index <- 2
if (output$checkpoints){
callbacks[[list_index]] <- keras::callback_model_checkpoint(filepath = filepath_checkpoints,
save_weights_only = FALSE,
save_best_only = save_best_only,
verbose = 1)
list_index <- list_index + 1
}
if (output$tensorboard){
callbacks[[list_index]] <- keras::callback_tensorboard(file.path(tensorboard.log, run.name),
write_graph = TRUE,
histogram_freq = 1,
write_images = TRUE,
write_grads = TRUE)
# log hparams
callbacks[[list_index + 1]] <- hp$KerasCallback(file.path(tensorboard.log, run.name), hparams, trial_id = run.name)
list_index <- list_index + 2
# create string with function arguments
argumentList <- as.list(match.call(expand.dots=FALSE))
argAsChar <- as.character(argumentList)
argText <- vector("character")
argsInQuotes <- c("model_path", "path", "path.val", "checkpoint_path", "run.name", "solver",
"tensorboard.log", "fileLog", "train_type")
argText[1] <- "trainNetwork("
for (i in 2:(length(argumentList) - 1)){
arg <- argAsChar[[i]]
if (names(argumentList)[i] %in% argsInQuotes){
argText[i] <- paste0(names(argumentList)[i], " = ", '\"', arg, '\"', " ,")
} else {
argText[i] <- paste0(names(argumentList)[i], " = ", arg, " ,")
}
}
i <- length(argumentList)
if (names(argumentList)[i] %in% argsInQuotes){
argText[i] <- paste0(names(argumentList)[i], " = ", '\"', argAsChar[[i]], '\"', ")")
} else {
argText[i] <- paste0(names(argumentList)[i], " = ", argAsChar[[i]], ")")
}
# write function arguments as text in tensorboard
trainNetworkArguments <- keras::callback_lambda(
on_train_begin = function(logs){
file.writer <- tensorflow::tf$summary$create_file_writer(file.path(tensorboard.log, run.name))
file.writer$set_as_default()
tensorflow::tf$summary$text(name="Arguments", data = argText, step = 0L)
file.writer$flush()
}
)
callbacks[[list_index]] <- trainNetworkArguments
list_index <- list_index + 1
# confusion matrix callback
confMat <- keras::callback_lambda(
on_epoch_end = function(epoch, logs) {
file.writer <- tensorflow::tf$summary$create_file_writer(file.path(tensorboard.log, run.name))
file.writer$set_as_default()
if (labelGen | labelByFolder){
num_classes <- label.vocabulary.size
confMatLabels <- labelVocabulary
} else {
num_classes <- vocabulary.size
confMatLabels <- vocabulary
}
df_true_pred <- data.frame(
true = NULL,
pred = NULL
)
for (i in 1:ceiling(steps.per.epoch * validation.split)){
z <- gen.val()
y_true <- apply(z[[2]], 1, which.max) - 1
y_pred <- keras::predict_classes(model, z[[1]])
df_true_pred <- rbind(df_true_pred, cbind(y_true, y_pred))
}
df_true_pred$true <- factor(y_true, levels = 0:(length(confMatLabels) - 1), labels = confMatLabels)
df_true_pred$pred <- factor(y_pred, levels = 0:(length(confMatLabels) - 1), labels = confMatLabels)
cm <- yardstick::conf_mat(df_true_pred, true, pred)
suppressMessages(
cm_plot <- ggplot2::autoplot(cm, type = "heatmap") +
ggplot2::scale_fill_gradient(low="#D6EAF8", high = "#2E86C1")
)
plot_path <- paste0(getwd(), "/", run.name, ".png")
suppressMessages(ggplot2::ggsave(filename = plot_path, plot = cm_plot))
# convert saved image to array
np <- reticulate::import("numpy", convert = FALSE)
# python module pillow needs to be installed
PIL <- reticulate::import("PIL", convert = FALSE)
im <- np$asarray(PIL$Image$open(plot_path))
im_R <- reticulate::py_to_r(im)
im_R <- array(im_R, dim = c(1, dim(im_R)))
tensorflow::tf$summary$image(name = "confusion matrix", data = im_R/255, step = epoch)
file.writer$flush()
}
)
callbacks[[list_index]] <- confMat
list_index + 1
}
if (output$log){
callbacks[[list_index]] <- keras::callback_csv_logger(
paste0(run.name, "_log.csv"),
separator = ";",
append = TRUE)
}
# training
message("Start training ...")
history <-
model %>% keras::fit_generator(
generator = gen,
validation_data = gen.val,
validation_steps = ceiling(steps.per.epoch * validation.split),
steps_per_epoch = steps.per.epoch,
max_queue_size = max.queue.size,
epochs = epochs,
initial_epoch = initial_epoch,
callbacks = callbacks
)
} else {
message("Start training ...")
history <- model %>% keras::fit(
dataset$X,
dataset$Y,
batch_size = batch.size,
validation_split = validation.split,
epochs = epochs)
}
# save final model
message("Training done.\nSave model.")
if (output$serialize_model){
Rmodel <-
keras::serialize_model(model, include_optimizer = TRUE)
save(Rmodel, file = paste0(run.name, "_full_model.Rdata"))
}
if (output$full_model){
keras::save_model_hdf5(
model,
paste0(run.name, "_full_model.hdf5"),
overwrite = TRUE,
include_optimizer = TRUE
)
}
return(history)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.