0%

geneformer-gene-class

Gene Classification

上一篇文章写了用Geneformer如何做细胞分类,这一次记录用Genefomer做基因分类的过程,例如预测基因是否为药物敏感性TF。

首先,下载基因分类相关数据

1
2
3
4
5
6
cd Genecorpus-30M/example_input_files/cell_classification/disease_classifiction/human_dcm_hcm_nf.dataset
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/dataset.arrow

wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/dataset_info.json

wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/state.json

作者提供一组心肌炎相关的scRNA-seq数据,其中包含来自non-failing (nf), hypertrophic, and dilated样本的数据,以及是否为对药物敏感的转录因子的gene list。根据这些数据进行微调,随后判断基因是否为对药物敏感的转录因子。

微调数据:sc-RNA-seq data and gene labels;

下游任务:判断TFs的药物敏感性。

Modules import

1
2
3
4
import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# imports
import datetime
import subprocess
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_from_disk
from sklearn import preprocessing
from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve
from sklearn.model_selection import StratifiedKFold
import torch
from transformers import BertForTokenClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from tqdm.notebook import tqdm

from geneformer import DataCollatorForGeneClassification
from geneformer.pretrainer import token_dictionary
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:68: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def twobit_to_dna(twobit: int, size: int) -> str:
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:85: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def dna_to_twobit(dna: str) -> int:
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:102: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def twobit_1hamming(twobit: int, size: int) -> List[int]:

Load Gene Attribute Information

读入作者提供的基因信息表格,包括了ensembl id, gene name和gene type信息。再将这些信息分别封装到三个字典中(gene_id_type_dict, gene_name_id_dict, gene_id_name_dict).

1
2
3
4
5
6
7
8
9
10
# table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
gene_info = pd.read_csv("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/gene_info_table.csv", index_col=0)

# create dictionaries for corresponding attributes
gene_id_type_dict = dict(zip(gene_info["ensembl_id"],gene_info["gene_type"]))
gene_name_id_dict = dict(zip(gene_info["gene_name"],gene_info["ensembl_id"]))
gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()}

# first 5 key:value pairs
{k: gene_id_name_dict[k] for k in list(gene_id_name_dict)[:5]}
{'ENSG00000000003': 'TSPAN6',
 'ENSG00000000005': 'TNMD',
 'ENSG00000000419': 'DPM1',
 'ENSG00000000457': 'SCYL3',
 'ENSG00000000460': 'C1orf112'}

Load Training Data and Class Labels

接下来,读入微调训练相关数据集,包括心肌炎相关的scRNA-seq数据 (“human_dcm_hcm_nf.dataset”)和是否为对药物敏感的转录因子的gene list (“dosage_sens_tf_labels.csv”)

为了处理读入的dosage_sens_tf_labels,这里定义函数prep_inputs将输入的基因id转换为token id,并生成genegroup1genegroup2相应长度的labels(group1记为0, group2记为1).

token_dictionary中定义了ensembl id和token的对应关系。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# function for preparing targets and labels
def prep_inputs(genegroup1, genegroup2, id_type):
if id_type == "gene_name":
targets1 = [gene_name_id_dict[gene] for gene in genegroup1 if gene_name_id_dict.get(gene) in token_dictionary]
targets2 = [gene_name_id_dict[gene] for gene in genegroup2 if gene_name_id_dict.get(gene) in token_dictionary]
elif id_type == "ensembl_id":
targets1 = [gene for gene in genegroup1 if gene in token_dictionary]
targets2 = [gene for gene in genegroup2 if gene in token_dictionary]

targets1_id = [token_dictionary[gene] for gene in targets1]
targets2_id = [token_dictionary[gene] for gene in targets2]

targets = np.array(targets1_id + targets2_id)
labels = np.array([0]*len(targets1_id) + [1]*len(targets2_id))
nsplits = min(5, min(len(targets1_id), len(targets2_id))-1)
assert nsplits > 2
print(f"# targets1: {len(targets1_id)}\n# targets2: {len(targets2_id)}\n# splits: {nsplits}")
return targets, labels, nsplits
1
{k: token_dictionary[k] for k in list(token_dictionary)[:5]}
{'<pad>': 0,
 '<mask>': 1,
 'ENSG00000000003': 2,
 'ENSG00000000005': 3,
 'ENSG00000000419': 4}

