-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·87 lines (73 loc) · 3.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from data_pipeline import load_and_preprocess_data, augment_data
from feature_selection import select_features
from model import train_model_cv, evaluate_model, save_model, compare_models, optimize_hyperparameters
from interpretation import explain_model
from ensemble import create_ensemble
from visualization import plot_feature_importance, plot_confusion_matrix, plot_correlation_matrix, visualize_results
#!/usr/bin/env python3
import argparse
import sys
from Bio import SeqIO
from Bio.Seq import Seq
from features import extract_features
from model import load_model, predict_new_data, train_model_cv, save_model
from data_pipeline import load_and_preprocess_data, augment_data
import numpy as np
def generate_model(positive_file: str, negative_file: str):
"""Generate and save the protein interaction model."""
print("Loading and preprocessing data...")
X, y, feature_names = load_and_preprocess_data(positive_file, negative_file)
print("Augmenting data...")
X_augmented, y_augmented = augment_data(X, y)
print("Training model...")
model, scaler = train_model_cv(X_augmented, y_augmented)
print("Generating visualization plots...")
y_pred = predict_new_data(model, scaler, X_augmented)
visualize_results(model, X_augmented, y_augmented, y_pred, feature_names)
print("Saving model...")
save_model(model, scaler, "protein_interaction_model.joblib")
print("Model generation complete. Saved as 'protein_interaction_model.joblib'")
print("Visualization plots have been generated and saved.")
def main():
parser = argparse.ArgumentParser(description="Predict protein interactions or generate model.")
parser.add_argument("--generate", action="store_true", help="Generate new model")
parser.add_argument("--positive", help="Path to positive interaction FASTA file (for model generation)")
parser.add_argument("--negative", help="Path to negative interaction FASTA file (for model generation)")
parser.add_argument("sequence_file", nargs="?", help="Path to the protein sequence file (FASTA format) for prediction")
args = parser.parse_args()
if args.generate:
if not args.positive or not args.negative:
print("Error: Both --positive and --negative files are required for model generation.")
sys.exit(1)
generate_model(args.positive, args.negative)
elif args.sequence_file:
try:
# Load the sequence
with open(args.sequence_file, "r") as handle:
record = next(SeqIO.parse(handle, "fasta"))
sequence = record.seq
except FileNotFoundError:
print(f"Error: File '{args.sequence_file}' not found.")
sys.exit(1)
except StopIteration:
print(f"Error: No sequences found in '{args.sequence_file}'.")
sys.exit(1)
# Extract features
features = extract_features(sequence)
X_new = np.array([list(features.values())])
# Load the pre-trained model
try:
model, scaler = load_model("protein_interaction_model.joblib")
except FileNotFoundError:
print("Error: Pre-trained model not found. Please ensure 'protein_interaction_model.joblib' is in the current directory.")
sys.exit(1)
# Make prediction
prediction = predict_new_data(model, scaler, X_new)
# Print result
result = "likely to interact" if prediction[0] == 1 else "unlikely to interact"
print(f"The protein sequence in '{args.sequence_file}' is {result}.")
else:
print("Error: Please provide either --generate with --positive and --negative files, or a sequence file for prediction.")
sys.exit(1)
if __name__ == "__main__":
main()