Skip to content

Commit

Permalink
in plotting val/train curve:
Browse files Browse the repository at this point in the history
ValueError: object __array__ method not producing an array
  • Loading branch information
rvankoert committed Feb 4, 2025
1 parent ee24977 commit beb1a30
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/modes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,24 @@ def train_model(model: tf.keras.Model,
def plot_metric(metric, history, title, output_path, plot_validation_metric):
plt.style.use("ggplot")
plt.figure()

# Print the history dictionary to debug
print(f"History dictionary: {history.history}")

# Check if the metric exists in the history
if metric not in history.history:
raise ValueError(f"Metric '{metric}' not found in history")

# Plot the training metric
plt.plot(history.history[metric], label='Training ' + metric)

# Plot the validation metric if requested
if plot_validation_metric:
plt.plot(history.history[f"val_{metric}"],
label=f"Validation {metric}")
val_metric = f"val_{metric}"
if val_metric not in history.history:
raise ValueError(f"Validation metric '{val_metric}' not found in history")
plt.plot(history.history[val_metric], label=f"Validation {metric}")

plt.title(title)
plt.xlabel("Epoch #")
plt.ylabel(metric)
Expand Down Expand Up @@ -134,6 +148,9 @@ def plot_training_history(history: tf.keras.callbacks.History,
other for Character Error Rate (CER).
"""

if not os.path.exists(output_path):
os.makedirs(output_path)

plot_metric(metric="loss",
history=history,
title="Training Loss",
Expand Down

0 comments on commit beb1a30

Please sign in to comment.