读入作者提供的dosage sensitive tfs list,其中包含122 dosage sensitive tfs (0),和368个insensitive tfs (1). 使用prep_inputs将tfs的基因id转换为token,并划分为5个splits,做后续的5-fold cross-validation

1
2
3
4
5
6
7
8
9
from collections import Counter

# preparing targets and labels for dosage sensitive vs insensitive TFs
dosage_tfs = pd.read_csv("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", header=0)
sensitive = dosage_tfs["dosage_sensitive"].dropna()
insensitive = dosage_tfs["dosage_insensitive"].dropna()
targets, labels, nsplits = prep_inputs(sensitive, insensitive, "ensembl_id")
print(targets[0:5])
print(Counter(labels))
# targets1: 122
# targets2: 368
# splits: 5
[208 223 275 295 487]
Counter({1: 368, 0: 122})

读入作者提供的心肌炎scRNA-seq进行微调(fine-tune),其中包含579,159个细胞,21种celltypes;

3种亚型分组:1. NF (Non-failing), 2. HCM (hypertrophic cardiomyopathy), and 3. DCM (dilated cardiomyopathy).

在打乱细胞标签后,随机抽取了50,000个细胞作为training set.

1
2
3
4
5
# load training dataset
train_dataset=load_from_disk("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset")
shuffled_train_dataset = train_dataset.shuffle(seed=42)
subsampled_train_dataset = shuffled_train_dataset.select([i for i in range(50_000)])

Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\disease_classification\human_dcm_hcm_nf.dataset\cache-54b519f110fa07f1.arrow
1
2
3
4
5
6
7
8
#import pandas as pd
print(train_dataset)
print("\nCelltype: ")
print(Counter(train_dataset['cell_type']))
print("\nSubgroups: ")
print(Counter(train_dataset['disease']))

print(subsampled_train_dataset)
Dataset({
    features: ['input_ids', 'length', 'cell_type', 'individual', 'age', 'sex', 'disease', 'lvef'],
    num_rows: 579159
})

Celltype: 
Counter({'Fibroblast1': 141725, 'Cardiomyocyte1': 136167, 'Endothelial1': 78375, 'Pericyte1': 67600, 'Macrophage': 54714, 'Endothelial2': 18394, 'VSMC': 18137, 'Lymphocyte': 16246, 'Endocardial': 6489, 'Cardiomyocyte2': 5445, 'Adipocyte': 5298, 'ActivatedFibroblast': 5210, 'LymphaticEndothelial': 5181, 'Endothelial3': 4538, 'MastCell': 4465, 'Neuronal': 4292, 'Cardiomyocyte3': 3350, 'Pericyte2': 1704, 'ProliferatingMacrophage': 1276, 'Fibroblast2': 284, 'Epicardial': 269})

Subgroups: 
Counter({'hcm': 230652, 'nf': 182317, 'dcm': 166190})

Define Functions for Training and Cross-Validating Classifier

Geneformer将细胞基因表达量转为rank value encoding,且每个细胞的rank encoding长度不一样,而后续模型要求input tensors的长度一致。因此,这里定义函数preprocess_classifier_batch将不同长度的input都添加<pad> token到统一长度。

classifier_predict将input dataset 划分为forward_batch_size大小的batch利用fine-tuned的模型进行prediction,预测基因属于dosage sensitive or insensitive. 同时,根据预测labels与真实labels计算相应evaluation metrics (e.g., FPR, TPR)。

注意,如果使用GPU训练,且GPU内存太小,需要相应降低forward_batch_size,这里我调整至forward_batch_size=20

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
def preprocess_classifier_batch(cell_batch, max_len):
if max_len == None:
max_len = max([len(i) for i in cell_batch["input_ids"]])
def pad_label_example(example):
example["labels"] = np.pad(example["labels"],
(0, max_len-len(example["input_ids"])),
mode='constant', constant_values=-100)
example["input_ids"] = np.pad(example["input_ids"],
(0, max_len-len(example["input_ids"])),
mode='constant', constant_values=token_dictionary.get("<pad>"))
example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int)
return example
padded_batch = cell_batch.map(pad_label_example)
return padded_batch

# forward batch size is batch size for model inference (e.g. 200)
def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
predict_logits = []
predict_labels = []
model.eval()

# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
evalset_len = len(evalset)
max_divisible = find_largest_div(evalset_len, forward_batch_size)
if len(evalset) - max_divisible == 1:
evalset_len = max_divisible

