Skip to content

Commit

Permalink
refactor: plotting and analysis functions for optimization results
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzifrancesco committed Jan 21, 2025
1 parent 84721ab commit 7b9c2f7
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 87 deletions.
86 changes: 86 additions & 0 deletions scripts/modules/plot_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from .cfg import Config, get_next_filename
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.cm import viridis
from pynlin.utils import watt2dBm


def plot_profiles(signal_wavelengths,
signal_solution,
ase_solution,
pump_wavelengths,
pump_solution,
cf: Config):
plt.clf()
plt.figure(figsize=(4, 3))
cmap = viridis
z_plot = np.linspace(0, cf.fiber_length, len(pump_solution[:, 0, 0])) * 1e-3
# lss = ["-", "--", "-.", ":", "-"]
mode_labels = ["LP01", "LP11", "LP21", "LP02"]
for i in range(cf.n_modes):
plt.plot(z_plot,
watt2dBm(signal_solution[:, :, i]), color=cmap(i / cf.n_modes + 0.2), alpha=0.3)
plt.plot(z_plot,
watt2dBm(ase_solution[:, :, i]), color=cmap(i / cf.n_modes + 0.2), alpha=0.3, ls="-.")
plt.ylabel(r"$P$ [dBm]")
plt.xlabel(r"$z$ [km]")
# plt.legend()
plt.tight_layout()
plt.grid(False)
plt.savefig(get_next_filename("media/optimization/signal_ase_profile", "pdf"))
plt.clf()
#
plt.figure(figsize=(4, 3))
cmap = viridis
z_plot = np.linspace(0, cf.fiber_length, len(pump_solution[:, 0, 0])) * 1e-3
#
for i in range(cf.n_modes):
plt.plot(z_plot,
watt2dBm(pump_solution[:, :, i]), color=cmap(i / cf.n_modes + 0.2), alpha=0.3)
plt.grid(False)
plt.ylabel(r"$P$ [dBm]")
plt.xlabel(r"$z$ [km]")
# plt.legend()
plt.tight_layout()
plt.savefig(get_next_filename("media/optimization/pump_profile", "pdf"))
#
loss = -0.2e-3 * cf.fiber_length
on_off_gain = -loss + cf.raman_gain
plt.clf()
plt.figure(figsize=(4, 3))
for i in range(cf.n_modes):
plt.plot(signal_wavelengths * 1e6,
watt2dBm(signal_solution[-1, :, i]) - cf.launch_power - loss,
label=mode_labels[i],
color=cmap(i / cf.n_modes + 0.2))
plt.legend()
plt.axhline(on_off_gain, ls="--", color="black")
plt.xlabel(r"Channel Wavelength [$\mu$ m]")
plt.ylabel("Gain [dB]")
plt.tight_layout()
plt.savefig(get_next_filename("media/optimization/flatness", "pdf"))
print(f"Plot saved.")
return

def analyze_optimization(
signal_wavelengths,
signal_solution,
ase_solution,
pump_wavelengths,
pump_solution,
cf):
flatness = np.max(signal_solution[-1, :, :]) - np.min(signal_solution[-1, :, :])
approx_loss = -0.2e-3 * cf.fiber_length
avg_ase = np.mean(ase_solution[-1, :, :])
avg_pump_power_0 = np.mean(pump_solution[0, :, :])
avg_pump_power_L = np.mean(pump_solution[-1, :, :])
print(f"{'Optimization metric':<30} | {'Value':>10}")
print("-" * 43)
print(f"{'Flatness':<30} | {flatness:.5e} dB")
print(f"{'Loss':<30} | {approx_loss:.5e} dB")
print(f"{'ASE':<30} | {avg_ase:.5e} dB")
print(f"{'Average pump power at z=0':<30} | {avg_pump_power_0:.5e} dBm")
print(f"{'Average pump power at z=L':<30} | {avg_pump_power_L:.5e} dBm")
return


