0%

geneformer-cell-class

Geneformer for cell type annotation

Geneformer 是一个基于30M scRNA-seq data训练的Transformer模型,该训练数据包括人类的多种组织器官。Geneformer可以用于细胞水平的分类预测和基因水平的分类预测(例如预测是否为耐药基因),这里我们先根据教程演示其在细胞类型预测上的步骤。

Geneformer首先将细胞的基因表达量转换为rank value encoding作为输入,再传递到transformer架构中进行预训练,后续可以根据下游任务在特定数据集上微调加上最后的输出层。

Geneformer 和其训练数据集 Genecorpus-30M 都可以在hugging face上访问到。

Geneformer环境配置

我们首先配置Geneformer分析所需要的环境。

创建conda环境

1
2
conda create --envs geneformer
conda activate geneformer

为了在jupyter notebook使用该环境,我们需要安装ipykernel.

https://blog.csdn.net/mighty13/article/details/119859242

1
2
conda install -c anaconda ipykernel
python -m ipykernel install --user --name geneformer

Installation

接着,我们下载Geneformer和相关的示例数据集。

Clone project

1
2
3
4
5
git clone https://huggingface.co/ctheodoris/Geneformer
cd Geneformer
pip install .
git lfs install
git lfs pull

Download associated dataset

1
git clone https://huggingface.co/datasets/ctheodoris/Genecorpus-30M

由于clone的.dataset相关文件只有1kb,我们需要手动下载相应训练数据

1
2
3
4
cd Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset.arrow
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/dataset_info.json
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/state.json

Install related modules

1
2
3
4
5
6
7
8
pip3 install seaborn
pip3 install datasets
pip3 install -U scikit-learn
pip3 install transformers
# gpu version of torch
pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
pip3 install statsmodels
pip3 install -U accelerate

Import modules

1
2
3
4
import os
GPU_NUMBER = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
1
2
3
4
5
6
7
8
9
10
11
12
# imports
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns; sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification

Prepare training and evaluation datasets

1
2
# load cell type dataset (includes all tissues)
train_dataset=load_from_disk("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset")

我们读入文章提供的数据集Genecorpus-30M,该数据集以Apache Arrow format提供。

Data Fields

  • cell_type
  • organ_major
  • input_id: rank value encoding for an example cell
  • length: length of rank value encoding for that example cell

For rank value

  1. 计算各个检测到的基因在所有细胞中的非零中位值(nonzero median);
  2. 对每个细胞中的基因read counts除以该细胞的总read counts以校正测序深度;
  3. 对每个细胞的每个基因除以其相应的非零中位值以求得normalized expression;
  4. 基于每个细胞的normalized expression进行ranking,获得rank values。

The rank value encodings for each single cell transcriptome were then tokenized based on a total vocabulary of 25,424 protein-coding or miRNA genes detected within Geneformer-30M. The token dictionary mapping each token ID to special tokens (pad and mask) or Ensembl IDs for each gene is included within the repository as a pickle file (token_dictionary.pkl).

Why the rank values do not range from 1 to the number of genes in that cell?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# elements of train_dataset
# Includes 249,556 cells in total
print(train_dataset)

# number of cell for each celltypes
print("Celltypes:")
print(Counter(train_dataset['cell_type']))

# number of cells from different organs
print("\nOrgans:")
print(Counter(train_dataset['organ_major']))

# rank value encoding for each cell
print(len(train_dataset['input_ids'][1]))