max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])

for i in range(0, evalset_len, forward_batch_size):
max_range = min(i+forward_batch_size, evalset_len)
batch_evalset = evalset.select([i for i in range(i, max_range)])
padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
padded_batch.set_format(type="torch")

input_data_batch = padded_batch["input_ids"]
attn_msk_batch = padded_batch["attention_mask"]
label_batch = padded_batch["labels"]
with torch.no_grad():
outputs = model(
input_ids = input_data_batch.to("cuda"),
attention_mask = attn_msk_batch.to("cuda"),
labels = label_batch.to("cuda"),
)
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
predict_labels += [torch.squeeze(label_batch.to("cpu"))]

logits_by_cell = torch.cat(predict_logits)
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
labels_by_cell = torch.cat(predict_labels)
all_labels = torch.flatten(labels_by_cell)
logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]
y_pred = [vote(item[0]) for item in logit_label_paired]
y_true = [item[1] for item in logit_label_paired]
logits_list = [item[0] for item in logit_label_paired]
# probability of class 1
y_score = [py_softmax(item)[1] for item in logits_list]
conf_mat = confusion_matrix(y_true, y_pred)
fpr, tpr, _ = roc_curve(y_true, y_score)
# plot roc_curve for this split
plt.plot(fpr, tpr)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC')
plt.show()
# interpolate to graph
interp_tpr = np.interp(mean_fpr, fpr, tpr)
interp_tpr[0] = 0.0
return fpr, tpr, interp_tpr, conf_mat

def vote(logit_pair):
a, b = logit_pair
if a > b:
return 0
elif b > a:
return 1
elif a == b:
return "tie"

def py_softmax(vector):
e = np.exp(vector)
return e / e.sum()

# get cross-validated mean and sd metrics
def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]
print(wts)
all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]
mean_tpr = np.sum(all_weighted_tpr, axis=0)
mean_tpr[-1] = 1.0
all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]
roc_auc = np.sum(all_weighted_roc_auc)
roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))
return mean_tpr, roc_auc, roc_auc_sd

# Function to find the largest number smaller
# than or equal to N that is divisible by k
def find_largest_div(N, K):
rem = N % K
if(rem == 0):
return N
else:
return N - rem

定义函数cross_validate封装模型数据切分(80% training set, 10% evaluation set, 10% hold-out evaluation set)、训练和预测过程。

其中,读入预训练模型这部分需要改为本地Geneformer或是hugging face上库的名字 (“ctheodoris/Geneformer”)

1
2
3
4
5
6
7
# load model
model = BertForTokenClassification.from_pretrained(
"D:/jupyterNote/Geneformer", # change to local path to the model
num_labels=2,
output_attentions = False,
output_hidden_states = False
)

接下来,这部分代码根据定义的参数微调模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# add output directory to training args and initiate
training_args["output_dir"] = ksplit_output_dir
training_args_init = TrainingArguments(**training_args)

# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForGeneClassification(),
train_dataset=trainset_labeled,
eval_dataset=evalset_train_labeled
)

# train the gene classifier
trainer.train()

这部分代码使用微调模型在 out-of-sample dataset (evalset_oos_labeled) 进行预测及评估。

注意调整这里forward_batch_size以适应电脑配置。

1
2
3
4
5
6
7
8
9
10
# evaluate model
fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 20, mean_fpr) # forward_batch_size: 20

# append to tpr and roc lists
confusion = confusion + conf_mat
all_tpr.append(interp_tpr)
all_roc_auc.append(auc(fpr, tpr))
# append number of eval examples by which to weight tpr in averaged graphs
all_tpr_wt.append(len(tpr))

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# cross-validate gene classifier
def cross_validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc):
# check if output directory already written to
# ensure not overwriting previously saved model
model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
if os.path.isfile(model_dir_test) == True:
raise Exception("Model already saved to this directory.")

# initiate eval metrics to return
num_classes = len(set(labels))
mean_fpr = np.linspace(0, 1, 100)
all_tpr = []
all_roc_auc = []
all_tpr_wt = []
label_dicts = []
confusion = np.zeros((num_classes,num_classes))

