-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscgpt.py
179 lines (138 loc) · 4.81 KB
/
scgpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# -*- coding: utf-8 -*-
"""ScGPT.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1-tx7QkMFyvUKhDz8gJ4UEaAnqNixZxgO
"""
from google.colab import drive
drive.mount("/content/gdrive")
# Commented out IPython magic to ensure Python compatibility.
# %cd /content/gdrive/MyDrive/
# Commented out IPython magic to ensure Python compatibility.
!mkdir poetry
# %cd poetry
# Install poetry from pip
!pip install poetry
# Configure poetry to create virtual environments in the project folder
!poetry config virtualenvs.in-project true
# Create the pyproject.toml file
!poetry init
!poetry add torch
VENV_PATH ="/content/gdrive/MyDrive/poetry/.venv"
!ls $VENV_PATH
# Configure poetry to create virtual environments in the project folder
!poetry config virtualenvs.in-project true
!poetry install --no-ansi
!ls $VENV_PATH
import os, sys
LOCAL_VENV_PATH = '/content/venv' # local notebook
os.symlink(VENV_PATH, LOCAL_VENV_PATH) # connect to directory in drive
sys.path.insert(0, LOCAL_VENV_PATH)
import torch
print(torch.__version__)
# Commented out IPython magic to ensure Python compatibility.
# Move in your Drive
# %cd /content/gdrive/MyDrive/
!pip install scgpt "flash-attn<1.0.5" # optional, recommended
# As of 2023.09, pip install may not run with new versions of the google orbax package, if you encounter related issues, please use the following command instead:
# pip install scgpt "flash-attn<1.0.5" "orbax<0.1.8"
!pip install wandb
# Commented out IPython magic to ensure Python compatibility.
#!rm -rf poetry-scGPT/
!git clone https://github.com/bowang-lab/scGPT.git
# %cd scGPT
!poetry install
# Commented out IPython magic to ensure Python compatibility.
# %cd scGPT
!ls
!pip install scanpy
!pip install torchtext
import sys
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.stats import mode
import scanpy as sc
import sklearn
import warnings
sys.path.insert(0, "../")
#import scgpt as scg
# extra dependency for similarity search
try:
import faiss
faiss_imported = True
except ImportError:
faiss_imported = False
print(
"faiss not installed! We highly recommend installing it for fast similarity search."
)
print("To install it, see https://github.com/facebookresearch/faiss/wiki/Installing-Faiss")
warnings.filterwarnings("ignore", category=ResourceWarning)
model_dir="/data/weights/scgpt/scGPT_human"
adata = sc.read_h5ad("/data/annotation_pancreas/demo_train.h5ad")
cell_type_key = "Celltype"
gene_col = "index"
import os, sys
batch_size=32
# output_dir is the path to which the results should be saved
output_dir="../output/scgpt/scgpt_human/"
# path to where we will store the embeddings and other evaluation outputs
model_out = os.path.join(output_dir, "model_outputs")
# if you can use multithreading specify num_workers
num_workers=0
ref_embed_adata = scg.tasks.embed_data(
adata,
model_dir,
gene_col=gene_col,
obs_to_save=cell_type_key, # optional arg, only for saving metainfo
batch_size=64,
return_new_adata=True,
)
# # running on cpu, not recommended since it is slow
# ref_embed_adata = scg.tasks.embed_data(
# adata,
# model_dir,
# gene_col=gene_col,
# obs_to_save=cell_type_key,
# batch_size=64,
# device="cpu",
# use_fast_transformer=False,
# return_new_adata=True,
# )
# Optional step to visualize the reference dataset using the embeddings
sc.pp.neighbors(ref_embed_adata, use_rep="X")
sc.tl.umap(ref_embed_adata)
sc.pl.umap(ref_embed_adata, color=cell_type_key, frameon=False, wspace=0.4)
test_adata = sc.read_h5ad("../data/annotation_pancreas/demo_test.h5ad")
test_embed_adata = scg.tasks.embed_data(
test_adata,
model_dir,
gene_col=gene_col,
obs_to_save=cell_type_key, # optional arg, only for saving metainfo
batch_size=64,
return_new_adata=True,
)
ref_cell_embeddings = ref_embed_adata.X
test_emebd = test_embed_adata.X
k = 10 # number of neighbors
if faiss_imported:
# Declaring index, using most of the default parameters from
index = faiss.IndexFlatL2(ref_cell_embeddings.shape[1])
index.add(ref_cell_embeddings)
# Query dataset, k - number of closest elements (returns 2 numpy arrays)
distances, labels = index.search(test_emebd, k)
idx_list=[i for i in range(test_emebd.shape[0])]
preds = []
for k in idx_list:
if faiss_imported:
idx = labels[k]
else:
idx, sim = get_similar_vectors(test_emebd[k][np.newaxis, ...], ref_cell_embeddings, k)
pred = ref_embed_adata.obs[cell_type_key][idx].value_counts()
preds.append(pred.index[0])
gt = test_adata.obs[cell_type_key].to_numpy()
sklearn.metrics.accuracy_score(gt, preds)
ids_m = np.where(gt == "endothelial")[0]
print(f"Found {len(ids_m)} endothelial cells")
print(f"Predicted cell types: {voting[ids_m]}")
print(f"Annotated cell types: {gt[ids_m]}")