Skip to content

Commit

Permalink
add and run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
wenh06 committed Oct 7, 2024
1 parent af55734 commit 51a9313
Show file tree
Hide file tree
Showing 25 changed files with 504 additions and 1,288 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,3 @@ log
*.pth.tar

!saved_models/*.pth.tar

25 changes: 25 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ignore svg, txt, json, html files and the folders final_results
exclude: '^.*\.(svg|txt|json|html)$|torch_ecg|references|official_sample_code|signal_processing/pantompkins\.py|\.ipynb_checkpoints|fast-test.*\.ipynb'
fail_fast: false

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
- id: black
args: [--line-length=128, --verbose]
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
- id: flake8
args: [--max-line-length=128, '--exclude=./.*,build,dist,official*,torch_ecg,references,*.ipynb', '--ignore=E501,W503,E203,F841,E402,E231,E731', --count, --statistics, --show-source]
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: [--profile=black, --line-length=128]
69 changes: 28 additions & 41 deletions cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,35 @@

from torch_ecg.torch_ecg.model_configs import ( # noqa: F401
ECG_SEQ_LAB_NET_CONFIG,
RR_LSTM_CONFIG,
ECG_SUBTRACT_UNET_CONFIG,
ECG_UNET_VANILLA_CONFIG,
RR_AF_CRF_CONFIG,
RR_AF_VANILLA_CONFIG,
ECG_UNET_VANILLA_CONFIG,
ECG_SUBTRACT_UNET_CONFIG,
vgg_block_basic,
vgg_block_mish,
vgg_block_swish,
vgg16,
vgg16_leadwise,
resnet_block_stanford,
resnet_stanford,
resnet_block_basic,
resnet_bottle_neck,
resnet,
resnet_leadwise,
multi_scopic_block,
multi_scopic,
multi_scopic_leadwise,
dense_net_leadwise,
xception_leadwise,
lstm,
RR_LSTM_CONFIG,
attention,
dense_net_leadwise,
global_context,
linear,
lstm,
multi_scopic,
multi_scopic_block,
multi_scopic_leadwise,
non_local,
resnet,
resnet_block_basic,
resnet_block_stanford,
resnet_bottle_neck,
resnet_leadwise,
resnet_stanford,
squeeze_excitation,
global_context,
vgg16,
vgg16_leadwise,
vgg_block_basic,
vgg_block_mish,
vgg_block_swish,
xception_leadwise,
)


__all__ = [
"BaseCfg",
"TrainCfg",
Expand Down Expand Up @@ -72,13 +71,9 @@
"paroxysmal atrial fibrillation": 2,
"persistent atrial fibrillation": 1,
}
BaseCfg.class_abbr_map = {
k: BaseCfg.class_fn_map[v] for k, v in BaseCfg.class_abbr2fn.items()
}
BaseCfg.class_abbr_map = {k: BaseCfg.class_fn_map[v] for k, v in BaseCfg.class_abbr2fn.items()}

BaseCfg.bias_thr = (
0.15 * BaseCfg.fs
) # rhythm change annotations onsets or offset of corresponding R peaks
BaseCfg.bias_thr = 0.15 * BaseCfg.fs # rhythm change annotations onsets or offset of corresponding R peaks
BaseCfg.beat_ann_bias_thr = 0.1 * BaseCfg.fs # half width of broad qrs complex
BaseCfg.beat_winL = 250 * BaseCfg.fs // 1000 # corr. to 250 ms
BaseCfg.beat_winR = 250 * BaseCfg.fs // 1000 # corr. to 250 ms
Expand Down Expand Up @@ -129,9 +124,7 @@

TrainCfg.label_smoothing = 0.1
TrainCfg.random_mask = int(TrainCfg.fs * 0.0) # 1.0s, 0 for no masking
TrainCfg.stretch_compress = (
5 # stretch or compress in time axis, units in percentage (0 - inf)
)
TrainCfg.stretch_compress = 5 # stretch or compress in time axis, units in percentage (0 - inf)
TrainCfg.stretch_compress_prob = 0.3 # probability of performing stretch or compress
TrainCfg.random_normalize = True # (re-)normalize to random mean and std
# valid segments has
Expand Down Expand Up @@ -160,9 +153,7 @@
# [0.0, 0.01],
# ])

TrainCfg.flip = [-1] + [
1
] * 4 # making the signal upside down, with probability 1/(1+4)
TrainCfg.flip = [-1] + [1] * 4 # making the signal upside down, with probability 1/(1+4)
# TODO: explore and add more data augmentations

# configs of training epochs, batch, etc.
Expand Down Expand Up @@ -228,9 +219,7 @@
TrainCfg.rr_lstm.model_name = "lstm" # "lstm", "lstm_crf"
TrainCfg.rr_lstm.input_len = 30 # number of rr intervals ( number of rpeaks - 1)
TrainCfg.rr_lstm.overlap_len = 15 # number of rr intervals ( number of rpeaks - 1)
TrainCfg.rr_lstm.critical_overlap_len = (
25 # number of rr intervals ( number of rpeaks - 1)
)
TrainCfg.rr_lstm.critical_overlap_len = 25 # number of rr intervals ( number of rpeaks - 1)
TrainCfg.rr_lstm.classes = [
"af",
]
Expand Down Expand Up @@ -352,12 +341,10 @@
ModelCfg.main.unet.reduction = 1
ModelCfg.main.unet.init_num_filters = 16 # keep the same with n_classes
ModelCfg.main.unet.down_num_filters = [
ModelCfg.main.unet.init_num_filters * (2**idx)
for idx in range(1, ModelCfg.main.unet.down_up_block_num + 1)
ModelCfg.main.unet.init_num_filters * (2**idx) for idx in range(1, ModelCfg.main.unet.down_up_block_num + 1)
]
ModelCfg.main.unet.up_num_filters = [
ModelCfg.main.unet.init_num_filters * (2**idx)
for idx in range(ModelCfg.main.unet.down_up_block_num - 1, -1, -1)
ModelCfg.main.unet.init_num_filters * (2**idx) for idx in range(ModelCfg.main.unet.down_up_block_num - 1, -1, -1)
]
ModelCfg.main.unet.up_mode = "deconv"

Expand Down
Loading

0 comments on commit 51a9313

Please sign in to comment.