# set up cross-validation splits
skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
# train and evaluate
iteration_num = 0
for train_index, eval_index in tqdm(skf.split(targets, labels)):
if len(labels) > 500:
print("early stopping activated due to large # of training examples")
nsplits = 3
if iteration_num == 3:
break
print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
# generate cross-validation splits
targets_train, targets_eval = targets[train_index], targets[eval_index]
labels_train, labels_eval = labels[train_index], labels[eval_index]
label_dict_train = dict(zip(targets_train, labels_train))
label_dict_eval = dict(zip(targets_eval, labels_eval))
label_dicts += (iteration_num, targets_train, targets_eval, labels_train, labels_eval)

# function to filter by whether contains train or eval labels
def if_contains_train_label(example):
a = label_dict_train.keys()
b = example['input_ids']
return not set(a).isdisjoint(b)

def if_contains_eval_label(example):
a = label_dict_eval.keys()
b = example['input_ids']
return not set(a).isdisjoint(b)

# filter dataset for examples containing classes for this split
print(f"Filtering training data")
trainset = data.filter(if_contains_train_label, num_proc=num_proc)
print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n")
print(f"Filtering evalation data")
evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")

# minimize to smaller training sample
training_size = min(subsample_size, len(trainset))
trainset_min = trainset.select([i for i in range(training_size)])
eval_size = min(training_size, len(evalset))
half_training_size = round(eval_size/2)
evalset_train_min = evalset.select([i for i in range(half_training_size)])
evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])

# label conversion functions
def generate_train_labels(example):
example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]]
return example

def generate_eval_labels(example):
example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]]
return example

# label datasets
print(f"Labeling training data")
trainset_labeled = trainset_min.map(generate_train_labels)
print(f"Labeling evaluation data")
evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
print(f"Labeling evaluation OOS data")
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)

# create output directories
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
ksplit_model_dir = os.path.join(ksplit_output_dir, "models/")

# ensure not overwriting previously saved model
model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
if os.path.isfile(model_output_file) == True:
raise Exception("Model already saved to this directory.")

# make training and model output directories
subprocess.call(f'mkdir {ksplit_output_dir}', shell=True)
subprocess.call(f'mkdir {ksplit_model_dir}', shell=True)

# load model
model = BertForTokenClassification.from_pretrained(
"D:/jupyterNote/Geneformer", # change as the path to the model
num_labels=2,
output_attentions = False,
output_hidden_states = False
)
if freeze_layers is not None:
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
for module in modules_to_freeze:
for param in module.parameters():
param.requires_grad = False

model = model.to("cuda:0")

# add output directory to training args and initiate
training_args["output_dir"] = ksplit_output_dir
training_args_init = TrainingArguments(**training_args)

# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForGeneClassification(),
train_dataset=trainset_labeled,
eval_dataset=evalset_train_labeled
)

# train the gene classifier
trainer.train()

# save model
trainer.save_model(ksplit_model_dir)

# evaluate model
fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 20, mean_fpr) # forward_batch_size: 20

# append to tpr and roc lists
confusion = confusion + conf_mat
all_tpr.append(interp_tpr)
all_roc_auc.append(auc(fpr, tpr))
# append number of eval examples by which to weight tpr in averaged graphs
all_tpr_wt.append(len(tpr))

iteration_num = iteration_num + 1

# get overall metrics for cross-validation
mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt)
return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts

Define Functions for Plotting Results

定义一个画ROC曲线的函数plot_ROC和画混淆矩阵的函数plot_confusion_matrix

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# plot ROC curve
def plot_ROC(bundled_data, title):
plt.figure()
lw = 2
for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
plt.plot(mean_fpr, mean_tpr, color=color,
lw=lw, label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(sample, roc_auc, roc_auc_sd))
plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(title)
plt.legend(loc="lower right")
plt.show()

# plot confusion matrix
def plot_confusion_matrix(classes_list, conf_mat, title):
display_labels = []
i = 0
for label in classes_list:
display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:,i]))]
i = i + 1
display = ConfusionMatrixDisplay(confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"),
display_labels=display_labels)
display.plot(cmap="Blues",values_format=".2g")
plt.title(title)

Fine-Tune With Gene Classification Learning Objective and Quantify Predictive Performance

定义模型微调的参数,同样的根据电脑配置调整num_gpus, num_proc, geneformer_batch_size.其余的超参延用预设的值,理论上超参也可以继续优化。

