Skip to content

Commit

Permalink
debug binary classifier data preparation
Browse files Browse the repository at this point in the history
  • Loading branch information
kermitt2 committed Jan 28, 2024
1 parent 8641e73 commit 7061ff9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions delft/applications/licenseClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def train_and_eval_binary(embeddings_name, fold_count, architecture="gru", trans
y_train_class_rank = np.array(y_train_class_rank)
y_test_class_rank = np.array(y_test_class_rank)

list_classes_rank = [list_classes[class_rank], "not_"+list_classes[class_rank]]
list_classes_rank = [list_classes_copyright[class_rank], "not_"+list_classes_copyright[class_rank]]

model = Classifier(model_name, architecture=architecture, list_classes=list_classes_rank, max_epoch=max_epoch, fold_number=fold_count, patience=patience,
use_roc_auc=True, embeddings_name=embeddings_name, batch_size=batch_size, maxlen=maxlen, early_stop=early_stop,
Expand Down Expand Up @@ -246,7 +246,7 @@ def train_and_eval_binary(embeddings_name, fold_count, architecture="gru", trans
y_train_class_rank = np.array(y_train_class_rank)
y_test_class_rank = np.array(y_test_class_rank)

list_classes_rank = [list_classes[class_rank], "not_"+list_classes[class_rank]]
list_classes_rank = [list_classes_licenses[class_rank], "not_"+list_classes_licenses[class_rank]]

model = Classifier(model_name, architecture=architecture, list_classes=list_classes_rank, max_epoch=max_epoch, fold_number=fold_count, patience=patience,
use_roc_auc=True, embeddings_name=embeddings_name, batch_size=batch_size, maxlen=maxlen, early_stop=early_stop,
Expand Down

0 comments on commit 7061ff9

Please sign in to comment.