R/multinom_EM.R

Defines functions multinom_EM

Documented in multinom_EM

#' An EM algorithm to fit a multinomial with maximum likelihood
#'
#' @param X A vector of commopent sizes
#' @param simMM A matrix of floats (n_cluster, n_cluster) for the similarity
#'   matrix between clusters. simMM[i,j] means the proportion of cluster i will
#'   be assigned to cluster j, hence colSums(simMM) are ones.
#' @param max_iter integer(1). number of maximum iterations
#' @param min_iter integer(1). number of minimum iterations
#' @param logLik_threshold A float. The threshold of logLikelihood increase for
#'   detecting convergence
#' @param verbose A logic. Whether to print the iteration times and log likelihood.
#'
#' @return a list containing \code{mu}, a vector for estimated latent proportion
#'   of each cluster, \code{logLik}, a float for the estimated log likelihood,
#'   \code{simMM}, the input of simMM, code{X}, the input of X, \code{X_prop},
#'   the proportion of clusters in the input X, \code{predict_X_prop}, and the
#'   predicted proportion of clusters based on mu and simMM.
#'
#' @export
#'
#' @examples
#' X = c(100, 300, 1500, 500, 1000)
#' simMM = create_simMat(5, confuse_rate=0.2)
#' multinom_EM(X, simMM)
#' 
multinom_EM <- function(X, simMM, min_iter=10, max_iter=1000,
                        logLik_threshold=1e-2, verbose=TRUE) {
    # Be very careful on the shape of simMM; rowSums(simMM) = 1
    K = ncol(simMM)

    # initialization
    mu = sample(K)
    mu = mu / sum(mu)
    Z = matrix(NA, K, K)
    logLik_old <- logLik_new <- log(mu %*% simMM) %*% X

    if (verbose) {
        print(paste("Iteration 0 logLik:", round(logLik_new, 3)))
    }

    for (it in seq_len(max_iter)) {
        ## E step: expectation of count from each component
        # Z = (simMM * mu)
        # Z = t(t(Z) / colSums(Z))

        for (i in seq(K)) {
            for (j in seq(K)){
                Z[i, j] = simMM[i, j] * mu[i] / sum(mu * simMM[, j])
            }
        }

        ## M step: maximizing likelihood
        # mu = c(X %*% Z) # this is wrong

        ## v2
        mu = c(Z %*% X)
        mu = mu / sum(mu)

        ## Check convergence
        logLik_new <- log(mu %*% simMM) %*% X
        # sum(X * log(mu %*% t(simMM)))
        if (it > min_iter && logLik_new - logLik_old < logLik_threshold) {
            break
        } else {
            logLik_old <- logLik_new
        }
        if (verbose) {
            print(paste("Iteration", it, "logLik:", round(logLik_new, 3)))
        }

    }

    ## return values
    list("mu" = mu, "logLik" = logLik_new,
         "simMM" = simMM, "X" = X, "X_prop" = X / sum(X),
         "predict_X_prop" = mu %*% simMM)
}

# matrix(seq(6), 2, 3) * c(1, 2)
huangyh09/DCATS documentation built on Nov. 25, 2022, 7:02 a.m.