关于freeze_layers的选择,作者说下游任务和pretrain越相似的时候freeze_layers可以越大,即“记住”更多pretrain的weights (?).

Generally, in our experience, applications that are more relevant to the pretraining objective benefit from more layers being frozen to prevent overfitting to the limited task-specific data, whereas applications that are more distant from the pretraining objective benefit from fine-tuning of more layers to optimize performance on the new task.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# set model parameters
# max input size
max_input_size = 2 ** 11 # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 4
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 6
# batch size for training and eval
geneformer_batch_size = 2
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 1
# optimizer
optimizer = "adamw"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# set training arguments
subsample_size = 10_000
training_args = {
"learning_rate": max_lr,
"do_train": True,
"evaluation_strategy": "no",
"save_strategy": "epoch",
"logging_steps": 100,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": 0.001,
"per_device_train_batch_size": geneformer_batch_size,
"per_device_eval_batch_size": geneformer_batch_size,
"num_train_epochs": epochs,
}
1
2
3
4
5
6
7
8
9
10
11
12
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
training_output_dir = f"D:\\jupyterNote\\Geneformer\\examples\\gene_class_test\\{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}\\"

# ensure not overwriting previously saved model
ksplit_model_test = os.path.join(training_output_dir, "ksplit0/models/pytorch_model.bin")
if os.path.isfile(ksplit_model_test) == True:
raise Exception("Model already saved to this directory.")

# make output directory
subprocess.call(f'mkdir {training_output_dir}', shell=True)
0
1
2
3
# clear GPU memory after pytorch training 
import torch
torch.cuda.empty_cache()
1
2
# not work
#!set 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512' # Limit each allocation split to 500 MB

我们使用subsampled_train_dataset进行微调,其中包含50,000个细胞,每次抽取10,000个细胞做CV,一共做5次(nsplits=5).同样,将输入的targets和labels划分为80% training set (n = 392), 和 20% evaluation set (n = 98),这里采取的是stratified split,即不同split之间会有同样的数据。

这些划分的target和label存储在label_dicts中,其中每五个元素为一组,包括iteration_num, targets_train, targets_eval, labels_train, labels_eval.

cross_validate会打印每个split training相关的信息,包括training loss, learning_rate, epoch, ROC curve.

1
2
3
# cross-validate gene classifier
all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts \
= cross_validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1)
0it [00:00, ?it/s]


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\disease_classification\human_dcm_hcm_nf.dataset\cache-509acb05b140c462.arrow


****** Crossval split: 0/4 ******

Filtering training data
Filtered 0%; 49994 remain

Filtering evalation data
1
Split 0 training info...

image-20230822200906467

****** Crossval split: 1/4 ******

Filtering training data
Filtered 0%; 49992 remain

Filtering evalation data
Filtered 4%; 47913 remain

Labeling training data
1
Split 1 training info...

image-20230822200925242,l

****** Crossval split: 2/4 ******

Filtering training data
Filtered 0%; 49993 remain

Filtering evalation data
Filtered 4%; 47886 remain

Labeling training data
1
Split 2 training info...

image-20230822200942481

****** Crossval split: 3/4 ******

Filtering training data
Filtered 0%; 49991 remain

Filtering evalation data
Filtered 4%; 48025 remain

Labeling training data
1
Split 3 training info...

image-20230822200951566

****** Crossval split: 4/4 ******

Filtering training data
Filtered 0%; 49977 remain

Filtering evalation data
Filtered 2%; 48951 remain

Labeling training data
1
Split 4 training info...

[0.25172310458495656, 0.18719408650484468, 0.1628708420737189, 0.2369393666966337, 0.16127260013984618]
1
2
3
# bundle data for plotting
bundled_data = []
bundled_data += [(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")]
1
2
# plot ROC curve
plot_ROC(bundled_data, 'Dosage Sensitive vs Insensitive TFs')

1
2
3
# plot confusion matrix
classes_list = ["Dosage Sensitive", "Dosage Insensitive"]
plot_confusion_matrix(classes_list, confusion, "Geneformer")


以上是5-fold CV的结果,我们接下来尝试用其中10,000个细胞微调的模型在其相应的out-of-sample evaluation set上进行gene classification.

我们首先读入第一个split的fine-tuned model,并将其转换到GPU上。该模型out_features=2即进行二分类预测。

1
2
3
4
5
# reload fine-tuned model
ft_model = BertForTokenClassification.from_pretrained("gene_class_test/230724_geneformer_GeneClassifier_dosageTF_L2048_B2_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/ksplit0/models/")

ft_model.to('cuda:0')
print(ft_model)
BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(25426, 256, padding_idx=0)
      (position_embeddings): Embedding(2048, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.02, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=256, out_features=512, bias=True)
            (intermediate_act_fn): ReLU()
          )
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=256, bias=True)
            (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.02, inplace=False)
          )
        )
      )
    )
  )
  (dropout): Dropout(p=0.02, inplace=False)
  (classifier): Linear(in_features=256, out_features=2, bias=True)
)