# number of genes of each cells
print(train_dataset['length'][1])
Dataset({
    features: ['cell_type', 'input_ids', 'length', 'organ_major'],
    num_rows: 249556
})
Counter({'B cell (Plasmocyte)': 20728, 'T cell': 16695, 'Enterocyte progenitor': 15441, 'Fetal epithelial progenitor': 14580, 'Fetal neuron': 12287, 'Fetal mesenchymal progenitor': 11905, 'Erythroid progenitor cell (RP high)': 10819, 'Hepatocyte/Endodermal cell': 9781, 'Fetal enterocyte ': 9613, 'Erythroid cell': 9089, 'Loop of Henle': 8439, 'Macrophage': 7854, 'Monocyte': 7541, 'Epithelial cell': 7458, 'AT2 cell': 7333, 'Neutrophil': 7276, 'Fibroblast': 6980, 'Dendritic cell': 5960, 'Pancreas exocrine cell': 5538, 'Endothelial cell (APC)': 5431, 'M2 Macrophage': 5373, 'Endothelial cell': 4738, 'Antigen presenting cell (RPS high)': 4658, 'Intercalated cell': 4414, 'Sinusoidal endothelial cell': 3844, 'B cell': 3783, 'Endothelial cell (endothelial to mesenchymal transition)': 3647, 'Fetal acinar cell': 3214, 'Ureteric bud cell': 2472, 'Enterocyte': 1994, 'Proximal tubule progenitor': 1846, 'Smooth muscle cell': 1794, 'Fetal stromal cell': 1061, 'Stromal cell': 1029, 'Mast cell': 968, 'Fetal endocrine cell': 834, 'Neutrophil (RPS high)': 604, 'Intermediated cell': 529, 'Proliferating T cell': 528, 'CB CD34+': 423, 'Basal cell': 236, 'Primordial germ cell': 230, 'Fetal fibroblast': 125, 'Fetal Neuron': 114, 'Stratified epithelial cell': 113, 'Fetal skeletal muscle cell': 57, 'Fetal chondrocyte': 49, 'Mesothelial cell': 30, 'Goblet cell': 19, 'Chondrocyte': 19, 'hESC': 18, 'Fasciculata cell': 13, 'Gastric endocrine cell': 7, 'Myeloid cell': 7, 'Epithelial cell (intermediated)': 7, 'Astrocyte': 4, 'Kidney intercalated cell': 3, 'Ventricle cardiomyocyte': 2, 'Immature sertoli cell (Pre-Sertoli cell)': 2})
Counter({'large_intestine': 50363, 'kidney': 45059, 'lung': 33309, 'liver': 28376, 'pancreas': 28116, 'immune': 17110, 'spleen': 15614, 'brain': 13440, 'placenta': 9509, 'bone_marrow': 8660})
569
569
1
2
3
print(train_dataset['length'][2])
print(min(train_dataset['input_ids'][2]))
print(max(train_dataset['input_ids'][2]))
625
9
25414

接着,我们将每个组织的细胞分为80%的training set和20%的evaluation set以供后续的model fine-tune使用。

这里的细胞都带有celltype labels,并转换为数值标记,例如”B cell” – 1。之后提供给下游fine-tune训练使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
dataset_list = []
evalset_list = []
organ_list = []
target_dict_list = []

for organ in Counter(train_dataset["organ_major"]).keys():
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
# for each organ
if organ in ["bone_marrow"]:
continue
elif organ=="immune":
organ_ids = ["immune", "bone_marrow"]
organ_list += ["immune"]
else:
organ_ids = [organ]
organ_list += [organ]

print(organ)

# filter datasets for given organ
def if_organ(example, organ_ids):
return example["organ_major"] in organ_ids
# if_organ cannot access the outside variable, pass the organ_ids directly
trainset_organ = train_dataset.filter(if_organ, num_proc=4, fn_kwargs={"organ_ids": organ_ids}) # `num_proc` indicates the number of processors

# per scDeepsort published method, drop cell types representing < 0.5% of cells
celltype_counter = Counter(trainset_organ["cell_type"])
total_cells = sum(celltype_counter.values())
# generates a new list that includes only the keys from the celltype_counter dictionary, but only if the corresponding value is greater than 0.5%
cells_to_keep = [k for k,v in celltype_counter.items() if v > (0.005*total_cells)]

def if_not_rare_celltype(example, cells_to_keep):
return example["cell_type"] in cells_to_keep
trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=6, fn_kwargs={"cells_to_keep": cells_to_keep}) # filter rare cell type (cells < 0.5% )

# shuffle datasets and rename columns
trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)
trainset_organ_shuffled = trainset_organ_shuffled.rename_column("cell_type","label")
trainset_organ_shuffled = trainset_organ_shuffled.remove_columns("organ_major")

# create dictonary of cell types : label ids
# Counter returns dictionary with element as key and number of occurences as value
target_names = list(Counter(trainset_organ_shuffled["label"]).keys())
target_name_id_dict = dict(zip(target_names, [i for i in range(len(target_names))]))
target_dict_list += [target_name_id_dict]

# change labels to numerical ids
def classes_to_ids(example, target_name_id_dict):
example["label"] = target_name_id_dict[example["label"]]
return example
labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=4, fn_kwargs={"target_name_id_dict": target_name_id_dict})

