-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhyparams.py
61 lines (48 loc) · 1.59 KB
/
hyparams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python
# hparams
# Prints out hyperparameters and their defaults for various models.
#
# Author: Benjamin Bengfort <benjamin@bengfort.com>
# Created: Wed Nov 27 10:45:19 2019 -0500
#
# Copyright (C) 2019 Georgetown Data Analytics (CCPE)
# For license information, see LICENSE.txt
#
# ID: hparams.py [] benjamin@bengfort.com $
"""
Prints out hyperparameters and their defaults for various models.
"""
##########################################################################
## Imports
##########################################################################
import pprint
import argparse
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB, ComplementNB
MODELS = {
"svm": SVC,
"logistic": LogisticRegression,
"gaussiannb": GaussianNB,
"multinomialnb": MultinomialNB,
"bernoullinb": BernoulliNB,
"complementnb": ComplementNB,
}
##########################################################################
## Main Method
##########################################################################
def main(args):
for model in args.model:
params = MODELS[model]().get_params()
pprint.pprint(params)
print("\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="prints out hyperparameters and their defaults"
)
parser.add_argument(
"model", choices=MODELS.keys(), nargs="+",
help="the models for whom to print out the params and defaults"
)
args = parser.parse_args()
main(args)