Source code for cvae

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu May 19 21:43:08 2022

@author: hill103

this script stores functions to build a CVAE for platform effect adjustment
"""



import os
from config import print, diagnosis_path
import numpy as np
import pandas as pd
import umap
from utils import read_spatial_data, read_scRNA_data, run_DE
from time import time
import random
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import MinMaxScaler
import scanpy as sc
sc.settings.verbosity = 0  # verbosity: errors (0), warnings (1), info (2), hints (3)

import tensorflow.keras.backend as K
from tensorflow.keras.layers import Lambda, Input, Dense, Activation, concatenate
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.utils import set_random_seed

# dealing with the keras symbolic tensor error
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()



[docs] def celltype2props(celltype_anno, celltype_order): ''' calculate cell-type proportions matrix given one cell-type annotation Parameters ---------- celltype_anno : dataframe cell-type annotations. Only 1 column named <celltype> celltype_order : list already sorted unique cell-types. Its order matters, and will be the order in cell-type proportions (columns) and cell-type gene expression profile (rows) Returns ------- celltype_prop : dataframe cell-type proportions, columns are already sorted cell-types ''' celltype_stats = [] for i in celltype_anno.index: celltype_stats.append({celltype_anno.loc[i, 'celltype']: 1}) # transform to matrix celltype_prop = pd.DataFrame(celltype_stats, columns=celltype_order, index=celltype_anno.index) celltype_prop.fillna(0, inplace=True) # calculate cell-type proportions, divides each element in a row by the sum of that row celltype_prop = celltype_prop.div(celltype_prop.sum(axis=1), axis=0) return celltype_prop
[docs] def transferProps(query, ref, ref_props, n_neighbors=10, sigma=1, use_embedding='PCA', pca_dimension=None): ''' transfer cell-type proportions by select K Nearest Neighbors in ref and take Gaussian weighted average of ref proportions Parameters ---------- query : 2-D numpy matrix encoder embeddings of spatial spots (spots * latent layer neurons). ref : 2-D numpy matrix encoder embeddings of scRNA-seq cells and pseudo-spots in scRNA-seq condition (cells+pseudo-spots * latent layer neurons). ref_props : 2-D numpy matrix cell-type proportion matrix of scRNA-seq cells and pseudo-spots (cells+pseudo-spots * cell-types). n_neighbors : int, optional Number of neighbors to use. The default is 10. sigma : float, optional Standard deviation for the Gaussian weighting function. The default is 1. use_embedding : str, optional which embedding to use, either PCA, UMAP or none. The default is PCA. pca_dimension : int, optinal specify the number of dimensions for PCA reduction. If set to None, the reduced dimension will be one-third of the input dimension. Returns ------- query_props : 2-D numpy matrix cell-type proportion matrix for spatial spots. ''' assert query.shape[1] == ref.shape[1] n_celltype = ref_props.shape[1] if (query.shape[1]<=2) and (use_embedding!='none'): print(f'[WARNING] original latent space dimension {query.shape[1]} <= 2, no need to use {use_embedding} embedding!') use_embedding = 'none' if use_embedding == 'PCA': # first take a PCA to avoid Curse of Dimensionality # we perform PCA without any normalization and scaling, and reduce the dimensionality to one-third of the original dimensions if pca_dimension is None: # orginal dimension: 3*#cell-types, reduced dimension: #cell-types reduced_pca_dimension = int(query.shape[1] / 3) else: reduced_pca_dimension = int(pca_dimension) principal_components = PCA(n_components=reduced_pca_dimension).fit_transform(np.vstack((query, ref))) # Split the principal components back into query and ref query_pc = principal_components[:query.shape[0], :] ref_pc = principal_components[query.shape[0]:, :] elif use_embedding == 'UMAP': all_umap = umap.UMAP(random_state=42).fit_transform(np.vstack((query, ref))) # Split the principal components back into query and ref query_pc = all_umap[:query.shape[0], :] ref_pc = all_umap[query.shape[0]:, :] elif use_embedding == 'none': query_pc = query ref_pc = ref else: raise Exception(f'unknow embedding {use_embedding}') print(f'embedding dimension: {query_pc.shape[1]}') # perform KNN on query data on reduced dimension nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='auto').fit(ref_pc) # find nearest neighbors distances, indices = nbrs.kneighbors(query_pc) # Calculate Gaussian weighted averages query_props = np.zeros((query.shape[0], ref_props.shape[1])) for i, (dists, inds) in enumerate(zip(distances, indices)): # Compute Gaussian weights weights = np.exp(-dists**2 / (2 * sigma**2)) weights /= weights.sum() # Multiply weights with corresponding rows in ref_props (element-wise multiplication) # first inserts a new axis along the second dimension (column) of weights, changing the shape of weights from (K,) to (K, 1). # then perform multiplication with shape (K, m) * (K, 1), numpy broadcasting expands weights along the column dimension (1 dimension), matching the column size of ref_props, and perform element-wise multiplication weighted_props = ref_props[inds] * weights[:, np.newaxis] # calculate weighted average for the current query row; axis=0 sum across columns avg_props = np.sum(weighted_props, axis=0) # normalize the proportions to sum to 1 query_props[i] = avg_props / np.sum(avg_props) # NOTE if all proportions are 0 due to very small weights, the initial guess will be all NaN if pd.isnull(query_props[i]).any(): # replace it as a vector with all elements identical query_props[i, :] = np.full((n_celltype,), 1.0/n_celltype) continue # post-process theta to set theta<0.01 as 0 then re-normalize remaining theta to sum to 1 tmp_ind = query_props[i, :] < 0.01 if tmp_ind.all(): # all elements < threashold, just leave it unchanged continue if tmp_ind.any(): query_props[i, tmp_ind] = 0 query_props[i, :] = query_props[i, :] / np.sum(query_props[i, :]) return query_props
[docs] def generate_pseudo_spots(exp_df, celltype_anno, n_spot, celltype_order, pseudo_spot_min_cell, pseudo_spot_max_cell): ''' generate pseudo-spots for CVAE training by randomly combining scRNA-seq cells UPDATE: now we separate the pseudo-spots and all scRNA-seq cells, i.e. we DO NOT add all cells to the end of the dataframe as special pseudo-spots with only one cell if n_spot=0, i.e. no pseudo-spots, then return an empty dataframe Parameters ---------- exp_df : dataframe normalized gene expression (cell * genes) celltype_anno : dataframe or None cell-type annotations for cells in scRNA-seq data. Only 1 column named <celltype> n_spot : int number of pseudo spots need to be generated, including training and validation set celltype_order : list already sorted unique cell-types. Its order matters, and will be the order in cell-type proportions (columns) and cell-type gene expression profile (rows) pseudo_spot_min_cell : int minimum value of cells in pseudo-spot pseudo_spot_max_cell : int maximum value of cells in pseudo-spot Returns ------- pseudo_spots_df : dataframe pseudo-spot gene expression (pseudo-spots * genes; NO original cells included) pseudo_spots_celltype_prop : dataframe pseudo-spot cell-type proportions (pseudo-spots * cell-types; NO original cells included) n_cell_in_spot : list number of cells/spots in pseudo-spots (NO original cells included) ''' if n_spot == 0: return pd.DataFrame(columns=exp_df.columns), pd.DataFrame(columns=celltype_order), [] pseudo_spots = [] celltype_stats = [] n_cell_in_spot = [] n_cell_list = list(range(pseudo_spot_min_cell, pseudo_spot_max_cell+1)) # cell barcode separated by cell-types type_cell_index = dict() for one_celltype in celltype_order: type_cell_index[one_celltype] = celltype_anno[celltype_anno['celltype']==one_celltype].index.to_list() # though it's possible to use multiprocessing to generate pseudo-spots parallelly, the big dataframe need to be shared across all subprocesses, and it may not be a good idea to share objects in multiprocessing as it may cause unknown problems. And the performance benefits by multiprocessing many not be such large # so considering the safety and performance benefits, just keep the simplest way to generate pseudo-spots one-by-one # to reduce randomness, pre-set the seed value for random random.seed(138) # Set to track the milestones to print milestones = set(round((n_spot / 10) * i) for i in range(1, 11)) # Ensure that no milestone can exceed the total number of spots milestones = {m for m in milestones if m <= n_spot} for i in range(n_spot): if i+1 in milestones: # Calculate the current percentage, ensure it does not exceed 100% cur_progress = min(round((i+1) / n_spot, 2), 1) print(f'{cur_progress:.0%}...', end='') # first determine how many cells in this pseudo-spot this_num = random.sample(n_cell_list, 1)[0] n_cell_in_spot.append(this_num) this_cells = [] for j in range(this_num): # select one cell-type selected_celltype = random.sample(celltype_order, 1)[0] # from this selected cell-type, randomly select one cell belong to that cell-type this_cells.append(random.sample(type_cell_index[selected_celltype], 1)[0]) # take average of selected cells pseudo_spots.append(exp_df.loc[this_cells].mean(axis=0)) # count the celltype of selected cells celltype_stats.append(celltype_anno.loc[this_cells, 'celltype'].value_counts().to_dict()) # make below prints in a newline print('\n') # Build pseudo-spots dataframe # First n_valid_spot spots are used for validation, rest spots are used for training pseudo_spots_df = pd.concat(pseudo_spots, axis=1).transpose() pseudo_spots_df.reset_index(inplace=True, drop=True) pseudo_spots_df.index = ['scrna_pseudo' + str(idx) for idx in pseudo_spots_df.index] pseudo_spots_celltype_prop = pd.DataFrame(celltype_stats, columns=celltype_order, index=pseudo_spots_df.index) pseudo_spots_celltype_prop.fillna(0, inplace=True) # calculate cell-type proportions pseudo_spots_celltype_prop = pseudo_spots_celltype_prop.div(pseudo_spots_celltype_prop.sum(axis=1), axis=0) #import gc #import psutil #print(f'before gc and del variable RAM usage: {psutil.Process().memory_info().rss/1024**2:.2f} MB') #del pseudo_spots, celltype_stats #print(f'del variable without gc RAM usage: {psutil.Process().memory_info().rss/1024**2:.2f} MB') #gc.collect() #print(f'after gc RAM usage: {psutil.Process().memory_info().rss/1024**2:.2f} MB') return pseudo_spots_df, pseudo_spots_celltype_prop, n_cell_in_spot
[docs] def combine_spatial_spots(exp_df, n_spot, pseudo_spot_min_cell, pseudo_spot_max_cell): ''' we also generate "pseudo-spots" by combining spatial spots here we do not know the true cell-type proportions in spatial spots, so we also do not know the proportions in generated "pseudo-spots". We just ignore it, and only use the expressions for CVAE training Parameters ---------- exp_df : dataframe normalized gene expression (spots * genes) n_spot : int number of pseudo spots need to be generated, including training and validation set pseudo_spot_min_cell : int minimum value of cells in pseudo-spot pseudo_spot_max_cell : int maximum value of cells in pseudo-spot Returns ------- pseudo_spots_df : dataframe pseudo-spot gene expression (pseudo-spots * genes; NO original cells included) ''' if n_spot == 0: return pd.DataFrame(columns=exp_df.columns) pseudo_spots = [] n_cell_list = list(range(pseudo_spot_min_cell, pseudo_spot_max_cell+1)) spot_index = exp_df.index.to_list() # to reduce randomness, pre-set the seed value for random random.seed(154) # Set to track the milestones to print milestones = set(round((n_spot / 10) * i) for i in range(1, 11)) # Ensure that no milestone can exceed the total number of spots milestones = {m for m in milestones if m <= n_spot} for i in range(n_spot): if i+1 in milestones: # Calculate the current percentage, ensure it does not exceed 100% cur_progress = min(round((i+1) / n_spot, 2), 1) print(f'{cur_progress:.0%}...', end='') # first determine how many spots in this pseudo-spot this_num = random.sample(n_cell_list, 1)[0] # then randomly select these number of spots and take average pseudo_spots.append(exp_df.loc[random.sample(spot_index, this_num)].mean(axis=0)) # make below prints in a newline print('\n') # Build pseudo-spots dataframe pseudo_spots_df = pd.concat(pseudo_spots, axis=1).transpose() pseudo_spots_df.reset_index(inplace=True, drop=True) pseudo_spots_df.index = ['spatial_pseudo' + str(idx) for idx in pseudo_spots_df.index] return pseudo_spots_df
[docs] def augment_sc(exp_df, celltype_anno, target_count, pseudo_spot_min_cell, pseudo_spot_max_cell): ''' augment single cells and balance #cells of cell types generate pseudo-spots by randomly combining scRNA-seq cells within the same cell type original single cells are put at the end Parameters ---------- exp_df : dataframe normalized gene expression (cell * genes) celltype_anno : dataframe or None cell-type annotations for cells in scRNA-seq data. Only 1 column named <celltype> target_count : int target number of cells per cell type pseudo_spot_min_cell : int minimum value of cells in pseudo-spot pseudo_spot_max_cell : int maximum value of cells in pseudo-spot Returns ------- pseudo_spots_df : dataframe pseudo-spot gene expression (pseudo-spots * genes; original cells included first) pseudo_spots_celltype_prop : dataframe pseudo-spot cell-type proportions (pseudo-spots * cell-types; original cells included first) n_cell_in_spot : list number of cells/spots in pseudo-spots (original cells included first) ''' pseudo_spots = [] celltype_stats = [] n_cell_in_spot = [] n_cell_list = list(range(pseudo_spot_min_cell, pseudo_spot_max_cell+1)) n_cell_per_ct = celltype_anno.celltype.value_counts().to_dict() all_cts = n_cell_per_ct.keys() # cell barcode separated by cell-types type_cell_index = dict() for one_celltype in all_cts: type_cell_index[one_celltype] = celltype_anno[celltype_anno['celltype']==one_celltype].index.to_list() # though it's possible to use multiprocessing to generate pseudo-spots parallelly, the big dataframe need to be shared across all subprocesses, and it may not be a good idea to share objects in multiprocessing as it may cause unknown problems. And the performance benefits by multiprocessing many not be such large # so considering the safety and performance benefits, just keep the simplest way to generate pseudo-spots one-by-one # to reduce randomness, pre-set the seed value for random random.seed(169) for i, one_celltype in enumerate(all_cts): print(f'{i/len(all_cts):.0%}...', end='') for j in range(target_count-n_cell_per_ct[one_celltype]): # first determine how many cells in this pseudo-spot this_num = random.sample(n_cell_list, 1)[0] n_cell_in_spot.append(this_num) # from this selected cell-type, randomly select needed cells belong to that cell-type if n_cell_per_ct[one_celltype] < this_num: # sample with replace this_cells = random.choices(type_cell_index[one_celltype], k=this_num) else: # sample without replace this_cells = random.sample(type_cell_index[one_celltype], this_num) # take average of selected cells pseudo_spots.append(exp_df.loc[this_cells].mean(axis=0)) celltype_stats.append(one_celltype) # make below prints in a newline print('100%') # Build pseudo-spots dataframe # last X spots are used for validation, rest spots are used for training pseudo_spots_df = pd.concat(pseudo_spots, axis=1).transpose() pseudo_spots_df.reset_index(inplace=True, drop=True) pseudo_spots_df.index = ['scrna_augment' + str(idx) for idx in pseudo_spots_df.index] pseudo_spots_celltypes = pd.DataFrame(celltype_stats, columns=['celltype'], index=pseudo_spots_df.index) n_cell_df = pd.DataFrame(n_cell_in_spot, columns=['ncell'], index=pseudo_spots_df.index) # shuffle all rows tmp_index = pseudo_spots_df.index.to_list() random.shuffle(tmp_index) pseudo_spots_df = pseudo_spots_df.loc[tmp_index].copy() pseudo_spots_celltypes = pseudo_spots_celltypes.loc[tmp_index].copy() n_cell_df = n_cell_df.loc[tmp_index].copy() # combine original single cells at the end combined_exp = pd.concat([pseudo_spots_df, exp_df], axis=0) combined_ct = pd.concat([pseudo_spots_celltypes, celltype_anno], axis=0) combined_n_cell = n_cell_df.ncell.to_list() + [1] * exp_df.shape[0] return combined_exp, combined_ct, combined_n_cell
[docs] def CVAE_keras_model(p, p_cond, latent_dim, p_encoder_lst, p_decoder_lst, hidden_act='elu', output_act='relu', use_batch_norm=True, cvae_init_lr=0.01): ''' define a standard CVAE model based on Keras need to build a decoder separately as can not extract it from the whole model Parameters ---------- p : int number of nodes in input layer p_cond : int number of conditional nodes in input layer latent_dim : int number of nodes in latent space p_encoder_lst : list of integers including number of nodes in each hidden layer of encoder, the length of list is the number of hidden layers p_decoder_lst : list of integers including number of nodes in each hidden layer of decoder, the length of list is the number of hidden layers hidden_act : string, optional activation function of hidden layers. Default is elu function output_act : string, optional activation function of output layer. Default is relu function use_batch_norm : bool, optional whether to use batch normalization. Default if True, i.e. use batch normalization cvae_init_lr : float initial learning rate for training CVAE Returns ------- cvae : Keras model CVAE model. Encoder can be extracted from it decoder : Keras model corresponding decoder, reset its weights after CVAE training ''' # Functions for CVAE # reparameterization trick # instead of sampling from Q(z|X), sample eps = N(0,I) # z = z_mean + sqrt(var)*eps def sampling(args): """Reparameterization trick by sampling fr an isotropic unit Gaussian. # Arguments: args (tensor): mean and log of variance of Q(z|X) # Returns: z (tensor): sampled latent vector """ z_mean, z_log_var = args batch = K.shape(z_mean)[0] dim = K.int_shape(z_mean)[1] # by default, random_normal has mean=0 and std=1.0 epsilon = K.random_normal(shape=(batch, dim), mean=0., stddev=1) return z_mean + K.exp(0.5 * z_log_var) * epsilon # the Keras framework support custom loss returning one value, but more correct way is returning an array of losses (one of sample in the input batch), and the reducing the done by Keras # when return an array of losses, we can also handle the specified sample_weight def KL_loss(obs, pred): kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var) kl_loss = K.sum(kl_loss, axis=-1) kl_loss *= -0.5 return kl_loss def reconstruction_loss(obs, pred): return K.sum(K.square(obs - pred), axis=-1) def vae_loss(obs, pred, C=1): # currently we set weight C=1 return C*reconstruction_loss(obs, pred) + KL_loss(obs, pred) # Build encoder model X = Input(shape=(p,), name='encoder_input') cond = Input(shape=(p_cond,), name='cond_input') encoder_inputs = concatenate([X, cond]) # add hidden layers for i, num in enumerate(p_encoder_lst): if i == 0: # 1st hidden layer if use_batch_norm: encoder_hidden = Dense(num, use_bias=False, name=f'encoder_layer{i}_w')(encoder_inputs) encoder_hidden = BatchNormalization(name=f'encoder_layer{i}_BN')(encoder_hidden) encoder_hidden = Activation(hidden_act, name=f'encoder_layer{i}_act')(encoder_hidden) else: encoder_hidden = Dense(num, use_bias=True, name=f'encoder_layer{i}_w')(encoder_inputs) encoder_hidden = Activation(hidden_act, name=f'encoder_layer{i}_act')(encoder_hidden) else: if use_batch_norm: encoder_hidden = Dense(num, use_bias=False, name=f'encoder_layer{i}_w')(encoder_hidden) encoder_hidden = BatchNormalization(name=f'encoder_layer{i}_BN')(encoder_hidden) encoder_hidden = Activation(hidden_act, name=f'encoder_layer{i}_act')(encoder_hidden) else: encoder_hidden = Dense(num, use_bias=True, name=f'encoder_layer{i}_w')(encoder_hidden) encoder_hidden = Activation(hidden_act, name=f'encoder_layer{i}_act')(encoder_hidden) # latent layer of z_mean and z_log_var if use_batch_norm: z_mean_pre = Dense(latent_dim, use_bias=False)(encoder_hidden) z_mean_pre = BatchNormalization()(z_mean_pre) z_mean = Activation('linear', name='z_mean')(z_mean_pre) z_log_var_pre = Dense(latent_dim, use_bias=False)(encoder_hidden) z_log_var_pre = BatchNormalization()(z_log_var_pre) z_log_var = Activation('linear', name='z_log_var')(z_log_var_pre) else: z_mean_pre = Dense(latent_dim, use_bias=True)(encoder_hidden) z_mean = Activation('linear', name='z_mean')(z_mean_pre) z_log_var_pre = Dense(latent_dim, use_bias=True)(encoder_hidden) z_log_var = Activation('linear', name='z_log_var')(z_log_var_pre) # use reparameterization trick to push the sampling out as input # note that "output_shape" isn't necessary with the TensorFlow backend z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var]) # Build decoder model latent_plus_cond = concatenate([z, cond]) # add hidden layers for i, num in enumerate(p_decoder_lst): if i == 0: # 1st hidden layer if use_batch_norm: decoder_hidden = Dense(num, use_bias=False, name=f'decoder_layer{i}_w')(latent_plus_cond) decoder_hidden = BatchNormalization(name=f'decoder_layer{i}_BN')(decoder_hidden) decoder_hidden = Activation(hidden_act, name=f'decoder_layer{i}_act')(decoder_hidden) else: decoder_hidden = Dense(num, use_bias=True, name=f'decoder_layer{i}_w')(latent_plus_cond) decoder_hidden = Activation(hidden_act, name=f'decoder_layer{i}_act')(decoder_hidden) else: if use_batch_norm: decoder_hidden = Dense(num, use_bias=False, name=f'decoder_layer{i}_w')(decoder_hidden) decoder_hidden = BatchNormalization(name=f'decoder_layer{i}_BN')(decoder_hidden) decoder_hidden = Activation(hidden_act, name=f'decoder_layer{i}_act')(decoder_hidden) else: decoder_hidden = Dense(num, use_bias=True, name=f'decoder_layer{i}_w')(decoder_hidden) decoder_hidden = Activation(hidden_act, name=f'decoder_layer{i}_act')(decoder_hidden) # output layer if use_batch_norm: decoder_hidden = Dense(p, use_bias=False, name='decoder_output_w')(decoder_hidden) decoder_hidden = BatchNormalization(name='decoder_output_BN')(decoder_hidden) decoder_output = Activation(output_act, name='decoder_output_act')(decoder_hidden) else: decoder_hidden = Dense(p, use_bias=True, name='decoder_output_w')(decoder_hidden) decoder_output = Activation(output_act, name='decoder_output_act')(decoder_hidden) # CVAE model = encoder + decoder # by using the Keras functional API, the variables will be created right away without needing to call .build(). When not using API, you can manually call `model.build()` cvae = Model([X, cond], decoder_output, name='cvae') # Optimizer, use old optimizers in legacy namespace adam = optimizers.legacy.Adam(learning_rate=cvae_init_lr, clipnorm=1.0, decay=0.0) cvae.compile(optimizer=adam, loss=vae_loss, metrics=[reconstruction_loss, KL_loss], experimental_run_tf_function=True) # Subset the decoder (build another new decoder and re-store weights) def build_new_decoder(p, p_cond, latent_dim, p_decoder_lst, hidden_act='elu', output_act='relu', use_batch_norm=False): latent_inputs = Input(shape=(latent_dim,), name='z_sampling') cond_input = Input(shape=(p_cond,), name='conditions') latent_plus_cond = concatenate([latent_inputs, cond_input]) # add hidden layers for i, num in enumerate(p_decoder_lst): if i == 0: # 1st hidden layer if use_batch_norm: decoder_hidden = Dense(num, use_bias=False, name=f'decoder_layer{i}_w')(latent_plus_cond) decoder_hidden = BatchNormalization(name=f'decoder_layer{i}_BN')(decoder_hidden) decoder_hidden = Activation(hidden_act, name=f'decoder_layer{i}_act')(decoder_hidden) else: decoder_hidden = Dense(num, use_bias=True, name=f'decoder_layer{i}_w')(latent_plus_cond) decoder_hidden = Activation(hidden_act, name=f'decoder_layer{i}_act')(decoder_hidden) else: if use_batch_norm: decoder_hidden = Dense(num, use_bias=False, name=f'decoder_layer{i}_w')(decoder_hidden) decoder_hidden = BatchNormalization(name=f'decoder_layer{i}_BN')(decoder_hidden) decoder_hidden = Activation(hidden_act, name=f'decoder_layer{i}_act')(decoder_hidden) else: decoder_hidden = Dense(num, use_bias=True, name=f'decoder_layer{i}_w')(decoder_hidden) decoder_hidden = Activation(hidden_act, name=f'decoder_layer{i}_act')(decoder_hidden) # output layer if use_batch_norm: decoder_hidden = Dense(p, use_bias=False, name='decoder_output_w')(decoder_hidden) decoder_hidden = BatchNormalization(name='decoder_output_BN')(decoder_hidden) decoder_output = Activation(output_act, name='decoder_output_act')(decoder_hidden) else: decoder_hidden = Dense(p, use_bias=True, name='decoder_output_w')(decoder_hidden) decoder_output = Activation(output_act, name='decoder_output_act')(decoder_hidden) new_decoder = Model([latent_inputs, cond_input], decoder_output, name='new_decoder') return new_decoder new_decoder = build_new_decoder(p, p_cond, latent_dim, p_decoder_lst, hidden_act, output_act, use_batch_norm) return cvae, new_decoder
[docs] def build_CVAE(spatial_df, scRNA_df, scRNA_celltype, n_marker_per_cmp, n_pseudo_spot, pseudo_spot_min_cell, pseudo_spot_max_cell, seq_depth_scaler, cvae_input_scaler, cvae_init_lr, num_hidden_layer, use_batch_norm, cvae_train_epoch, use_spatial_pseudo, use_fdr, p_val_cutoff, fc_cutoff, pct1_cutoff, pct2_cutoff, sortby_fc, diagnosis, rerun_DE=True, filter_gene=True): ''' build CVAE to adjust platform effect, return transformed spatial gene expression and scRNA-seq cell-type gene signature input gene expression in datasets only included genes needed for downstream analysis and already been normalized by sequencing depth Parameters ---------- spatial_df : dataframe normalized gene expression in spatial transcriptomic data (spots * genes). scRNA_df : dataframe normalized gene expression in scRNA-seq data (cells * genes). scRNA_celltype : dataframe cell-type annotations for cells in scRNA-seq data. Only 1 column named <celltype>. n_marker_per_cmp : int number of TOP marker genes for each comparison in DE. n_pseudo_spot : int number of pseudo-spots. pseudo_spot_min_cell : int minimum value of cells in pseudo-spot. pseudo_spot_max_cell : int maximum value of cells in pseudo-spot. seq_depth_scaler : int a scaler of scRNA-seq sequencing depth. cvae_input_scaler : int maximum value of the scaled input for CVAE. cvae_init_lr : float initial learning rate for training CVAE. num_hidden_layer : int number of hidden layers in encoder and decoder. use_batch_norm : bool whether to use Batch Normalization. cvae_train_epoch : int max number of training epochs for the CVAE. use_spatial_pseudo : int whether to generate "pseudo-spots" in spatial condition. use_fdr : bool whether to use FDR adjusted p value for filtering and sorting. p_val_cutoff : float threshold of p value (or FDR if --use_fdr is true) in marker genes filtering. fc_cutoff : float threshold of fold change (without log transform!) in marker genes filtering. pct1_cutoff : float threshold of pct.1 in marker genes filtering. pct2_cutoff : float threshold of pct.2 in marker genes filtering. sortby_fc : bool whether to sort marker genes by fold change. diagnosis : bool if True save more information to files for diagnosis CVAE and hyper-parameter selection. rerun_DE : bool, optional whether to rerun DE on the CVAE transformed scRNA-seq data, since the DE genes might be different with before CVAE transforming. filter_gene : bool whether to filter genes before DE. Returns ------- spatial_transformed_numi : dataframe CVAE transformed (platform effect adjusted) spatial spot gene raw nUMI counts (spots * genes). scRNA_decode_avg_df : dataframe CVAE decodered average gene expression (normalized) of cell-types in scRNA-seq data (cell-types * genes). new_markers : list or None marker genes from re-run DE on CVAE transformed scRNA-seq data. It will be None if not re-run DE (rerun_DE=False). cvae_pred : dataframe or None cell-type proportions of spatial spots predicted or transferred by CVAE. It will be None if no way to got initial guess of cell-type proportions (spots * cell-types). ''' assert((scRNA_df.index == scRNA_celltype.index).all()) assert((spatial_df.columns == scRNA_df.columns).all()) if diagnosis: # first plot UMAP for raw input gene expressions from diagnosis_plots import defineColor, rawInputUMAP plot_colors = defineColor(spatial_df.shape[0], scRNA_celltype) rawInputUMAP(spatial_df, scRNA_df, scRNA_celltype, plot_colors) # some settings # max number of pseudo spots (training+validation, without training single cells) n_max_pseudo_spots = n_pseudo_spot # scaler to multiply the normalized gene values and transform back to raw nUMI counts depth_scaler = seq_depth_scaler # percentage of training pseudo spots training_pct = 0.8 # max value when scaling the input gene expression of CVAE, while min is 0 input_max = cvae_input_scaler # first log transform, then minmax scaling use_log_transform = True # whether to get initial guess of cell-type proportions do_initial_guess = True # whether to single cell augmentation and #cells of cell type balancing sc_augment = True # the order of celltypes matters, unify the order throughout whole pipeline, which will be determined here celltype_order = sorted(list(scRNA_celltype.celltype.unique())) n_celltype = len(celltype_order) celltype_count_dict = scRNA_celltype.celltype.value_counts().to_dict() # Randomly select cells into pseudo-spots, at most X pseudo-spots # total number of generated pseudo-spots (including training and validation pseudo-spots, NOT include scRNA-seq cells) n_pseudo_scrna = int(min(100 * spatial_df.shape[0] * n_celltype, n_max_pseudo_spots)) n_train_pseudo_scrna = int(np.floor(n_pseudo_scrna * training_pct)) n_valid_pseudo_scrna = int(n_pseudo_scrna - n_train_pseudo_scrna) print(f'generate {n_pseudo_scrna} pseudo-spots containing {pseudo_spot_min_cell} to {pseudo_spot_max_cell} cells from scRNA-seq cells...') # pseudo-spot gene expression (pseudo-spots * genes; NO scRNA-seq cells at the end) # pseudo-spot cell-type proportions (pseudo-spots * cell-types; NO scRNA-seq cells at the end) # number of cells in pseudo-spots (NO scRNA-seq cells at the end) pseudo_spots_df, pseudo_spots_celltype_prop, n_cell_in_spot = generate_pseudo_spots(scRNA_df, scRNA_celltype, n_pseudo_scrna, celltype_order, pseudo_spot_min_cell, pseudo_spot_max_cell) if sc_augment: print('\n#cells of cell types in reference scRNA-seq data:') for k, v in celltype_count_dict.items(): print(f'{k}: {v}') max_count = max(celltype_count_dict.values()) target_count = int(1.5 * max_count) print(f'HIGHLIGHT: augment single cells to {target_count} cells per cell type') scRNA_df, scRNA_celltype, scrna_n_cell = augment_sc(scRNA_df, scRNA_celltype, target_count, 2, 3) # update count celltype_count_dict = scRNA_celltype.celltype.value_counts().to_dict() # also update color palette if diagnosis: plot_colors = defineColor(spatial_df.shape[0], scRNA_celltype) # split training and validation n_train_scrna_cell = int(np.floor(scRNA_df.shape[0] * training_pct)) n_valid_scrna_cell = int(scRNA_df.shape[0] - n_train_scrna_cell) else: n_train_scrna_cell = scRNA_df.shape[0] n_valid_scrna_cell = 0 scrna_n_cell = [1] * n_train_pseudo_scrna # convert scRNA-seq cell-type annotation to proportions scrna_cell_celltype_prop = celltype2props(scRNA_celltype, celltype_order) # generate pseudo-spots by combining spatial spots if use_spatial_pseudo: n_pseudo_spatial = int(0.5 * n_pseudo_scrna) else: n_pseudo_spatial = 0 n_train_pseudo_spatial = int(np.floor(n_pseudo_spatial * training_pct)) n_valid_pseudo_spatial = int(n_pseudo_spatial - n_train_pseudo_spatial) print(f'generate {n_pseudo_spatial} pseudo-spots containing 2 to 6 spots from spatial spots...') pseudo_spatial_df = combine_spatial_spots(spatial_df, n_pseudo_spatial, 2, 6) if use_log_transform: # since the input dataframe is extracted from AnnData object, and will not be used in downstream analysis (we can extract from AnnData again), it's safe to modify them directly here print('\nHIGHLIGHT: first apply log transformation on sequencing depth normalized gene expressions, followed by Min-Max scaling') spatial_df = np.log1p(spatial_df) if pseudo_spots_df.shape[0] > 0: pseudo_spots_df = np.log1p(pseudo_spots_df) if pseudo_spatial_df.shape[0] > 0: pseudo_spatial_df = np.log1p(pseudo_spatial_df) # do not forget scRNA_df scRNA_df = np.log1p(scRNA_df) # Build training and validation data # first spots used for validation, rest spots used for training print(f'\n{"" : <24} | {"training": >9} | {"validation": >9}') print(f'{"spatial spots" : <24} | {spatial_df.shape[0]: >9} | {0: >9}') print(f'{"spatial pseudo-spots" : <24} | {n_train_pseudo_spatial: >9} | {n_valid_pseudo_spatial: >9}') print(f'{"scRNA-seq cells" : <24} | {n_train_scrna_cell: >9} | {n_valid_scrna_cell: >9}') print(f'{"scRNA-seq pseudo-spots" : <24} | {n_train_pseudo_scrna: >9} | {n_valid_pseudo_scrna: >9}\n') train_scrna_df = pd.concat([pseudo_spots_df.iloc[n_valid_pseudo_scrna:,:], scRNA_df.iloc[n_valid_scrna_cell:,:]], ignore_index=False) valid_scrna_df = pd.concat([pseudo_spots_df.iloc[:n_valid_pseudo_scrna,:], scRNA_df.iloc[:n_valid_scrna_cell,:]], ignore_index=False) train_spatial_df = pd.concat([pseudo_spatial_df.iloc[n_valid_pseudo_spatial:,:], spatial_df], ignore_index=False) valid_spatial_df = pseudo_spatial_df.iloc[:n_valid_pseudo_spatial,:] assert train_scrna_df.shape[0] == (n_train_pseudo_scrna + n_train_scrna_cell) assert valid_scrna_df.shape[0] == (n_valid_pseudo_scrna + n_valid_scrna_cell) assert train_spatial_df.shape[0] == (n_train_pseudo_spatial + spatial_df.shape[0]) assert valid_spatial_df.shape[0] == n_valid_pseudo_spatial # scaling to [0,input_max] with each dataset separately # use only spatial spots + spatial pseudo-spots for spatial dataset scaling print(f'scaling inputs to range 0 to {input_max}') spatial_min_max_scaler = MinMaxScaler(feature_range=[0, input_max]) train_spatial_data = spatial_min_max_scaler.fit_transform(train_spatial_df) if valid_spatial_df.shape[0] > 0: valid_spatial_data = spatial_min_max_scaler.transform(valid_spatial_df) else: valid_spatial_data = valid_spatial_df.values # use only training pseudo spots + single cells for scRNA-seq dataset scaling scRNA_min_max_scaler = MinMaxScaler(feature_range=[0, input_max]) train_scrna_data = scRNA_min_max_scaler.fit_transform(train_scrna_df) if valid_scrna_df.shape[0] > 0: valid_scrna_data = scRNA_min_max_scaler.transform(valid_scrna_df) else: valid_scrna_data = valid_scrna_df.values # first spatial pseudo-spots then spatial spots then scRNA-seq pseudo-spots and scRNA-seq cells in data # we also consider whether to duplicate spatial data since they are few # update: use training sample weights data = np.vstack((train_spatial_data, train_scrna_data)) labels = np.array([input_max]*train_spatial_data.shape[0] + [0.]*train_scrna_data.shape[0]) labels = labels.reshape((len(labels), 1)) # validation data valid_data = np.vstack((valid_spatial_data, valid_scrna_data)) valid_labels = np.array([input_max]*valid_spatial_data.shape[0] + [0.]*valid_scrna_data.shape[0]) valid_labels = valid_labels.reshape((len(valid_labels), 1)) # training sample weights weight_pseudo_scrna = np.ones((n_train_pseudo_scrna,)) weight_cell_scrna = np.ones((n_train_scrna_cell,)) weight_pseudo_spatial = np.ones((n_train_pseudo_spatial,)) weight_spot_spatial = np.ones((spatial_df.shape[0],)) # weight sum of scRNA-seq cells : sum of scRNA pseudo spots = 1 : 1 # always decrease the weights for cohort with more samples if n_train_pseudo_scrna > 0: if n_train_pseudo_scrna > n_train_scrna_cell: weight_pseudo_scrna *= n_train_scrna_cell / n_train_pseudo_scrna elif n_train_pseudo_scrna < n_train_scrna_cell: weight_cell_scrna *= n_train_pseudo_scrna / n_train_scrna_cell # weight sum of spatial spots : sum of spatial pseudo spots = 1 : 1 # always decrease the weights for cohort with more samples if n_train_pseudo_spatial > 0: if n_train_pseudo_spatial > spatial_df.shape[0]: weight_pseudo_spatial *= spatial_df.shape[0] / n_train_pseudo_spatial elif n_train_pseudo_spatial < spatial_df.shape[0]: weight_spot_spatial *= n_train_pseudo_spatial / spatial_df.shape[0] # Final Balancing, re-weight spatial data to make sure the sum of spatial : sum of scRNA-seq = 1 : 1 # since we have already adjusted the weights, here we can not rely on sample size any more, use the sum of weight instead # always decrease the weights for cohort with more samples if (np.sum(weight_pseudo_scrna)+np.sum(weight_cell_scrna)) < (np.sum(weight_pseudo_spatial)+np.sum(weight_spot_spatial)): # calculate factor beforehand to avoid update in weight causing factor change tmp_factor = (np.sum(weight_pseudo_scrna)+np.sum(weight_cell_scrna)) / (np.sum(weight_pseudo_spatial)+np.sum(weight_spot_spatial)) weight_pseudo_spatial *= tmp_factor weight_spot_spatial *= tmp_factor elif (np.sum(weight_pseudo_scrna)+np.sum(weight_cell_scrna)) > (np.sum(weight_pseudo_spatial)+np.sum(weight_spot_spatial)): tmp_factor = (np.sum(weight_pseudo_spatial)+np.sum(weight_spot_spatial)) / (np.sum(weight_pseudo_scrna)+np.sum(weight_cell_scrna)) weight_pseudo_scrna *= tmp_factor weight_cell_scrna *= tmp_factor sample_weight = np.concatenate([weight_pseudo_spatial, weight_spot_spatial, weight_pseudo_scrna, weight_cell_scrna]) del train_spatial_data, valid_spatial_data, train_scrna_data, valid_scrna_data del train_spatial_df, valid_spatial_df, train_scrna_df, valid_scrna_df # Define CVAE # number of nodes in input layer (equals number of celltypes) p = data.shape[1] # number of nodes in conditional node p_cond = 1 # Hyper-parameters latent_dim = n_celltype * 3 # use geometric mean of latent and input dimension (a geometric progression) hidden_dim = list(np.floor(np.geomspace(latent_dim, p, num_hidden_layer+2)[1:num_hidden_layer+1]).astype('int')) print('\nCVAE structure:') print(f'Encoder: {" - ".join([str(x) for x in ([p+p_cond] + hidden_dim[::-1] + [latent_dim])])}') print(f'Decoder: {" - ".join([str(x) for x in ([latent_dim+p_cond] + hidden_dim + [p])])}\n') # note hidden layer in encoder is a reverse of the hidden_dim variable cvae, new_decoder = CVAE_keras_model(p, p_cond, latent_dim, hidden_dim[::-1], hidden_dim, use_batch_norm=use_batch_norm, cvae_init_lr=cvae_init_lr) # learning rate decay lrate = ReduceLROnPlateau(monitor='val_loss', factor=0.9, patience=20, min_lr=5e-4, cooldown=10, verbose=False) # early stopping based on validation loss early_stop = EarlyStopping(monitor='val_loss', min_delta=0, patience=40, restore_best_weights=True, verbose=False) # change tensorflow seed value, set the same seed value for sampling samples from latent space to decoder before training # still has unknown randomness source even set seed here... set_random_seed(1154) # if not use Batch Normalization, we use all training data for one epoch if use_batch_norm: one_batch_size = 16384 do_shuffle = True else: one_batch_size = data.shape[0] do_shuffle = False # Train CVAE # note when there is no pseudo-spots, then there is no validation data if valid_data.shape[0] == 0: print('\nStart training without validation data...\n') history_callback = cvae.fit([data, labels], data, epochs=cvae_train_epoch, batch_size=one_batch_size, shuffle=do_shuffle, callbacks=[lrate, early_stop], sample_weight=sample_weight, verbose=True) else: print('\nStart training...\n') history_callback = cvae.fit([data, labels], data, validation_data=([valid_data, valid_labels], valid_data), epochs=cvae_train_epoch, batch_size=one_batch_size, shuffle=do_shuffle, callbacks=[lrate, early_stop], sample_weight=sample_weight, verbose=True) n_epoch = len(history_callback.history['loss']) if n_epoch < cvae_train_epoch: print(f'\ntraining finished in {n_epoch} epochs (early stop), transform data to adjust the platform effect...\n') else: print(f'\ntraining finished in {n_epoch} epochs (reach max pre-specified epoches), transform data to adjust the platform effect...\n') if diagnosis: # plot loss from diagnosis_plots import plotCVAELoss plotCVAELoss(history_callback) # postprocess the trained models # Subset the encoder encoder = Model([cvae.get_layer('encoder_input').input, cvae.get_layer('cond_input').input], [cvae.get_layer('z_mean').output, cvae.get_layer('z_log_var').output], name='encoder') # update layer weights for the decoder for layer in new_decoder.layers: if layer.name.endswith('w') or layer.name.endswith('BN'): new_decoder.get_layer(layer.name).set_weights(cvae.get_layer(layer.name).get_weights()) ''' # double check to ensure decoder is correct from utils import check_decoder check_decoder(cvae, new_decoder, data, labels) ''' # Transform data to adjust platform effect # note if log transform is applied, the data in spatial_df, pseudo_spatial_df, scRNA_df and pseudo_spots_df are already log transformed # Decoder output of transformed latent embedding of spatial spot-level data spatial_embed = encoder.predict([spatial_min_max_scaler.transform(spatial_df), np.full((spatial_df.shape[0],1), input_max)])[0] tmp_output = new_decoder.predict([spatial_embed, np.full((spatial_embed.shape[0],1), 0)]) spatial_transformed_df = pd.DataFrame(scRNA_min_max_scaler.inverse_transform(tmp_output), columns=spatial_df.columns, index=spatial_df.index) # log transformation back if use_log_transform: print('HIGHLIGHT: when transforming data, after reversed Min-Max Scaling, apply exp transformation then multiple the factor and round to integer') spatial_transformed_df = np.expm1(spatial_transformed_df) spatial_transformed_numi = np.rint(spatial_transformed_df * depth_scaler) # Decoder output of average marker gene expression of scRNA-seq cell-types # NOTE: here we also include augmented single cells into downstream analysis, which affects the cell-type marker gene profile scRNA_embed = encoder.predict([scRNA_min_max_scaler.transform(scRNA_df), np.full((scRNA_df.shape[0],1), 0)])[0] # decode it tmp_output = new_decoder.predict([scRNA_embed, np.full((scRNA_embed.shape[0],1), 0)]) scRNA_decode_df = pd.DataFrame(scRNA_min_max_scaler.inverse_transform(tmp_output), columns=scRNA_df.columns, index=scRNA_df.index) # log transformation back if use_log_transform: scRNA_decode_df = np.expm1(scRNA_decode_df) # whether to rerun DE if rerun_DE: print('\nre-run DE on CVAE transformed scRNA-seq data!') from utils import rerun_DE new_markers = rerun_DE(scRNA_decode_df, scRNA_celltype, n_marker_per_cmp, use_fdr, p_val_cutoff, fc_cutoff, pct1_cutoff, pct2_cutoff, sortby_fc, diagnosis, filter_gene) else: new_markers = None # take average for all genes scRNA_decode_avg_df = scRNA_decode_df.copy() scRNA_decode_avg_df['celltype'] = scRNA_celltype.celltype scRNA_decode_avg_df = scRNA_decode_avg_df.groupby(['celltype']).mean() # re-order the rows to match the previous defined cell-type order scRNA_decode_avg_df = scRNA_decode_avg_df.loc[celltype_order, :] # transfer cell-type proportions as an initial guess if n_pseudo_scrna == 0: pseudo_spot_embed = np.empty((0, scRNA_embed.shape[1])) else: pseudo_spot_embed = encoder.predict([scRNA_min_max_scaler.transform(pseudo_spots_df), np.full((pseudo_spots_df.shape[0],1), 0)])[0] if do_initial_guess: use_embedding = 'none' if use_embedding == 'none': print('HIGHLIGHT: got initial guess of cell type proportions based on original CVAE latent embedding') else: print(f'\nHIGHLIGHT: got initial guess of cell type proportions based on {use_embedding} embedding of CVAE latent space') tmp_pred = transferProps(spatial_embed, np.vstack((scRNA_embed, pseudo_spot_embed)), pd.concat([scrna_cell_celltype_prop, pseudo_spots_celltype_prop]).values, n_neighbors=10, sigma=1, use_embedding=use_embedding) cvae_pred = pd.DataFrame(tmp_pred, index=spatial_df.index, columns=celltype_order) if diagnosis: os.makedirs(os.path.join(diagnosis_path, 'initial_guess'), exist_ok=True) cvae_pred.to_csv(os.path.join(diagnosis_path, 'initial_guess', 'celltype_props_by_transferring.csv')) else: cvae_pred = None # whether to save models and transformed data if diagnosis: # also embed pseudo-spots for diagnosis if n_pseudo_spatial == 0: pseudo_spatial_embed = np.empty((0, spatial_embed.shape[1])) else: pseudo_spatial_embed = encoder.predict([spatial_min_max_scaler.transform(pseudo_spatial_df), np.full((pseudo_spatial_df.shape[0],1), input_max)])[0] from diagnosis_plots import diagnosisCVAE diagnosisCVAE(cvae, encoder, new_decoder, spatial_embed, spatial_transformed_df, spatial_transformed_numi, pseudo_spatial_embed, scRNA_celltype, celltype_order, celltype_count_dict, scrna_cell_celltype_prop, scRNA_embed, scrna_n_cell, pseudo_spots_celltype_prop, n_cell_in_spot, pseudo_spot_embed, scRNA_decode_df, scRNA_decode_avg_df, new_markers, plot_colors) #print(f'before CVAE building function return RAM usage: {psutil.Process().memory_info().rss/1024**2:.2f} MB') return spatial_transformed_numi, scRNA_decode_avg_df, new_markers, cvae_pred
[docs] def build_CVAE_whole(spatial_file, ref_file, ref_anno_file, marker_file, n_hv_gene, n_marker_per_cmp, n_pseudo_spot, pseudo_spot_min_cell, pseudo_spot_max_cell, seq_depth_scaler, cvae_input_scaler, cvae_init_lr, num_hidden_layer, use_batch_norm, cvae_train_epoch, use_spatial_pseudo, redo_de, use_fdr, p_val_cutoff, fc_cutoff, pct1_cutoff, pct2_cutoff, sortby_fc, diagnosis, filter_cell, filter_gene): ''' read related CSV files, build CVAE to adjust platform effect, return transformed spatial gene expression and scRNA-seq cell-type gene signature Parameters ---------- spatial_file : string full path of input csv file of raw nUMI counts in spatial transcriptomic data (spots * genes). ref_file : string full path of input csv file of raw nUMI counts in scRNA-seq data (cells * genes). ref_anno_file : string full path of input csv file of cell-type annotations for all cells in scRNA-seq data. marker_file : string full path of input csv file of cell-typee marker gene expression (cell-types * genes). n_hv_gene : int number of highly variable genes for CVAE. n_marker_per_cmp : int number of TOP marker genes for each comparison in DE. n_pseudo_spot : int number of pseudo-spots. pseudo_spot_min_cell : int minimum value of cells in pseudo-spot. pseudo_spot_max_cell : int maximum value of cells in pseudo-spot. seq_depth_scaler : int a scaler of scRNA-seq sequencing depth. cvae_input_scaler : int maximum value of the scaled input for CVAE. cvae_init_lr : float initial learning rate for training CVAE. num_hidden_layer : int number of hidden layers in encoder and decoder. use_batch_norm : bool whether to use Batch Normalization. cvae_train_epoch : int max number of training epochs for the CVAE. use_spatial_pseudo : int whether to generate "pseudo-spots" in spatial condition. redo_de : bool whether to redo DE after CVAE transformation. use_fdr : bool whether to use FDR adjusted p value for filtering and sorting. p_val_cutoff : float threshold of p value (or FDR if --use_fdr is true) in marker genes filtering. fc_cutoff : float threshold of fold change (without log transform!) in marker genes filtering. pct1_cutoff : float threshold of pct.1 in marker genes filtering. pct2_cutoff : float threshold of pct.2 in marker genes filtering. sortby_fc : bool whether to sort marker genes by fold change. diagnosis : bool if True save more information to files for diagnosis CVAE and hyper-parameter selection. filter_cell : bool whether to filter cells before DE. filter_gene : bool whether to filter genes before DE. Returns ------- spatial_transformed_numi : dataframe CVAE transformed (platform effect adjusted) spatial spot gene raw nUMI counts (spots * genes). scRNA_decode_avg_df : dataframe CVAE decodered average gene expression (normalized) of cell-types in scRNA-seq data (cell-types * genes). new_markers : list or None marker genes from re-run DE on CVAE transformed scRNA-seq data. It will be None if not re-run DE. cvae_pred : dataframe or None cell-type proportions of spatial spots predicted or transferred by CVAE. It will be None if no way to got initial guess of cell-type proportions (spots * cell-types). ''' start_time = time() # read spatial data spatial_spot_obj = read_spatial_data(spatial_file, filter_gene)[0] # read scRNA-seq data scrna_obj = read_scRNA_data(ref_file, ref_anno_file, filter_cell, filter_gene) # Overlap of genes between scRNA cell-level and spatial spot-level data overlap_genes = list(set(spatial_spot_obj.var_names).intersection(set(scrna_obj.var_names))) print(f'total {len(overlap_genes)} overlapped genes') # if len(overlap_genes) < spatial_spot_obj.n_vars: # print(f'{spatial_spot_obj.n_vars-len(overlap_genes)} genes in spatial data but not found in scRNA-seq data: {", ".join(set(spatial_spot_obj.var_names).difference(set(overlap_genes)))}\n') # subset overlapped gene spatial_spot_obj = spatial_spot_obj[:, overlap_genes].copy() scrna_obj = scrna_obj[:, overlap_genes].copy() # how many genes used for CVAE if len(overlap_genes) <= n_hv_gene: # use all genes for CVAE final_gene_list = sorted(overlap_genes) print(f'\nuse all {len(final_gene_list)} genes for downstream analysis as there are less genes available than specified number {n_hv_gene}') else: # use selected highly variable genes + cell-type marker genes for CVAE # identify highly variable genes in scRNA-seq data, select TOP X HV genes # no need to consider highly variable genes in spatial data, as for cell-type deconvolution, we work on each spot independently print(f'\nidentify {n_hv_gene} highly variable genes from scRNA-seq data...') if n_hv_gene == 0: scrna_hv_genes = [] else: scrna_hv_genes = sc.pp.highly_variable_genes(scrna_obj, layer='raw_nUMI', flavor='seurat_v3', n_top_genes=n_hv_gene, inplace=False) scrna_hv_genes = scrna_hv_genes.loc[scrna_hv_genes['highly_variable']==True].index.to_list() # identify cell-type marker genes print('\nidentify cell-type marker genes...') if marker_file is not None: # directly use provide marker gene expression marker_df = pd.read_csv(marker_file, index_col=0) print('user provided marker gene profile, DE will be skipped...\n') print(f'read {marker_df.shape[1]} marker genes from user specified marker gene file') # extract marker gene overlapped with spatial data marker_genes = list(set(overlap_genes) & set(marker_df.columns)) print(f'from user specified marker gene expression use {len(marker_genes)} marker genes overlapped with spatial + scRNA-seq data') # if len(marker_genes) < len(overlap_genes): # print(f'{len(overlap_genes)-len(marker_genes)} genes in overlapped gene list between spatial and scRNA-seq data but not found in user provided marker gene expression: {", ".join(set(overlap_genes).difference(set(marker_genes)))}\n') else: # perform DE, return the marker gene expression print('no marker gene profile provided. Perform DE to get cell-type marker genes on scRNA-seq data...\n') marker_genes = run_DE(scrna_obj, n_marker_per_cmp, use_fdr, p_val_cutoff, fc_cutoff, pct1_cutoff, pct2_cutoff, sortby_fc, diagnosis, 'DE_celltype_markers.csv') # final gene list for downstream analysis final_gene_list = sorted(list(set(scrna_hv_genes).union(set(marker_genes)))) print(f'\nuse union of highly variable gene list and cell-type marker gene list derived from scRNA-seq data, finally get {len(final_gene_list)} genes for downstream analysis') print('\nstart CVAE building...\n') scrna_celltype = sc.get.obs_df(scrna_obj, keys='celltype').to_frame() scrna_celltype['celltype'] = scrna_celltype['celltype'].astype(str) # build CVAE (spatial_transformed_numi, scRNA_decode_avg_df, new_markers, cvae_pred) = build_CVAE(sc.get.obs_df(spatial_spot_obj, keys=final_gene_list), sc.get.obs_df(scrna_obj, keys=final_gene_list), scrna_celltype, n_marker_per_cmp, n_pseudo_spot, pseudo_spot_min_cell, pseudo_spot_max_cell, seq_depth_scaler, cvae_input_scaler, cvae_init_lr, num_hidden_layer, use_batch_norm, cvae_train_epoch, use_spatial_pseudo, use_fdr, p_val_cutoff, fc_cutoff, pct1_cutoff, pct2_cutoff, sortby_fc, diagnosis, rerun_DE=redo_de, filter_gene=filter_gene) print(f'\nplatform effect adjustment by CVAE finished. Elapsed time: {(time()-start_time)/60.0:.2f} minutes.\n\n') return spatial_transformed_numi, scRNA_decode_avg_df, new_markers, cvae_pred