我们取出第一个split对应的evaluation targets and labels,并抽取出相应的evaluation set (evalset_oos_labeled)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# out-of-sample evaluation set
# for set 0
label_dict_eval = dict(zip(label_dicts[2], label_dicts[4]))

def if_contains_eval_label(example, label_dict):
a = label_dict.keys()
b = example['input_ids']
return not set(a).isdisjoint(b)

evalset0 = subsampled_train_dataset.filter(if_contains_eval_label, num_proc=2, fn_kwargs={"label_dict": label_dict_eval})
eval_size0 = min(10000, len(evalset0))
half_training_size = round(eval_size0/2)
evalset_oos_min = evalset0.select([i for i in range(half_training_size, eval_size0)])

def generate_eval_labels(example, label_dict):
example["labels"] = [label_dict.get(token_id, -100) for token_id in example["input_ids"]]
return example

evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels, fn_kwargs={"label_dict": label_dict_eval})
evalset_oos_labeled
Dataset({
    features: ['input_ids', 'length', 'cell_type', 'individual', 'age', 'sex', 'disease', 'lvef', 'labels'],
    num_rows: 5000
})

这里我们修改一下原本的classifier_predict让其输出微调模型预测的label (y_pred), 真实label (y_true), 模型的预测值 (logits_list), 细胞ID (cell_id)和转录因子的token (token_id_dict).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# return prediction results
def get_classifier_predict(model, evalset, forward_batch_size):
predict_logits = []# return prediction results
def get_classifier_predict(model, evalset, forward_batch_size):
predict_logits = []
predict_labels = []
model.eval()
cell_id = []
token_id_dict = {}

# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
evalset_len = len(evalset)
max_divisible = find_largest_div(evalset_len, forward_batch_size)
if len(evalset) - max_divisible == 1:
evalset_len = max_divisible

max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])

for i in range(0, evalset_len, forward_batch_size):
max_range = min(i+forward_batch_size, evalset_len)
batch_evalset = evalset.select([i for i in range(i, max_range)])
padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
padded_batch.set_format(type="torch")

# cell id
cell_id += [i for i in range(i, max_range)]
# store token id by cell j
for j, tokens in enumerate(batch_evalset['input_ids']):
cell_idx = range(i, max_range)[j]
token_id_dict[cell_idx] = [tki for k, tki in enumerate(tokens) if batch_evalset['labels'][j][k] > -1]

input_data_batch = padded_batch["input_ids"]
attn_msk_batch = padded_batch["attention_mask"]
label_batch = padded_batch["labels"]
with torch.no_grad():
outputs = model(
input_ids = input_data_batch.to("cuda"),
attention_mask = attn_msk_batch.to("cuda"),
labels = label_batch.to("cuda"),
)
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
predict_labels += [torch.squeeze(label_batch.to("cpu"))]

logits_by_cell = torch.cat(predict_logits)
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
labels_by_cell = torch.cat(predict_labels)
all_labels = torch.flatten(labels_by_cell)
logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]
y_pred = [vote(item[0]) for item in logit_label_paired]
y_true = [item[1] for item in logit_label_paired]
logits_list = [item[0] for item in logit_label_paired]
return y_pred, y_true, logits_list, cell_id, token_id_dict
1
eval_pred, eval_label, eval_logits, cell_id, token_id = get_classifier_predict(model=ft_model, evalset=evalset_oos_labeled, forward_batch_size=20)
1
Model prediction info...

该模型输出两个分类的预测值,根据最大值来判断该基因的label。这里对每个细胞中的tf都进行了预测 (n = 27,939) .

1
2
3
4
5
print(Counter(eval_pred))
print(Counter(eval_label))