# create 80/20 train/eval splits
labeled_train_split = labeled_trainset.select([i for i in range(0, round(len(labeled_trainset)*0.8))])
labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])

# filter dataset for cell types in corresponding training set
trained_labels = list(Counter(labeled_train_split["label"]).keys())
def if_trained_label(example, trained_labels):
return example["label"] in trained_labels
labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=4, fn_kwargs={"trained_labels": trained_labels})

dataset_list += [labeled_train_split]
evalset_list += [labeled_eval_split_subset]
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d895a8dc1f433c21_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-5f89f392ef5206d4_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-ed12a33637a220d0.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-37f89eeea757d6c5_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-74f78f156da669e5_*_of_00004.arrow


spleen


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f6dc5b7fe424bf2d_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-896e8ac5f576851c_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-820cdcb7f7383e21.arrow


kidney


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d4006d9701718093_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-3d777eac360cb136_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-0292fb0af10803a4_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-a996d2bf76adce7c_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-65960d687cc54b82.arrow


lung


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-456daf9851000084_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-60719877e54cdb77_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-5304f89297ce82b0_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-5284c824687e6edd_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-01ed43e584533226.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-53c0460a585f1ee6_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-3d719e36692d0047_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f94573c555aec2c1_*_of_00004.arrow


brain
placenta


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c546b4750f90df75_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f29113e9ecf5a9a6.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-3e7c6c00c9efd043_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f57326d9b39be686_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d155ab4f91b9b109_*_of_00004.arrow


immune


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-7699e47999d0e20a_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c7fa1566fa301d6f.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c30b7a2b73e574a5_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-b2e2f437bfe2e62b_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f30bce417351df8c_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f040851bda405e47_*_of_00006.arrow


large_intestine


Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-fc4569333911ef13.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-791285ddf886e3a6_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-06ee2b3139e1b0a8_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-2ae398bc8c532f07_*_of_00004.arrow


pancreas


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-a47c5e32577914f3_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-0bc8012540760dc1.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-d4dcf029b0b1437a_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-f2a3bf8b55c9560b_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-8434ffd865a76d79_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-dbd87150b1b95134_*_of_00006.arrow
Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-b56b8ddf9ca9920d.arrow


liver


Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-a31834f681c29f0f_*_of_00004.arrow
Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\cell_type_annotation\cell_type_train_data.dataset\cache-c65ae92f453eb94b_*_of_00004.arrow

每个组织的celltype和label id的映射关系存储在target_dict_list这个list中

1
2
3
4
5
# number of cells for each organ
for i in range(0,len(organ_list)):
print(organ_list[i])
print(Counter(target_dict_list[i]))

spleen
Counter({'Endothelial cell (APC)': 5, 'Macrophage': 4, 'Neutrophil': 3, 'T cell': 2, 'B cell': 1, 'B cell (Plasmocyte)': 0})

kidney
Counter({'T cell': 14, 'Dendritic cell': 13, 'Fetal stromal cell': 12, 'Smooth muscle cell': 11, 'Endothelial cell': 10, 'Macrophage': 9, 'Proximal tubule progenitor': 8, 'Intermediated cell': 7, 'Fetal mesenchymal progenitor': 6, 'Intercalated cell': 5, 'Ureteric bud cell': 4, 'Loop of Henle': 3, 'Fetal epithelial progenitor': 2, 'Epithelial cell': 1, 'Endothelial cell (APC)': 0})

lung
Counter({'Basal cell': 15, 'Monocyte': 14, 'Endothelial cell': 13, 'Proliferating T cell': 12, 'Dendritic cell': 11, 'Fetal epithelial progenitor': 10, 'B cell (Plasmocyte)': 9, 'Mast cell': 8, 'Endothelial cell (APC)': 7, 'T cell': 6, 'Endothelial cell (endothelial to mesenchymal transition)': 5, 'M2 Macrophage': 4, 'Macrophage': 3, 'Fetal mesenchymal progenitor': 2, 'AT2 cell': 1, 'Smooth muscle cell': 0})

brain
Counter({'Fetal epithelial progenitor': 5, 'Fetal endocrine cell': 4, 'Erythroid cell': 3, 'Macrophage': 2, 'Fetal mesenchymal progenitor': 1, 'Fetal neuron': 0})

