-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
99 lines (72 loc) · 3.97 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from DroppedNeurons.utils import kitty
from DroppedNeurons import inference
from DroppedNeurons import app
from art import tprint
import gc
import argparse
from huggingface_hub import snapshot_download
from datasets import load_dataset
def download_assets(model_name = "iiiorg/piiranha-v1-detect-personal-information", dataset_name="ai4privacy/pii-masking-400k"):
"""Download dataset and model checkpoints"""
print("Downloading pretrained model weights")
snapshot_download(repo_id=model_name, cache_dir="checkpoints")
print("Downloading dataset")
ds = load_dataset(dataset_name, cache_dir="dataset")
def preprocess_data(dataset_in, dataset_out):
"""Dummy function to preprocess data."""
pass
def run_train(run_name, dataset_in, model_arch, epochs, lr, bs):
"""Dummy function to run training."""
pass
def run_inference(dataset_in, model_arch):
"""Dummy function to run inference."""
inference.run(dataset_in, model_arch)
def webui(host, port, public):
"""Dummy function to start a web UI."""
app.mainFunc(host, port, public)
def main():
parser = argparse.ArgumentParser(description="Run various tasks related to machine learning and data processing.")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Download assets
download_parser = subparsers.add_parser("download_assets", help="Download model and dataset")
download_parser.add_argument("model_name", help="Huggingface model repo")
download_parser.add_argument("dataset_name", help="Huggingface dataset repo")
# Subparser for data preprocessing
preprocess_parser = subparsers.add_parser("preprocess_data", help="Preprocess the input dataset.")
preprocess_parser.add_argument("dataset_in", help="Path to the input dataset file.")
preprocess_parser.add_argument("dataset_out", help="Path to save the output preprocessed dataset.")
# Subparser for training
train_parser = subparsers.add_parser("run_train", help="Train a model with specified parameters.")
train_parser.add_argument("run_name", help="Unique name for the training run.")
train_parser.add_argument("dataset_in", help="Path to the input dataset file for training.")
train_parser.add_argument("model_arch", help="Model architecture to use for training.")
train_parser.add_argument("epochs", type=int, help="Number of epochs to train the model.")
train_parser.add_argument("lr", type=float, help="Learning rate for the optimizer.")
train_parser.add_argument("bs", type=int, help="Batch size for training.")
# Subparser for inference
inference_parser = subparsers.add_parser("run_inference", help="Run inference on a dataset.")
inference_parser.add_argument("dataset_in", help="Path to the input dataset file for inference.")
inference_parser.add_argument("model_arch", help="Model architecture to use for inference.")
# Subparser for web UI
webui_parser = subparsers.add_parser("webui", help="Start a web UI for interaction.")
webui_parser.add_argument("host", help="Host address for the web UI (e.g., 'localhost').")
webui_parser.add_argument("port", type=int, help="Port number for the web UI.")
webui_parser.add_argument("public", help="Set to 'true' for public access or 'false' for private.")
# Parse arguments
args = parser.parse_args()
# Pretty Title
tprint("DroppedNeurons")
kitty()
# Call the appropriate function based on command
if args.command == "preprocess_data":
preprocess_data(args.dataset_in, args.dataset_out)
elif args.command == "run_train":
run_train(args.run_name, args.dataset_in, args.model_arch, args.epochs, args.lr, args.bs)
elif args.command == "run_inference":
run_inference(args.dataset_in, args.model_arch)
elif args.command == "webui":
webui(args.host, args.port, args.public)
elif args.command == "download_assets":
download_assets(args.model_name, args.dataset_name)
if __name__ == "__main__":
main()