print(eval_logits[0:3])
print(eval_pred[0:3])
Counter({1: 16492, 0: 11447})
Counter({0: 14673, 1: 13266})
[[4.6540117263793945, -4.643155574798584], [5.055752277374268, -4.894111156463623], [0.701909065246582, -0.6132677793502808]]
[0, 0, 0]

接下来,我们统计各个转录因子出现的频率。

1
2
3
4
5
6
7
8
9
10
11
# # numbers of tfs (genes with 0/1 label) in out-of-sample evaluation set
# tf_num = [len([v for v in i if v >= 0]) for i in evalset_oos_labeled['labels']]
# sum(tf_num)

# frequencies of tokens
token_freq = Counter()

for tks in token_id.values():
token_freq.update(tks)

token_freq
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
Counter({1636: 2636,
9061: 2755,
6754: 475,
16718: 204,
275: 1445,
15866: 600,
5084: 805,
3361: 272,
2410: 108,
1757: 550,
18597: 82,
10422: 305,
14481: 197,
8218: 766,
16619: 138,
4071: 434,
6931: 1052,
14023: 468,
7445: 699,
4445: 157,
17672: 983,
3982: 547,
5944: 552,
5357: 359,
20144: 237,
6257: 137,
6456: 185,
16597: 437,
2774: 216,
15781: 553,
20018: 386,
23967: 427,
21561: 218,
12006: 116,
20989: 339,
15753: 199,
487: 387,
16016: 530,
998: 496,
8972: 382,
6492: 269,
14410: 180,
14286: 228,
12961: 228,
8725: 26,
2707: 82,
17085: 262,
15375: 72,
13606: 313,
10804: 317,
12959: 527,
12435: 202,
16713: 359,
12674: 184,
20959: 88,
16535: 348,
21035: 131,
11880: 34,
23100: 347,
21079: 114,
20581: 284,
15553: 249,
14677: 63,
954: 171,
17147: 47,
12995: 51,
20962: 74,
12165: 46,
17092: 66,
15717: 54,
9024: 118,
16555: 67,
7705: 78,
13722: 44,
18778: 100,
9831: 41,
5789: 40,
14124: 59,
13954: 31,
10534: 50,
16425: 6,
20787: 3,
9367: 44,
14578: 1,
15180: 1,
12243: 4,
11443: 1,
13066: 1})

这里我们随机看两个基因预测分类是否正确,其中gene 9061被预测准确,为药物敏感基因。而gene 16425预测值与标签值不匹配。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# append all tokens into one list
token_id_list = [tk for tks in token_id.values() for tk in tks]

# successed prediction
# get prediction of gene (token = 9061)
target_pred1 = [eval_pred[i] for i in token_id_list if i == 9061]
print("Predicted label of gene 9061: ")
print(Counter(target_pred1))
target_label1 = [eval_label[i] for i in token_id_list if i == 9061]
print("True label of gene 9061: ")
print(Counter(target_label1))

# failed prediction
# get prediction of gene (token = 16425)
target_pred2 = [eval_pred[i] for i in token_id_list if i == 16425]
print("Predicted label of gene 16425: ")
print(Counter(target_pred2))
target_label2 = [eval_label[i] for i in token_id_list if i == 16425]
print("True label of gene 16425: ")
print(Counter(target_label2))
1
2
3
4
5
6
7
8
Predicted label of gene 9061: 
Counter({1: 2755})
True label of gene 9061:
Counter({1: 2755})
Predicted label of gene 16425:
Counter({0: 6})
True label of gene 16425:
Counter({1: 6})

总结

对于基因分类的微调,我们需要:

  1. 获取相应微调的数据集,并且有基因的label信息,例如某个TF是否为药物靶点之类的;

    关于数据集大小,从作者提供的例子来看,最少的情况是884个细胞,但其余下游任务都超过10k细胞

  2. BertForTokenClassification的方式读入预训练模型,并设置num_labels为分类数目;
  3. 根据微调的数据集训练,加上最后的输出层(task-specific transformer layer),并对微调模型预测性能进行评估;
  4. 在新的数据集上应用微调模型进行预测。

另外,作者最近更新上传了心肌炎单细胞数据微调的模型 (https://huggingface.co/ctheodoris/Geneformer/tree/main/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224)。大家也可以直接下载该模型使用。

Ref:

https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset

Transfer learning enables predictions in network biology: https://doi.org/10.1038/s41586-023-06139-9