placenta
Counter({'Macrophage': 2, 'Epithelial cell': 1, 'Fibroblast': 0})

immune
Counter({'B cell': 9, 'Neutrophil (RPS high)': 8, 'B cell (Plasmocyte)': 7, 'Dendritic cell': 6, 'Erythroid progenitor cell (RP high)': 5, 'Erythroid cell': 4, 'Neutrophil': 3, 'T cell': 2, 'Monocyte': 1, 'Antigen presenting cell (RPS high)': 0})

large_intestine
Counter({'Endothelial cell': 15, 'Smooth muscle cell': 14, 'Fetal stromal cell': 13, 'B cell': 12, 'Stromal cell': 11, 'Enterocyte': 10, 'T cell': 9, 'Macrophage': 8, 'Dendritic cell': 7, 'Epithelial cell': 6, 'Fetal neuron': 5, 'Fetal mesenchymal progenitor': 4, 'B cell (Plasmocyte)': 3, 'Fetal enterocyte ': 2, 'Hepatocyte/Endodermal cell': 1, 'Enterocyte progenitor': 0})

pancreas
Counter({'Endothelial cell (APC)': 14, 'Smooth muscle cell': 13, 'Dendritic cell': 12, 'Fetal epithelial progenitor': 11, 'Erythroid cell': 10, 'Fetal endocrine cell': 9, 'Endothelial cell': 8, 'Enterocyte progenitor': 7, 'Fetal neuron': 6, 'Macrophage': 5, 'Fetal mesenchymal progenitor': 4, 'Pancreas exocrine cell': 3, 'Fetal acinar cell': 2, 'T cell': 1, 'B cell': 0})

liver
Counter({'CB CD34+': 11, 'B cell': 10, 'Neutrophil (RPS high)': 9, 'B cell (Plasmocyte)': 8, 'T cell': 7, 'Neutrophil': 6, 'Monocyte': 5, 'Dendritic cell': 4, 'Macrophage': 3, 'Sinusoidal endothelial cell': 2, 'Erythroid cell': 1, 'Erythroid progenitor cell (RP high)': 0})
1
2
3
trainset_dict = dict(zip(organ_list, dataset_list))
traintargetdict_dict = dict(zip(organ_list, target_dict_list))
evalset_dict = dict(zip(organ_list, evalset_list))

Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance

接下来,使用预设的hyperparameters进行训练,作者建议根据下游任务调整hyperparameters。

另外,我们定义一个评估模型预测性能的函数compute_metrics.

1
2
3
4
5
6
7
8
9
10
11
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
# calculate accuracy and macro f1 using sklearn's function
acc = accuracy_score(labels, preds)
macro_f1 = f1_score(labels, preds, average='macro')
return {
'accuracy': acc,
'macro_f1': macro_f1
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# set model parameters
# max inpiut size
max_input_size = 2 ** 11 # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 0
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 6
# batch size for training and eval
# reducing batch size for limited gpu memeory
geneformer_batch_size = 4
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 10
# optimizer
optimizer = "adamw"

接着,我们对每个组织都微调一个细胞分类的预测器,其中包括, brain, immune, kidney, large intestine, liver, lung, pancreas, placenta, and spleen.

这一步耗时很久

模型通过BertForSequenceClassification.from_pretrained读入。num_labels为模型输出的class数目,这里设置成每个组织对应的细胞类型数量即可。

随后,创建trainer并进行训练,.predict()进行预测。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
for organ in organ_list:
print(organ)
organ_trainset = trainset_dict[organ]
organ_evalset = evalset_dict[organ]
organ_label_dict = traintargetdict_dict[organ]

# set logging steps
logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)

# reload pretrained model
model = BertForSequenceClassification.from_pretrained("D:\\jupyterNote\\Geneformer",
num_labels = len(organ_label_dict.keys()),
output_attentions = False,
output_hidden_states = False).to("cuda")

# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"D:\\jupyterNote\\Geneformer\\examples\\cell_class_test\\{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}\\"

# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
raise Exception("Model already saved to this directory.")

# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)

# set training arguments
training_args = {
"learning_rate": max_lr,
"do_train": True,
"do_eval": True,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"logging_steps": logging_steps,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": 0.001,
"per_device_train_batch_size": geneformer_batch_size,
"per_device_eval_batch_size": geneformer_batch_size,
"num_train_epochs": epochs,
"load_best_model_at_end": True,
"output_dir": output_dir,
}

