作者提供一组心肌炎相关的scRNA-seq数据,其中包含来自non-failing (nf), hypertrophic, and dilated样本的数据,以及是否为对药物敏感的转录因子的gene list。根据这些数据进行微调,随后判断基因是否为对药物敏感的转录因子。
微调数据:sc-RNA-seq data and gene labels;
下游任务:判断TFs的药物敏感性。
Modules import
1 2 3 4
import os GPU_NUMBER = [0] os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER]) os.environ["NCCL_DEBUG"] = "INFO"
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# imports import datetime import subprocess import math import matplotlib.pyplot as plt import numpy as np import pandas as pd from datasets import load_from_disk from sklearn import preprocessing from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve from sklearn.model_selection import StratifiedKFold import torch from transformers import BertForTokenClassification from transformers import Trainer from transformers.training_args import TrainingArguments from tqdm.notebook import tqdm
from geneformer import DataCollatorForGeneClassification from geneformer.pretrainer import token_dictionary
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:68: NumbaDeprecationWarning: [1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.[0m
def twobit_to_dna(twobit: int, size: int) -> str:
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:85: NumbaDeprecationWarning: [1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.[0m
def dna_to_twobit(dna: str) -> int:
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:102: NumbaDeprecationWarning: [1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.[0m
def twobit_1hamming(twobit: int, size: int) -> List[int]:
# table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.) gene_info = pd.read_csv("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/gene_info_table.csv", index_col=0)
# create dictionaries for corresponding attributes gene_id_type_dict = dict(zip(gene_info["ensembl_id"],gene_info["gene_type"])) gene_name_id_dict = dict(zip(gene_info["gene_name"],gene_info["ensembl_id"])) gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()}
# first 5 key:value pairs {k: gene_id_name_dict[k] for k inlist(gene_id_name_dict)[:5]}
# function for preparing targets and labels defprep_inputs(genegroup1, genegroup2, id_type): if id_type == "gene_name": targets1 = [gene_name_id_dict[gene] for gene in genegroup1 if gene_name_id_dict.get(gene) in token_dictionary] targets2 = [gene_name_id_dict[gene] for gene in genegroup2 if gene_name_id_dict.get(gene) in token_dictionary] elif id_type == "ensembl_id": targets1 = [gene for gene in genegroup1 if gene in token_dictionary] targets2 = [gene for gene in genegroup2 if gene in token_dictionary] targets1_id = [token_dictionary[gene] for gene in targets1] targets2_id = [token_dictionary[gene] for gene in targets2] targets = np.array(targets1_id + targets2_id) labels = np.array([0]*len(targets1_id) + [1]*len(targets2_id)) nsplits = min(5, min(len(targets1_id), len(targets2_id))-1) assert nsplits > 2 print(f"# targets1: {len(targets1_id)}\n# targets2: {len(targets2_id)}\n# splits: {nsplits}") return targets, labels, nsplits
1
{k: token_dictionary[k] for k inlist(token_dictionary)[:5]}
# load training dataset train_dataset=load_from_disk("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset") shuffled_train_dataset = train_dataset.shuffle(seed=42) subsampled_train_dataset = shuffled_train_dataset.select([i for i inrange(50_000)])
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\disease_classification\human_dcm_hcm_nf.dataset\cache-54b519f110fa07f1.arrow
defpreprocess_classifier_batch(cell_batch, max_len): if max_len == None: max_len = max([len(i) for i in cell_batch["input_ids"]]) defpad_label_example(example): example["labels"] = np.pad(example["labels"], (0, max_len-len(example["input_ids"])), mode='constant', constant_values=-100) example["input_ids"] = np.pad(example["input_ids"], (0, max_len-len(example["input_ids"])), mode='constant', constant_values=token_dictionary.get("<pad>")) example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int) return example padded_batch = cell_batch.map(pad_label_example) return padded_batch
# forward batch size is batch size for model inference (e.g. 200) defclassifier_predict(model, evalset, forward_batch_size, mean_fpr): predict_logits = [] predict_labels = [] model.eval() # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims evalset_len = len(evalset) max_divisible = find_largest_div(evalset_len, forward_batch_size) iflen(evalset) - max_divisible == 1: evalset_len = max_divisible max_evalset_len = max(evalset.select([i for i inrange(evalset_len)])["length"]) for i inrange(0, evalset_len, forward_batch_size): max_range = min(i+forward_batch_size, evalset_len) batch_evalset = evalset.select([i for i inrange(i, max_range)]) padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len) padded_batch.set_format(type="torch") input_data_batch = padded_batch["input_ids"] attn_msk_batch = padded_batch["attention_mask"] label_batch = padded_batch["labels"] with torch.no_grad(): outputs = model( input_ids = input_data_batch.to("cuda"), attention_mask = attn_msk_batch.to("cuda"), labels = label_batch.to("cuda"), ) predict_logits += [torch.squeeze(outputs.logits.to("cpu"))] predict_labels += [torch.squeeze(label_batch.to("cpu"))] logits_by_cell = torch.cat(predict_logits) all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2]) labels_by_cell = torch.cat(predict_labels) all_labels = torch.flatten(labels_by_cell) logit_label_paired = [item for item inlist(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100] y_pred = [vote(item[0]) for item in logit_label_paired] y_true = [item[1] for item in logit_label_paired] logits_list = [item[0] for item in logit_label_paired] # probability of class 1 y_score = [py_softmax(item)[1] for item in logits_list] conf_mat = confusion_matrix(y_true, y_pred) fpr, tpr, _ = roc_curve(y_true, y_score) # plot roc_curve for this split plt.plot(fpr, tpr) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('ROC') plt.show() # interpolate to graph interp_tpr = np.interp(mean_fpr, fpr, tpr) interp_tpr[0] = 0.0 return fpr, tpr, interp_tpr, conf_mat
defvote(logit_pair): a, b = logit_pair if a > b: return0 elif b > a: return1 elif a == b: return"tie" defpy_softmax(vector): e = np.exp(vector) return e / e.sum() # get cross-validated mean and sd metrics defget_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt): wts = [count/sum(all_tpr_wt) for count in all_tpr_wt] print(wts) all_weighted_tpr = [a*b for a,b inzip(all_tpr, wts)] mean_tpr = np.sum(all_weighted_tpr, axis=0) mean_tpr[-1] = 1.0 all_weighted_roc_auc = [a*b for a,b inzip(all_roc_auc, wts)] roc_auc = np.sum(all_weighted_roc_auc) roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts)) return mean_tpr, roc_auc, roc_auc_sd
# Function to find the largest number smaller # than or equal to N that is divisible by k deffind_largest_div(N, K): rem = N % K if(rem == 0): return N else: return N - rem
定义函数cross_validate封装模型数据切分(80% training set, 10% evaluation set, 10% hold-out evaluation set)、训练和预测过程。
# load model model = BertForTokenClassification.from_pretrained( "D:/jupyterNote/Geneformer", # change to local path to the model num_labels=2, output_attentions = False, output_hidden_states = False )
接下来,这部分代码根据定义的参数微调模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# add output directory to training args and initiate training_args["output_dir"] = ksplit_output_dir training_args_init = TrainingArguments(**training_args)
# append to tpr and roc lists confusion = confusion + conf_mat all_tpr.append(interp_tpr) all_roc_auc.append(auc(fpr, tpr)) # append number of eval examples by which to weight tpr in averaged graphs all_tpr_wt.append(len(tpr))
# cross-validate gene classifier defcross_validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc): # check if output directory already written to # ensure not overwriting previously saved model model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin") if os.path.isfile(model_dir_test) == True: raise Exception("Model already saved to this directory.") # initiate eval metrics to return num_classes = len(set(labels)) mean_fpr = np.linspace(0, 1, 100) all_tpr = [] all_roc_auc = [] all_tpr_wt = [] label_dicts = [] confusion = np.zeros((num_classes,num_classes)) # set up cross-validation splits skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True) # train and evaluate iteration_num = 0 for train_index, eval_index in tqdm(skf.split(targets, labels)): iflen(labels) > 500: print("early stopping activated due to large # of training examples") nsplits = 3 if iteration_num == 3: break print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n") # generate cross-validation splits targets_train, targets_eval = targets[train_index], targets[eval_index] labels_train, labels_eval = labels[train_index], labels[eval_index] label_dict_train = dict(zip(targets_train, labels_train)) label_dict_eval = dict(zip(targets_eval, labels_eval)) label_dicts += (iteration_num, targets_train, targets_eval, labels_train, labels_eval) # function to filter by whether contains train or eval labels defif_contains_train_label(example): a = label_dict_train.keys() b = example['input_ids'] returnnotset(a).isdisjoint(b)
defif_contains_eval_label(example): a = label_dict_eval.keys() b = example['input_ids'] returnnotset(a).isdisjoint(b) # filter dataset for examples containing classes for this split print(f"Filtering training data") trainset = data.filter(if_contains_train_label, num_proc=num_proc) print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n") print(f"Filtering evalation data") evalset = data.filter(if_contains_eval_label, num_proc=num_proc) print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
# minimize to smaller training sample training_size = min(subsample_size, len(trainset)) trainset_min = trainset.select([i for i inrange(training_size)]) eval_size = min(training_size, len(evalset)) half_training_size = round(eval_size/2) evalset_train_min = evalset.select([i for i inrange(half_training_size)]) evalset_oos_min = evalset.select([i for i inrange(half_training_size, eval_size)]) # label conversion functions defgenerate_train_labels(example): example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]] return example
defgenerate_eval_labels(example): example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]] return example # label datasets print(f"Labeling training data") trainset_labeled = trainset_min.map(generate_train_labels) print(f"Labeling evaluation data") evalset_train_labeled = evalset_train_min.map(generate_eval_labels) print(f"Labeling evaluation OOS data") evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels) # create output directories ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}") ksplit_model_dir = os.path.join(ksplit_output_dir, "models/") # ensure not overwriting previously saved model model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin") if os.path.isfile(model_output_file) == True: raise Exception("Model already saved to this directory.")
# make training and model output directories subprocess.call(f'mkdir {ksplit_output_dir}', shell=True) subprocess.call(f'mkdir {ksplit_model_dir}', shell=True) # load model model = BertForTokenClassification.from_pretrained( "D:/jupyterNote/Geneformer", # change as the path to the model num_labels=2, output_attentions = False, output_hidden_states = False ) if freeze_layers isnotNone: modules_to_freeze = model.bert.encoder.layer[:freeze_layers] for module in modules_to_freeze: for param in module.parameters(): param.requires_grad = False model = model.to("cuda:0") # add output directory to training args and initiate training_args["output_dir"] = ksplit_output_dir training_args_init = TrainingArguments(**training_args) # create the trainer trainer = Trainer( model=model, args=training_args_init, data_collator=DataCollatorForGeneClassification(), train_dataset=trainset_labeled, eval_dataset=evalset_train_labeled )
# train the gene classifier trainer.train() # save model trainer.save_model(ksplit_model_dir) # evaluate model fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 20, mean_fpr) # forward_batch_size: 20 # append to tpr and roc lists confusion = confusion + conf_mat all_tpr.append(interp_tpr) all_roc_auc.append(auc(fpr, tpr)) # append number of eval examples by which to weight tpr in averaged graphs all_tpr_wt.append(len(tpr)) iteration_num = iteration_num + 1 # get overall metrics for cross-validation mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt) return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts
Generally, in our experience, applications that are more relevant to the pretraining objective benefit from more layers being frozen to prevent overfitting to the limited task-specific data, whereas applications that are more distant from the pretraining objective benefit from fine-tuning of more layers to optimize performance on the new task.
# set model parameters # max input size max_input_size = 2 ** 11# 2048
# set training hyperparameters # max learning rate max_lr = 5e-5 # how many pretrained layers to freeze freeze_layers = 4 # number gpus num_gpus = 1 # number cpu cores num_proc = 6 # batch size for training and eval geneformer_batch_size = 2 # learning schedule lr_schedule_fn = "linear" # warmup steps warmup_steps = 500 # number of epochs epochs = 1 # optimizer optimizer = "adamw"
# ensure not overwriting previously saved model ksplit_model_test = os.path.join(training_output_dir, "ksplit0/models/pytorch_model.bin") if os.path.isfile(ksplit_model_test) == True: raise Exception("Model already saved to this directory.")
# make output directory subprocess.call(f'mkdir {training_output_dir}', shell=True)
0
1 2 3
# clear GPU memory after pytorch training import torch torch.cuda.empty_cache()
1 2
# not work #!set 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512' # Limit each allocation split to 500 MB
我们使用subsampled_train_dataset进行微调,其中包含50,000个细胞,每次抽取10,000个细胞做CV,一共做5次(nsplits=5).同样,将输入的targets和labels划分为80% training set (n = 392), 和 20% evaluation set (n = 98),这里采取的是stratified split,即不同split之间会有同样的数据。
0it [00:00, ?it/s]
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\disease_classification\human_dcm_hcm_nf.dataset\cache-509acb05b140c462.arrow
****** Crossval split: 0/4 ******
Filtering training data
Filtered 0%; 49994 remain
Filtering evalation data
1
Split 0 training info...
****** Crossval split: 1/4 ******
Filtering training data
Filtered 0%; 49992 remain
Filtering evalation data
Filtered 4%; 47913 remain
Labeling training data
1
Split 1 training info...
,l
****** Crossval split: 2/4 ******
Filtering training data
Filtered 0%; 49993 remain
Filtering evalation data
Filtered 4%; 47886 remain
Labeling training data
1
Split 2 training info...
****** Crossval split: 3/4 ******
Filtering training data
Filtered 0%; 49991 remain
Filtering evalation data
Filtered 4%; 48025 remain
Labeling training data
1
Split 3 training info...
****** Crossval split: 4/4 ******
Filtering training data
Filtered 0%; 49977 remain
Filtering evalation data
Filtered 2%; 48951 remain
Labeling training data
# # numbers of tfs (genes with 0/1 label) in out-of-sample evaluation set # tf_num = [len([v for v in i if v >= 0]) for i in evalset_oos_labeled['labels']] # sum(tf_num)
# frequencies of tokens token_freq = Counter()
for tks in token_id.values(): token_freq.update(tks)
# append all tokens into one list token_id_list = [tk for tks in token_id.values() for tk in tks]
# successed prediction # get prediction of gene (token = 9061) target_pred1 = [eval_pred[i] for i in token_id_list if i == 9061] print("Predicted label of gene 9061: ") print(Counter(target_pred1)) target_label1 = [eval_label[i] for i in token_id_list if i == 9061] print("True label of gene 9061: ") print(Counter(target_label1))
# failed prediction # get prediction of gene (token = 16425) target_pred2 = [eval_pred[i] for i in token_id_list if i == 16425] print("Predicted label of gene 16425: ") print(Counter(target_pred2)) target_label2 = [eval_label[i] for i in token_id_list if i == 16425] print("True label of gene 16425: ") print(Counter(target_label2))
1 2 3 4 5 6 7 8
Predicted label of gene 9061: Counter({1: 2755}) True label of gene 9061: Counter({1: 2755}) Predicted label of gene 16425: Counter({0: 6}) True label of gene 16425: Counter({1: 6})