143 changes: 58 additions & 85 deletions scripts/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pynlin.raman.solvers import MMFRamanAmplifier as NumpyMMFRamanAmplifier
from pynlin.utils import dBm2watt, watt2dBm
import pynlin.constellations

from modules.plot_optimization import plot_profiles, analyze_optimization

def ct_solver(power_per_pump,
pump_band_a,
Expand Down Expand Up @@ -61,6 +61,7 @@ def ct_solver(power_per_pump,
oi_fit = oi_avg_complete

ref_bandwidth = cf.baud_rate
#
fiber = pynlin.fiber.MMFiber(
effective_area=80e-12,
n_modes=cf.n_modes,
Expand All @@ -72,7 +73,6 @@ def ct_solver(power_per_pump,
num_channels=cf.n_channels,
center_frequency=cf.center_frequency
)
#
integration_steps = 1000
z_max = np.linspace(0, fiber.length, integration_steps)
np.save("z_max.npy", z_max)
Expand Down Expand Up @@ -171,90 +171,63 @@ def ct_solver(power_per_pump,
print("_" * 35)
print(f"> final flatness: {flatness:.2f} dB | {flatness_percent:.2f} %")
print("_" * 35)
#
plot_profiles(signal_wavelengths, signal_solution, ase_solution,
pump_wavelengths, pump_solution, cf)
return pump_solution, signal_solution, ase_solution, pump_wavelengths, pump_powers


def plot_profiles(signal_wavelengths,
signal_solution,
ase_solution,
pump_wavelengths,
pump_solution,
cf: cfg.Config):
plt.clf()
plt.figure(figsize=(4, 3))
cmap = viridis
z_plot = np.linspace(0, cf.fiber_length, len(pump_solution[:, 0, 0])) * 1e-3
lss = ["-", "--", "-.", ":", "-"]
mode_labels = ["LP01", "LP11", "LP21", "LP02"]
for i in range(cf.n_modes):
print(i)
plt.plot(z_plot,
watt2dBm(signal_solution[:, :, i]), color=cmap(i / cf.n_modes + 0.2), alpha=0.3)
plt.plot(z_plot,
watt2dBm(ase_solution[:, :, i]), color=cmap(i / cf.n_modes + 0.2), alpha=0.3, ls="-.")
plt.ylabel(r"$P$ [dBm]")
plt.xlabel(r"$z$ [km]")
# plt.legend()
plt.tight_layout()
plt.grid(False)
plt.savefig(cfg.get_next_filename("media/signal_profile", "pdf"))
plt.clf()

plt.figure(figsize=(4, 3))
cmap = viridis
z_plot = np.linspace(0, cf.fiber_length, len(pump_solution[:, 0, 0])) * 1e-3

for i in range(cf.n_modes):
plt.plot(z_plot,
watt2dBm(pump_solution[:, :, i]), color=cmap(i / cf.n_modes + 0.2), alpha=0.3)
plt.grid(False)
plt.ylabel(r"$P$ [dBm]")
plt.xlabel(r"$z$ [km]")
# plt.legend()
plt.tight_layout()
plt.savefig(cfg.get_next_filename("media/pump_profile", "pdf"))

loss = -0.2e-3 * cf.fiber_length
on_off_gain = -loss + cf.raman_gain
plt.clf()
plt.figure(figsize=(4, 3))
for i in range(cf.n_modes):
plt.plot(signal_wavelengths * 1e6,
watt2dBm(signal_solution[-1, :, i]) - cf.launch_power - loss,
label=mode_labels[i],
color=cmap(i / cf.n_modes + 0.2))
plt.legend()
plt.axhline(on_off_gain, ls="--", color="black")
plt.xlabel(r"Channel Wavelength [$\mu$ m]")
plt.ylabel("Gain [dB]")
plt.tight_layout()
plt.savefig(cfg.get_next_filename("media/flatness", "pdf"))
return


if __name__ == "__main__":
signal_powers = [-10, -5, 0]
for ix in range(3):
# write the signal power inside the config.toml file
signal_power = signal_powers[ix]
cf = cfg.load_toml_to_struct("./input/config.toml")
cf.launch_power = signal_power
cfg.save_struct_to_toml("./input/config.toml", cf)
pump_sol, signal_sol, ase_sol, pump_wavelengths, pump_powers = ct_solver(power_per_pump = 11,
pump_band_a = 1410e-9,
pump_band_b = 1520e-9,
learning_rate = 2e-2,
epochs = 500,
lock_wavelengths = 200,
batch_size = 1,
use_precomputed = True,
optimize = True,
use_avg_oi = True
)
print(" -> average ASE at the end: ", np.mean(ase_sol[-1, :, :]))
variables_dict = {name: value for name, value in locals().items() if name in ['pump_sol', 'signal_sol', 'ase_sol', 'pump_wavelengths', 'pump_powers']}
np.save("results/ct_solution" + str(signal_power) + "_gain_" + str(cf.raman_gain) + ".npy", variables_dict)
print("Results saved to file: ", "results/ct_solution" + str(signal_power) + "_gain_" + str(cf.raman_gain) + ".npy")
recompute = False # Set to True to force re-computation
signal_powers = [-10, -5, 0]

for ix in range(3):
signal_power = signal_powers[ix]
cf = cfg.load_toml_to_struct("./input/config.toml")
cf.launch_power = signal_power
cfg.save_struct_to_toml("./input/config.toml", cf)
output_file = f"results/ct_solution{signal_power}_gain_{cf.raman_gain}.npy"

if not os.path.exists(output_file) or recompute:
pump_sol, signal_sol, ase_sol, pump_wavelengths, pump_powers = ct_solver(
power_per_pump = 6,
pump_band_a = 1410e-9,
pump_band_b = 1520e-9,
learning_rate = 1e-2,
epochs = 1500,
lock_wavelengths = 200,
batch_size = 1,
use_precomputed = False,
optimize = True,
use_avg_oi = True
)
print(" -> average ASE at the end: ", np.mean(ase_sol[-1, :, :]))
variables_dict = {
name: value
for name, value in locals().items()
if name in ['pump_sol', 'signal_sol', 'ase_sol', 'pump_wavelengths', 'pump_powers']
}
np.save(output_file, variables_dict)
print("Results saved to file: ", output_file)
else:
print(f"File {output_file} already exists. Loading data...")
variables_dict = np.load(output_file, allow_pickle=True).item()

wdm = pynlin.wdm.WDM(
spacing=cf.channel_spacing,
num_channels=cf.n_channels,
center_frequency=cf.center_frequency
)
plot_profiles(
signal_wavelengths = wdm.wavelength_grid(),
signal_solution = variables_dict['signal_sol'],
ase_solution = variables_dict['ase_sol'],
pump_wavelengths = variables_dict['pump_wavelengths'],
pump_solution = variables_dict['pump_sol'],
cf = cf
)
analyze_optimization(
signal_wavelengths = wdm.wavelength_grid(),
signal_solution = variables_dict['signal_sol'],
ase_solution = variables_dict['ase_sol'],
pump_wavelengths = variables_dict['pump_wavelengths'],
pump_solution = variables_dict['pump_sol'],
cf = cf
)
3 changes: 1 addition & 2 deletions scripts/watch_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(self, file_path, n_modes=2):
def update_plot(self):
try:
signals = np.load(self.file_path)
print(signals)
for i, line in enumerate(self.lines):
if i == self.n_modes:
line.set_data(range(signals.shape[0]), np.ones(signals.shape[0]) * np.mean(signals, axis=(0, 1)) + self.ref)
Expand Down Expand Up @@ -65,4 +64,4 @@ def monitor_file(file_path, n_modes=2):

if __name__ == "__main__":
file_path = "results/gain_walker.npy"
monitor_file(file_path, n_modes=1)
monitor_file(file_path, n_modes=4)

0 comments on commit 7b9c2f7

Please sign in to comment.