training_args_init = TrainingArguments(**training_args)

# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForCellClassification(),
train_dataset=organ_trainset,
eval_dataset=organ_evalset,
compute_metrics=compute_metrics
)

# train the cell type classifier
trainer.train()
predictions = trainer.predict(organ_evalset)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
pickle.dump(predictions, fp)
trainer.save_metrics("eval",predictions.metrics)
trainer.save_model(output_dir)
1
2
3
4
5
6
7
8
    Some weights of the model checkpoint at D:\jupyterNote\Geneformer were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at D:\jupyterNote\Geneformer and are newly initialized: ['bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<...training message...>

训练结束后,训练的模型,及其预测结果都输出到设置的output_dir下。

每个文件夹中都包含了fine-tuned model相关文件(training_args.bin, config.json, pytorch_model.bin),以及预测结果相关文件(predictions.pickle

1
2
3
4
5
$ ls 230719_geneformer_CellClassifier_brain_L2048_B4_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/

all_results.json checkpoint-15984 checkpoint-23976 checkpoint-5328 eval_results.json training_args.bin
checkpoint-10656 checkpoint-18648 checkpoint-2664 checkpoint-7992 predictions.pickle
checkpoint-13320 checkpoint-21312 checkpoint-26640 config.json pytorch_model.bin
1
2
3
# clear GPU memory after pytorch training 
import torch
torch.cuda.empty_cache()
1
2
# The pretrained model
model
BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(25426, 256, padding_idx=0)
      (position_embeddings): Embedding(2048, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.02, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.02, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=256, out_features=512, bias=True)
            (intermediate_act_fn): ReLU()
          )
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=256, bias=True)
            (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.02, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=256, out_features=256, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.02, inplace=False)
  (classifier): Linear(in_features=256, out_features=12, bias=True)
)

接下来,我们将基于immune数据的微调模型应用到3k PBMCs scRNA-seq data上进行celltype prediction.

首先,我们需要将原始的测序counts转换为rank values encoding (tk.tokenize_data).

1
2
3
4
5
6
7
# applying fine-tuned model on new datasets
# using the 3k PBMCs dataset
# 1. transform scRNA-seq expression data to rank value .dataset format
from geneformer import TranscriptomeTokenizer

tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
tk.tokenize_data("D:/jupyterNote/pySC/output/", output_directory="token_data/", output_prefix="tk_pbmc3k")
Tokenizing D:\jupyterNote\pySC\output\pbmc3k.loom
D:\jupyterNote\pySC\output\pbmc3k.loom has no column attribute 'filter_pass'; tokenizing all cells.

读入tokenized dataset

1
2
3
# 2. load new dataset
new_dataset = load_from_disk("D:/jupyterNote/Geneformer/examples/token_data/tk_pbmc3k.dataset/")
new_dataset
Dataset({
    features: ['input_ids', 'cell_type', 'organ_major', 'length'],
    num_rows: 2638
})
1
2
3
4
import pandas as pd

# input_ids represent rank encodings
pd.DataFrame(new_dataset)

