#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu May 5 23:18:26 2022
@author: hill103
this script stores functions related to ADMM framework for fitting model
ADMM framework based on https://github.com/cvxgrp/strat_models
"""
import numpy as np
from time import time
from local_fit_numba import update_theta, adaptive_lasso
from scipy.sparse.linalg import cg
from scipy import sparse
from utils import reparameterTheta, reportRMSE
from config import min_theta, print
from local_fit_numba import generate_log_heavytail_array
[docs]
def one_admm_fit(data, L, theta, e_alpha, gamma_g, sigma2, lambda_r=1.0, lasso_weight=None,
hv_x=None, hv_log_p=None, theta_mask=None,
abs_tol=1e-3, rel_tol=1e-3,
rho=1, mu=10, tau_incr=2, tau_decr=2, max_rho=1e1, min_rho=1e-1,
maxiter=100, max_cg_iterations=10,
dynamic_rho=True, queue_len=3, diff_threshold=0.05, rho_incr=2, rho_decr=2, diff_scale=5, diff_stop=5e-5,
opt_method='L-BFGS-B', global_optimize=False, hybrid_version=True, verbose=False):
"""
perform a whole ADMM iterations once as one fitting procedure in GLRM
Note: the output of the fit result is 3-Dimensional. The dimension transform will be performed outside this function if needed
Parameters
----------
data : Dict
a Dict contains all info need for modeling:
X: a 2-D numpy matrix of celltype specific marker gene expression (celltypes * genes).\n
Y: a 2-D numpy matrix of spatial gene expression (spots * genes).\n
A: a 2-D numpy matrix of Adjacency matrix (spots * spots), or is None. Adjacency matrix of spatial sptots (1: connected / 0: disconnected). All 0 in diagonal.\n
N: a 1-D numpy array of sequencing depth of all spots (length #spots). If it's None, use sum of observed marker gene expressions as sequencing depth.\n
non_zero_mtx: If it's None, then do not filter zeros during regression. If it's a bool 2-D numpy matrix (spots * genes) as False means genes whose nUMI=0 while True means genes whose nUMI>0 in corresponding spots. The bool indicators can be calculated based on either observerd raw nUMI counts in spatial data, or CVAE transformed nUMI counts.\n
spot_names: a list of string of spot barcodes. Only keep spots passed filtering.\n
gene_names: a list of string of gene symbols. Only keep actually used marker genes.\n
celltype_names: a list of string of celltype names.\n
initial_guess: initial guess of cell-type proportions of spatial spots.
L : scipy sparse matrix (spots * spots)
Laplacian matrix
theta : 3-D numpy array (spots * celltypes * 1)
initial guess of theta (celltype proportion).
e_alpha : 1-D numpy array
initial guess of e_alpha (spot-specific effect).
gamma_g : 1-D numpy array
gene-specific platform effect for all genes.
sigma2 : float
initial guess of variance paramter of the lognormal distribution of ln(lambda). All genes and spots share the same variance.
it may not be updated during the ADMM iterations, i.e. sigma2 is treated as an already optimized value.
lambda_r : float, optional
strength for Adaptive Lasso penalty. The default is 1.0.
lasso_weight : 3-D numpy array (spots * celltypes * 1), optional
calculated weight for adaptive lasso. The default is None.
hv_x : 1-D numpy array, optional
data points served as x for calculation of probability density values. Only used for heavy-tail.
hv_log_p : 1-D numpy array, optional
log density values of normal distribution N(0, sigma^2) + heavy-tail. Only used for heavy-tail.
theta_mask : 3-D numpy array (spots * celltypes * 1), optional
mask for cell-type proportions (1: present, 0: not present). Only used for stage 2 theta optmization.
abs_tol : float, optional
Absolute tolerance. The default is 1e-3.
rel_tol : float, optional
Relative tolerance. The default is 1e-3.
rho : float, optional
Initial penalty parameter. The default is 1. Actually used in code is 1/rho, and turned out to fasten the converge than using rho
mu : float, optional
Adaptive penalty parameters. The default is 10.
tau_incr : float, optional
Adaptive penalty parameters. The default is 2.
tau_decr : float, optional
Adaptive penalty parameters. The default is 2.
max_rho : float, optional
Adaptive penalty parameters. The default is 1e1.
min_rho : float, optional
Adaptive penalty parameters. The default is 1e-1.
maxiter : int, optional
Maximum number of ADMM iterations. The default is 100.
max_cg_iterations : int, optional
Max number of CG iterations for graph Laplacian contrain per ADMM iteration. The default is 10.
dynamic_rho : bool, optional
if True, dynamically increasing min_rho and max_rho. The default is True.
queue_len : int, optional
the length of queue to record the mean of theta-theta_tilde 'RMSE' and theta-theta_hat 'RMSE'. The default is 3.
diff_threshold : float, optional
the threshold of 'RMSE change' for rho adjustment. If all 'RMSE change' values in queue are less than or equal the threshold, then increasing min_rho and max_rho. The default is 0.05.
rho_incr : float, optional
the multiplier of min_rho and max_rho increasing. The default is 2.
rho_decr : float, optional
the divider of min_rho and max_rho decreasing. The default is 2.
diff_scale : float, optional
if current theta tilde and hat 'RMSE' > diff_scale * previous 'RMSE', which means current rho value is too large and cause unexpected oscillation, then decreasing min_rho and max_rho. The default is 5.
diff_stop : float, optional
if average of theta tilde and hat 'RMSE' <= diff_stop, stop ADMM iteration. The default is 5e-5.
opt_method : string, optional
specify method used in scipy.optimize.minimize for local model fitting. The default is 'L-BFGS-B', a default method in scipy for optimization with bounds. Another choice would be 'SLSQP', a default method in scipy for optimization with constrains and bounds.
global_optimize : bool, optional
if is True, use basin-hopping algorithm to find the global minimum. The default is False.
hybrid_version : bool, optional
if True, use the hybrid_version of GLRM, i.e. in ADMM local model loss function optimization for w but adaptive lasso constrain on theta. If False, local model loss function optimization and adaptive lasso will on the same w. The default is True.
verbose : bool, optional
if True, print information in each ADMM loop
Returns
-------
Dict
estimated model coefficients, including:
theta : celltype proportions (#spots * #celltypes * 1)\n
e_alpha : spot-specific effect (1-D array with length #spot)\n
sigma2 : variance paramter of the lognormal distribution (float)\n
gamma_g : gene-specific platform effect for all genes (1-D array with length #gene)\n
theta_tilde : celltype proportions for Adaptive Lasso (#spots * #celltypes * 1)\n
theta_hat : celltype proportions for Graph Laplacian constrain (#spots * #celltypes * 1)
"""
n_celltype = data['X'].shape[0]
n_spot = data['Y'].shape[0]
one_theta_shape = (n_celltype, 1)
start_time = time()
optimal = False
# initialize x array for calculation of heavy-tail
if hv_x is None:
from local_fit_numba import z_hv
hv_x = z_hv.copy()
if hv_log_p is None:
# initialize density values of heavy-tail with initial sigma^2
hv_log_p = generate_log_heavytail_array(hv_x, np.sqrt(sigma2))
# if theta_mask is not None, then the theta and e_alpha are already pre-processed with theta_mask
# theta_tilde for adaptive lasso
theta_tilde = theta.copy()
# theta_hat for graph laplacian constrain
theta_hat = theta.copy()
u = np.zeros(theta.shape)
u_tilde = np.zeros(theta.shape)
res_pri = np.zeros(theta.shape)
res_pri_tilde = np.zeros(theta.shape)
res_dual = np.zeros(theta.shape)
res_dual_tilde = np.zeros(theta.shape)
if dynamic_rho:
#print('[CAUTION] dynamic rho trick is turned on!')
rmse_queue = []
pre_tilde_rmse = 0
pre_hat_rmse = 0
if verbose:
print('\nstart ADMM iteration...')
if verbose:
print(f'{"iter" : >6} | {"res_pri_n": >10} | {"res_dual_n": >10} | {"eps_pri": >10} | {"eps_dual": >10} | {"rho": >10} | {"new_rho": >10} | {"time_opt": >8} | {"time_reg": >8} | {"time_lap": >8} | {"tilde_RMSE": >10} | {"hat_RMSE": >10}')
# Main ADMM loop
for t in range(maxiter):
# theta update
tmp_start = time()
# use the previous theta as the warm start for next iteration
if hybrid_version:
# ADMM loss function on theta
theta, e_alpha = update_theta(data, theta, e_alpha, gamma_g, sigma2, theta_hat-u, 1./rho, global_optimize=global_optimize, hybrid_version=hybrid_version, opt_method=opt_method, hv_x=hv_x, hv_log_p=hv_log_p, theta_mask=theta_mask, verbose=False)
else:
# ADMM loss function on w
theta, e_alpha = update_theta(data, theta, e_alpha, gamma_g, sigma2,
reparameterTheta(theta_hat-u, e_alpha), 1./rho,
global_optimize=global_optimize, hybrid_version=hybrid_version, opt_method=opt_method, hv_x=hv_x, hv_log_p=hv_log_p, theta_mask=theta_mask, verbose=False)
# theta has already been masked inside update_theta function
time_local_opt = time() - tmp_start
# in 2-stage implement, NO sigma2 update step in ADMM iterations
# update theta_tilde
tmp_start = time()
if hybrid_version:
# ADMM loss function on theta
theta_tilde = adaptive_lasso(theta_hat-u_tilde, 1./rho, lambda_r, lasso_weight)
else:
# ADMM loss function on w
theta_tilde = reparameterTheta(adaptive_lasso(reparameterTheta(theta_hat-u_tilde, e_alpha), 1./rho, lambda_r, lasso_weight),
1.0/e_alpha)
# mask theta tilde
if not theta_mask is None:
theta_tilde[theta_mask==0] = 0
time_reg = time() - tmp_start
# theta_hat update by Laplacian constrain
tmp_start = time()
# put the constrain for each element in the theta_hat
# each iteration deal with one element but across all spots
sys = L + 2 * rho * sparse.eye(n_spot)
M = sparse.diags(1. / sys.diagonal())
indices = np.ndindex(one_theta_shape)
rhs = rho * (theta.T + u.T + theta_tilde.T + u_tilde.T)
for i, ind in enumerate(indices):
index = ind[::-1]
# Use Conjugate Gradient iteration to solve Ax = b.
# M: Preconditioner for A. The preconditioner should approximate the inverse of A.
sol = cg(sys, rhs[index], M=M,
x0=theta_hat.T[index], maxiter=max_cg_iterations)[0]
res_dual.T[index] = -rho * (sol - theta_hat.T[index])
res_dual_tilde.T[index] = res_dual.T[index]
theta_hat.T[index] = sol
# avoid negative values
theta_hat[theta_hat<min_theta] = min_theta
# mask theta hat
if not theta_mask is None:
theta_hat[theta_mask==0] = 0
time_graph = time() - tmp_start
# difference between theta, theta_tilde, theta_hat
tilde_rmse = reportRMSE(np.squeeze(theta), np.squeeze(theta_tilde))
hat_rmse = reportRMSE(np.squeeze(theta), np.squeeze(theta_hat))
# u and u_tilde update
res_pri = theta - theta_hat
res_pri_tilde = theta_tilde - theta_hat
u += theta - theta_hat
u_tilde += theta_tilde - theta_hat
# calculate residual norms
# np.append by default will combine two input and flatten the output as array
# np.linalg.norm by default calculate 2-norm for array
res_pri_norm = np.linalg.norm(np.append(res_pri, res_pri_tilde))
res_dual_norm = np.linalg.norm(np.append(res_dual, res_dual_tilde))
eps_pri = np.sqrt(2 * n_spot * np.prod(one_theta_shape)) * abs_tol + \
rel_tol * max(res_pri_norm, res_dual_norm)
eps_dual = np.sqrt(2 * n_spot * np.prod(one_theta_shape)) * abs_tol + \
rel_tol * np.linalg.norm(rho * np.append(u, u_tilde))
if dynamic_rho:
# use theta_tilde+theta_hat RMSE or res_pri_norm
rmse_queue.append((abs(tilde_rmse-pre_tilde_rmse) + abs(hat_rmse-pre_hat_rmse))/2.0)
if len(rmse_queue) > queue_len:
rmse_queue.pop(0)
# check stopping condition
if res_pri_norm <= eps_pri and res_dual_norm <= eps_dual:
optimal = True
if verbose:
print(f'{t : >6} | {res_pri_norm:10.3f} | {res_dual_norm:10.3f} | {eps_pri:10.3f} | {eps_dual:10.3f} | {rho:10.2f} | {"/" : >10} | {time_local_opt:8.3f} | {time_reg:8.3f} | {time_graph:8.3f} | {tilde_rmse:10.6f} | {hat_rmse:10.6f}')
break
# if the change of theta_tilde and theta_hat "RMSE" are small, early stop iteration
if (tilde_rmse+hat_rmse)/2 <= diff_stop:
optimal = True
if verbose:
print(f'{t : >6} | {res_pri_norm:10.3f} | {res_dual_norm:10.3f} | {eps_pri:10.3f} | {eps_dual:10.3f} | {rho:10.2f} | {"/" : >10} | {time_local_opt:8.3f} | {time_reg:8.3f} | {time_graph:8.3f} | {tilde_rmse:10.6f} | {hat_rmse:10.6f}')
print('early stop!')
break
# dynamically adjust min_rho and max_rho
# insert a large value into the queue after each adjustment of rho to force several iterations with adjusted rho and without immediate further rho adjustment
if dynamic_rho:
# first check whether rho is too large and cause oscillation after first several iterations
if (tilde_rmse+hat_rmse)/2 > diff_scale * (pre_tilde_rmse+pre_hat_rmse)/2 and t>=2:
min_rho /= rho_decr
max_rho /= rho_decr
rmse_queue.append(1)
if len(rmse_queue) > queue_len:
rmse_queue.pop(0)
#if verbose:
#print('dynamic rho trick: decreasing rho in next ADMM iteration!')
# then check whether need to increase rho
elif all(num <= diff_threshold for num in rmse_queue) and t>=1:
if min_rho < rho:
min_rho = rho
min_rho *= rho_incr
max_rho *= rho_incr
rmse_queue.append(1)
if len(rmse_queue) > queue_len:
rmse_queue.pop(0)
#if verbose:
#print('dynamic rho trick: increasing rho in next ADMM iteration!')
pre_tilde_rmse = tilde_rmse
pre_hat_rmse = hat_rmse
# penalty parameter update
new_rho = rho
if res_pri_norm > mu * res_dual_norm:
new_rho = tau_incr * rho
elif res_dual_norm > mu * res_pri_norm:
new_rho = rho / tau_decr
new_rho = np.clip(new_rho, min_rho, max_rho)
if verbose:
print(f'{t : >6} | {res_pri_norm:10.3f} | {res_dual_norm:10.3f} | {eps_pri:10.3f} | {eps_dual:10.3f} | {rho:10.2f} | {new_rho:10.2f} | {time_local_opt:8.3f} | {time_reg:8.3f} | {time_graph:8.3f} | {tilde_rmse:10.6f} | {hat_rmse:10.6f}')
u *= rho / new_rho
u_tilde *= rho / new_rho
rho = new_rho
# ADMM loop finished
if verbose:
if optimal:
print(f"Terminated (optimal) in {t+1} iterations.")
else:
print("Terminated (reached max iterations).")
# construct result, DO NOT change theta back to 2-D array
# the dimension transforming is performed outside this function is needed
result = {
'theta': theta,
'theta_tilde': theta_tilde,
'theta_hat': theta_hat,
'e_alpha': e_alpha,
'sigma2': sigma2,
'gamma_g': gamma_g
}
if verbose:
print(f'One optimization by ADMM finished. Elapsed time: {(time()-start_time)/60.0:.2f} minutes.\n')
return result