-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_SpaHDmap.py
109 lines (87 loc) · 5.58 KB
/
run_SpaHDmap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import json
import argparse
import numpy as np
import scanpy as sc
import torch
import SpaHDmap as hdmap
## -------------------------------------------------------------------------------------------------
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default=None, help='Path to the configuration file.')
parser.add_argument('-r', '--rank', type=int, default=20, help='Rank for the model.')
parser.add_argument('-s', '--seed', type=int, default=123, help='Seed for random number generation.')
parser.add_argument('-d', '--device', type=int, default=0, help='Device ID for CUDA.')
parser.add_argument('--visualize', type=str2bool, default=True, help='Enable visualization.')
parser.add_argument('--create_mask', type=str2bool, default=True, help='Enable creating mask.')
parser.add_argument('--swap_coord', type=str2bool, default=True, help='Enable swapping coordinates.')
parser.add_argument('--select_svgs', type=str2bool, default=True, help='Enable selecting SVGs.')
parser.add_argument('--n_top_genes', type=int, default=3000, help='Number of top genes to select.')
parser.add_argument('--save_model', type=str2bool, default=True, help='Enable saving model.')
parser.add_argument('--save_score', action='store_true', help='Enable saving score.')
parser.add_argument('--verbose', action='store_true', help='Enable verbose output.')
args = parser.parse_args()
assert args.config is not None, "Please specify the configuration file."
assert args.config.endswith('.json'), "The configuration file should be in JSON format."
assert os.path.exists(args.config), "The configuration file does not exist."
assert type(args.rank) == int and args.rank > 0, "The rank should be a positive integer."
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Random seed: {args.seed}")
np.random.seed(args.seed)
torch.manual_seed(args.seed)
## -------------------------------------------------------------------------------------------------
print("Step 0: Load and preprocess data")
# 0. Load config from JSON file
with open(args.config, 'r') as f:
config = json.load(f)
# 0.1 parameter setting
radius = config['paras']['radius'] if 'radius' in config['paras'] else None
scale_rate = config['paras']['scale_rate'] if 'scale_rate' in config['paras'] else 1
if radius.__class__ == list: assert len(radius) == len(config['sections']), "The length of radius should match the number of sections."
if scale_rate.__class__ == list: assert len(scale_rate) == len(config['sections']), "The length of scale_rate should match the number of sections."
if len(config['sections']) > 1 and (radius.__class__ == int or scale_rate.__class__ == int): print("Warning: The radius or scale_rate is not specified for each section, use the same value for all sections.")
reference = config['paras']['reference'] if 'reference' in config['paras'] else None
all_section_names = [section['name'] for section in config['sections']]
if reference is not None:
assert set(reference.keys()).issubset(all_section_names) and set(reference.values()).issubset(all_section_names), "The query or reference section should be in the section list."
assert set(reference.keys()) & set(reference.values()) == set(), "No section should be both reference and query."
# 0.2 path setting
root_path = config['settings']['root_path']
project = config['settings']['project']
results_path = f'{root_path}/{project}/Results_Rank{args.rank}/'
# 0.3 section setting
section_list = config['sections']
# 0.4 read section data
sections = []
for idx, section in enumerate(section_list):
section_name = section['name']
tmp_radius = radius[idx] if radius.__class__ == list else radius
tmp_scale_rate = scale_rate[idx] if scale_rate.__class__ == list else scale_rate
image_path = section['image_path'] if 'image_path' in section else None
adata_path = section['adata_path'] if 'adata_path' in section else None
if adata_path is not None:
adata = sc.read(adata_path)
sections.append(hdmap.prepare_stdata(adata=adata, section_name=section_name, image_path=image_path,
scale_rate=tmp_scale_rate, radius=tmp_radius,
swap_coord=args.swap_coord, create_mask=args.create_mask))
continue
visium_path = section['visium_path'] if 'visium_path' in section else None
spot_coord_path = section['spot_coord_path'] if 'spot_coord_path' in section else None
spot_exp_path = section['spot_exp_path'] if 'spot_exp_path' in section else None
sections.append(hdmap.prepare_stdata(section_name=section_name, image_path=image_path, visium_path=visium_path,
spot_coord_path=spot_coord_path, spot_exp_path=spot_exp_path,
scale_rate=tmp_scale_rate, radius=tmp_radius, swap_coord=args.swap_coord))
if args.select_svgs: hdmap.select_svgs(sections, n_top_genes=args.n_top_genes)
## -------------------------------------------------------------------------------------------------
mapper = hdmap.Mapper(section=sections, reference=reference, rank=args.rank,
results_path=results_path, verbose=args.verbose)
mapper.run_SpaHDmap(save_model=args.save_model, save_score=args.save_score, visualize=args.visualize)