input_ids cell_type organ_major length
0 [19693, 10551, 2362, 1869, 4658, 18585, 9039, ... CD4 T immune 139
1 [5307, 10632, 8073, 3539, 19629, 516, 18552, 1... B immune 247
2 [6729, 226, 9004, 11666, 4621, 1433, 8198, 327... CD4 T immune 200
3 [3057, 12015, 17654, 6179, 2522, 2770, 417, 62... FCGR3A Monocytes immune 185
4 [12623, 18649, 10321, 2313, 7245, 13219, 242, ... NK immune 91
... ... ... ... ...
2633 [2522, 13139, 3539, 449, 19629, 488, 2770, 695... CD14 Monocytes immune 241
2634 [13492, 6556, 12734, 3078, 3539, 643, 695, 195... B immune 189
2635 [14848, 16634, 19629, 2987, 5214, 11314, 17576... B immune 108
2636 [19629, 19899, 14408, 3078, 7380, 1868, 4472, ... B immune 93
2637 [4658, 18585, 18828, 14377, 10632, 13116, 1481... CD4 T immune 118

2638 rows × 4 columns

由于模型要求input tensors(每个细胞的rank encoding)长度一致,这里将其padding到统一长度(所有细胞中最多的基因数)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from geneformer.pretrainer import token_dictionary

def preprocess_classifier_batch(cell_batch, max_len):
if max_len == None:
max_len = max([len(i) for i in cell_batch["input_ids"]])
def pad_label_example(example):
#example["labels"] = np.pad(example["labels"],
# (0, max_len-len(example["input_ids"])),
# mode='constant', constant_values=-100)
example["input_ids"] = np.pad(example["input_ids"],
(0, max_len-len(example["input_ids"])),
mode='constant', constant_values=token_dictionary.get("<pad>"))
example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int)
return example
padded_batch = cell_batch.map(pad_label_example)
return padded_batch

# Function to find the largest number smaller
# than or equal to N that is divisible by k
def find_largest_div(N, K):
rem = N % K
if(rem == 0):
return N
else:
return N - rem

1
2
3
4
# padded to be the same length.
set_len=len(new_dataset)
max_set_len = max(new_dataset.select([i for i in range(set_len)])["length"])
padded_dataset = preprocess_classifier_batch(new_dataset, max_set_len)
Loading cached processed dataset at D:\jupyterNote\Geneformer\examples\token_data\tk_pbmc3k.dataset\cache-4214517ba3d677b2.arrow
1
pd.DataFrame(padded_dataset)

input_ids cell_type organ_major length attention_mask
0 [19693, 10551, 2362, 1869, 4658, 18585, 9039, ... CD4 T immune 139 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
1 [5307, 10632, 8073, 3539, 19629, 516, 18552, 1... B immune 247 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
2 [6729, 226, 9004, 11666, 4621, 1433, 8198, 327... CD4 T immune 200 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
3 [3057, 12015, 17654, 6179, 2522, 2770, 417, 62... FCGR3A Monocytes immune 185 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
4 [12623, 18649, 10321, 2313, 7245, 13219, 242, ... NK immune 91 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
... ... ... ... ... ...
2633 [2522, 13139, 3539, 449, 19629, 488, 2770, 695... CD14 Monocytes immune 241 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
2634 [13492, 6556, 12734, 3078, 3539, 643, 695, 195... B immune 189 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
2635 [14848, 16634, 19629, 2987, 5214, 11314, 17576... B immune 108 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
2636 [19629, 19899, 14408, 3078, 7380, 1868, 4472, ... B immune 93 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
2637 [4658, 18585, 18828, 14377, 10632, 13116, 1481... CD4 T immune 118 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...

2638 rows × 5 columns

接下来,读入微调模型进行预测

1
2
3
4
5
6
7
8
9
10
11
# 3. load the fine-tuned model
# reload fine-tuned model
ft_model = BertForSequenceClassification.from_pretrained("cell_class_test/230719_geneformer_CellClassifier_immune_L2048_B4_LR5e-05_LSlinear_WU500_E10_Oadamw_F0/")
# since immune organ only include 10 celltypes in the training set, the out_features=10 in the final output layer
print(ft_model)

ft_trainer = Trainer(model=ft_model)

# 4. perform prediction
ct_predictions = ft_trainer.predict(padded_dataset)
ct_pred = ct_predictions.predictions
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(25426, 256, padding_idx=0)
(position_embeddings): Embedding(2048, 256)
(token_type_embeddings): Embedding(2, 256)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-5): 6 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=256, out_features=256, bias=True)
(key): Linear(in_features=256, out_features=256, bias=True)
(value): Linear(in_features=256, out_features=256, bias=True)
(dropout): Dropout(p=0.02, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=256, out_features=256, bias=True)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=256, out_features=512, bias=True)
(intermediate_act_fn): ReLU()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=256, bias=True)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=256, out_features=256, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.02, inplace=False)
(classifier): Linear(in_features=256, out_features=10, bias=True)
)

使用fine-tuned model进行分类预测,这里根据最大预测值判断cell type。

1
2
3
4
5
6
7
# celltype : index 
immune_label_idx_dict = target_dict_list[5]

# get the predicted cell type by max value of prediction
ct_pred_id = ct_pred.argmax(-1) # list
ct_pred_label = [k for idx in ct_pred_id for k, v in immune_label_idx_dict.items() if v == idx]
print(len(ct_pred_label))
2638
1
2
3
print(ct_pred_id[0:10])
print(ct_pred_label[0:10])
print(immune_label_idx_dict)
[2 9 2 1 2 9 2 2 2 1]
['T cell', 'B cell', 'T cell', 'Monocyte', 'T cell', 'B cell', 'T cell', 'T cell', 'T cell', 'Monocyte']
{'Antigen presenting cell (RPS high)': 0, 'Monocyte': 1, 'T cell': 2, 'Neutrophil': 3, 'Erythroid cell': 4, 'Erythroid progenitor cell (RP high)': 5, 'Dendritic cell': 6, 'B cell (Plasmocyte)': 7, 'Neutrophil (RPS high)': 8, 'B cell': 9}

用UMAP可视化细胞分类的结果

1
2
3
4
5
6
7
import numpy as np
import pandas as pd
import scanpy as sc
import anndata

adata = anndata.read_h5ad("D:/jupyterNote/pySC/output/pbmc3k.h5ad")
adata
AnnData object with n_obs × n_vars = 2638 × 1838
    obs: 'n_genes', 'n_genes_by_counts', 'n_counts', 'total_counts_mt', 'pct_counts_mt', 'leiden', 'cell_type', 'organ_major'
    var: 'ensembl_id', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std', 'gene_name'
    uns: 'hvg', 'leiden', 'leiden_colors', 'log1p', 'neighbors', 'pca', 'rank_genes_groups', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    layers: 'counts', 'data', 'scaled'
    obsp: 'connectivities', 'distances'

尽管Geneformer对细胞的名称和原数据不太一样,我们可以看到Geneformer注释的结果大体上和原本注释是一致的。总的来说,Geneformer可以作为一种细胞类型预测的工具使用,但最好先对预训练模型微调,这要求我们有相关的单细胞数据集进行微调训练。

1
2
3
adata.obs['geneformer_pred'] = ct_pred_label
sc.pl.umap(adata, color='geneformer_pred')
sc.pl.umap(adata, color='cell_type')
E:\miniconda3\envs\geneformer\lib\site-packages\scanpy\plotting\_tools\scatterplots.py:392: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  cax = scatter(

image-20230802100536752

E:\miniconda3\envs\geneformer\lib\site-packages\scanpy\plotting\_tools\scatterplots.py:392: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  cax = scatter(

1
adata.obs

n_genes n_genes_by_counts n_counts total_counts_mt pct_counts_mt leiden cell_type organ_major
AAACATACAACCAC-1 781 779 2419.0 73.0 3.017776 CD4 T CD4 T immune
AAACATTGAGCTAC-1 1352 1352 4903.0 186.0 3.793596 B B immune
AAACATTGATCAGC-1 1131 1129 3147.0 28.0 0.889736 CD4 T CD4 T immune
AAACCGTGCTTCCG-1 960 960 2639.0 46.0 1.743085 FCGR3A Monocytes FCGR3A Monocytes immune
AAACCGTGTATGCG-1 522 521 980.0 12.0 1.224490 NK NK immune
... ... ... ... ... ... ... ... ...
TTTCGAACTCTCAT-1 1155 1153 3459.0 73.0 2.110436 CD14 Monocytes CD14 Monocytes immune
TTTCTACTGAGGCA-1 1227 1224 3443.0 32.0 0.929422 B B immune
TTTCTACTTCCTCG-1 622 622 1684.0 37.0 2.197150 B B immune
TTTGCATGAGAGGC-1 454 452 1022.0 21.0 2.054795 B B immune
TTTGCATGCCTCAC-1 724 723 1984.0 16.0 0.806452 CD4 T CD4 T immune

2638 rows × 8 columns

总结

对于细胞分类的微调,我们需要:

  1. 获取组织对应的微调数据集,并且有细胞的label信息,例如各个细胞类型;

    关于数据集大小,从作者提供的例子来看,最少的情况是884个细胞,但其余下游任务都超过10k细胞

  2. BertForSequenceClassification的方式读入预训练模型,并设置num_labels为分类数目;
  3. 根据微调的数据集训练,加上最后的输出层(task-specific transformer layer),并对微调模型预测性能进行评估;
  4. 在新的数据集上应用微调模型进行预测。

Ref:

Transfer learning enables predictions in network biology: https://doi.org/10.1038/s41586-023-06139-9

https://huggingface.co/ctheodoris/Geneformer