R/translate_model.R

Defines functions gpuToCpuModel

Documented in gpuToCpuModel

#' GPU to CPU model
#' 
#' Takes trained model and replaces CuDNNLSTM layers with LSTM layers, preserves weights.
#' 
#' @param model.path Path to gpu model.
#' @param verbose Whether to print model. 
#' @export
gpuToCpuModel <- function(model.path, verbose = TRUE){
  gpu_model <- keras::load_model_hdf5(model.path)
  if (verbose) cat("gpu model: \n") ; print(gpu_model)
  maxlen <- gpu_model$input$shape[1] 
  vocabulary.size <- gpu_model$input$shape[2] 
  layerList <- keras::get_config(gpu_model)["layers"]
  cpu_model <- keras::keras_model_sequential()
  for (i in 1:length(layerList)){
    layer_class_name <- layerList[[i]]$class_name
    
    if (layer_class_name == "Conv1D"){
      cpu_model %>%  keras::layer_conv_1d(
        kernel_size = 3,
        padding = "same",
        activation = "relu",
        filters = 81,
        input_shape = c(maxlen, vocabulary.size)
      )
    }
    
    if (layer_class_name == "MaxPooling1D"){
      cpu_model %>%  keras::layer_max_pooling_1d(pool_size = 3)  
    }
    
    if (layer_class_name == "BatchNormalization") {
      cpu_model %>%  keras::layer_batch_normalization(momentum = .8)
    }
    
    if (layer_class_name == "CuDNNLSTM"){
      cpu_model %>% keras::layer_lstm(
        units = layerList[[i]]$config$units,
        input_shape = c(maxlen, vocabulary.size),
        return_sequences = layerList[[i]]$config$return_sequences
      )
    }
    
    if (layer_class_name == "Bidirectional"){
      cpu_model %>%
        keras::bidirectional(
          input_shape = c(maxlen, vocabulary.size),
          keras::layer_lstm(
            units = layerList[[i]]$config$layer$config$units,
            return_sequences = layerList[[i]]$config$return_sequences
          ) 
        )
    }
    
    if (layer_class_name == "Dense"){
      cpu_model %>% keras::layer_dense(layerList[[i]]$config$units) 
    }
    
    if (layer_class_name == "Activation"){
      cpu_model %>%  keras::layer_activation("softmax")
    }
  }
  cpu_model %>% keras::load_model_weights_tf(filepath = model.path)
  if (verbose) cat("cpu model: \n") ; print(cpu_model)
  return(cpu_model)
}
hiddengenome/altum documentation built on April 22, 2020, 9:33 p.m.