From 2d118f9d40ee09572dc9625b5c2a83beaf8e4873 Mon Sep 17 00:00:00 2001 From: lorenzifrancesco Date: Tue, 21 Jan 2025 10:02:39 +0100 Subject: [PATCH] add: fault detection in torch RK4 --- pynlin/raman/pytorch/_torch_ode.py | 2 +- pynlin/raman/pytorch/solvers.py | 126 ++++++++++++++++----------- scripts/modules/cfg.py | 2 +- scripts/modules/plot_optimization.py | 4 +- scripts/optimize.py | 6 +- 5 files changed, 82 insertions(+), 58 deletions(-) diff --git a/pynlin/raman/pytorch/_torch_ode.py b/pynlin/raman/pytorch/_torch_ode.py index fceef01..aecef59 100644 --- a/pynlin/raman/pytorch/_torch_ode.py +++ b/pynlin/raman/pytorch/_torch_ode.py @@ -30,7 +30,7 @@ def torch_rk4(func, y0, t, *args, **kwargs): new_state = y + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4) if torch.isnan(new_state).any() or torch.isinf(new_state).any(): - + print("\033[41mWARN:\033[0m failure detected in RK4 propagation") return y y = new_state diff --git a/pynlin/raman/pytorch/solvers.py b/pynlin/raman/pytorch/solvers.py index d012b4c..33a6e85 100644 --- a/pynlin/raman/pytorch/solvers.py +++ b/pynlin/raman/pytorch/solvers.py @@ -11,18 +11,22 @@ from pynlin.fiber import Fiber, MMFiber from pynlin.raman.pytorch._torch_ode import torch_rk4 from pynlin.raman.response import impulse_response -from pynlin.utils import nu2lambda +from pynlin.utils import nu2lambda, watt2dBm import matplotlib.pyplot as plt import seaborn as sns +# tmp +import matplotlib.pyplot as plt +from matplotlib.cm import viridis + class RamanAmplifier(torch.nn.Module): def __init__( self, fiber_length: float, integration_steps: int, - num_pumps: int, + n_pumps: int, signal_wavelengths: Union[list, NDArray], power_per_channel: float, fiber: Fiber, @@ -38,7 +42,7 @@ def __init__( The length of the fiber [m]. steps : int The number of integration steps. - num_pumps : int + n_pumps : int The number of Raman pumps. signal_wavelength : torch.Tensor The input signal wavelenghts. @@ -55,8 +59,8 @@ def __init__( super(RamanAmplifier, self).__init__() self.c0 = speed_of_light self.power_per_channel = power_per_channel - self.num_pumps = num_pumps - self.num_channels = signal_wavelengths.shape[0] + self.n_pumps = n_pumps + self.n_channels = signal_wavelengths.shape[0] self.signal_wavelengths = torch.from_numpy(signal_wavelengths) self.length = fiber_length self.steps = integration_steps @@ -65,7 +69,7 @@ def __init__( if isinstance(signal_wavelengths, np.ndarray): signal_wavelengths = torch.from_numpy(signal_wavelengths).float() - signal_power = self.power_per_channel * torch.ones((1, self.num_channels)) + signal_power = self.power_per_channel * torch.ones((1, self.n_channels)) # limit the polynomial fit of the attenuation spectrum to order 2 num_loss_coeffs = len(fiber.losses) @@ -84,7 +88,7 @@ def __init__( if isinstance(signal_loss, np.ndarray): signal_loss = torch.from_numpy(signal_loss) - # signal_loss = signal_loss.repeat_interleave(self.modes).view(1, -1) + # signal_loss = signal_loss.repeat_interleave(self.n_modes).view(1, -1) self.raman_coefficient = fiber.raman_coefficient @@ -126,21 +130,21 @@ def __init__( self.register_buffer("raman_response", raman_response) # # Doesn't matter, the pumps are turned off - # pump_lambda = torch.linspace(1420, 1480, self.num_pumps) * 1e-9 - # pump_power = torch.zeros((num_pumps * modes)) + # pump_lambda = torch.linspace(1420, 1480, self.n_pumps) * 1e-9 + # pump_power = torch.zeros((n_pumps * modes)) # x = torch.cat((pump_lambda, pump_power)).float().view(1, -1) if np.isscalar(pump_direction): - pump_direction = np.ones((self.num_pumps,)) * pump_direction + pump_direction = np.ones((self.n_pumps,)) * pump_direction else: pump_direction = np.atleast_1d(pump_direction) - signal_direction = np.ones((self.num_channels,)) + signal_direction = np.ones((self.n_channels,)) direction = np.concatenate((pump_direction, signal_direction)) self.direction = torch.from_numpy(direction).float() # if counterpumping: # self.counterpumping = True - # direction[: self.num_pumps * self.modes] = -1 + # direction[: self.n_pumps * self.n_modes] = -1 # else: # self.counterpumping = False @@ -237,7 +241,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: batch_size = x.shape[0] - num_freqs = self.num_channels + self.num_pumps + num_freqs = self.n_channels + self.n_pumps # This will be the input to the interpolation function interpolation_grid = torch.zeros( @@ -246,7 +250,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: device=x.device, ) - pump_wavelengths = x[:, : self.num_pumps] + pump_wavelengths = x[:, : self.n_pumps] # Compute the loss for each pump wavelength/mode pump_loss = self._alpha_to_linear( @@ -259,7 +263,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: losses = torch.cat( ( pump_loss, - self.signal_loss.expand(batch_size, self.num_channels), + self.signal_loss.expand(batch_size, self.n_channels), ), dim=1, ) @@ -269,15 +273,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: pump_freqs = self._lambda2frequency(pump_wavelengths) total_freqs = torch.cat( - (pump_freqs, self.signal_frequency.expand(batch_size, self.num_channels)), + (pump_freqs, self.signal_frequency.expand(batch_size, self.n_channels)), dim=1, ) # Concatenate input pump power and signal power, making sure power > 0 total_power = torch.cat( ( - x[:, self.num_pumps:], - self.signal_power.expand(batch_size, self.num_channels), + x[:, self.n_pumps:], + self.signal_power.expand(batch_size, self.n_channels), ), 1, ) @@ -323,13 +327,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: signal_spectrum = solution[ :, - self.num_pumps:, + self.n_pumps:, ].clone() return signal_spectrum # if self.counterpumping: - # pump_initial_power = solution[:, : self.num_pumps].clone() + # pump_initial_power = solution[:, : self.n_pumps].clone() # return signal_spectrum, pump_initial_power # else: # return signal_spectrum @@ -340,7 +344,7 @@ def __init__( self, length, steps, - num_pumps, + n_pumps, signal_wavelength, power_per_channel, fiber: MMFiber, @@ -356,7 +360,7 @@ def __init__( The length of the fiber [m]. steps : int The number of integration steps. - num_pumps : int + n_pumps : int The number of Raman pumps. signal_wavelength : torch.Tensor The input signal wavelenghts. @@ -373,13 +377,13 @@ def __init__( super(MMFRamanAmplifier, self).__init__() self.c0 = speed_of_light self.power_per_channel = power_per_channel - self.num_pumps = num_pumps - self.num_channels = signal_wavelength.shape[0] - self.modes = fiber.n_modes + self.n_pumps = n_pumps + self.n_channels = signal_wavelength.shape[0] + self.n_modes = fiber.n_modes self.length = length self.steps = steps self.fiber = fiber - self.overlap_integrals = fiber.overlap_integrals[:, :self.modes, :self.modes] + self.overlap_integrals = fiber.overlap_integrals[:, :self.n_modes, :self.n_modes] overlap_integrals_tensor = torch.Tensor(fiber.overlap_integrals).float() z = torch.linspace(0, self.length, self.steps) @@ -388,7 +392,7 @@ def __init__( signal_wavelength = torch.from_numpy(signal_wavelength).float() signal_power = self.power_per_channel * torch.ones( - (1, self.num_channels * self.modes) + (1, self.n_channels * self.n_modes) ) # limit the polynomial fit of the attenuation spectrum to order 2 @@ -404,7 +408,7 @@ def __init__( if isinstance(signal_loss, np.ndarray): signal_loss = torch.from_numpy(signal_loss) - signal_loss = signal_loss.repeat_interleave(self.modes).view(1, -1) + signal_loss = signal_loss.repeat_interleave(self.n_modes).view(1, -1) self.raman_coefficient = fiber.raman_coefficient @@ -446,17 +450,17 @@ def __init__( self.register_buffer("raman_response", raman_response) # Doesn't matter, the pumps are turned off - pump_lambda = torch.linspace(1420, 1480, self.num_pumps) * 1e-9 - pump_power = torch.zeros((num_pumps * self.modes)) + pump_lambda = torch.linspace(1420, 1480, self.n_pumps) * 1e-9 + pump_power = torch.zeros((n_pumps * self.n_modes)) x = torch.cat((pump_lambda, pump_power)).float().view(1, -1) direction = torch.ones( - ((self.num_pumps + self.num_channels) * self.modes,) + ((self.n_pumps + self.n_channels) * self.n_modes,) ).float() if counterpumping: self.counterpumping = True - direction[: self.num_pumps * self.modes] = -1 + direction[: self.n_pumps * self.n_modes] = -1 else: self.counterpumping = False @@ -549,25 +553,24 @@ def forward(self, x): signal_spectrum: torch.Tensor signal powers on each mode (B, N_signals, N_modes) """ - batch_size = x.shape[0] - num_freqs = self.num_channels + self.num_pumps + num_freqs = self.n_channels + self.n_pumps interpolation_grid = torch.zeros( (batch_size, 1, num_freqs ** 2, 2), dtype=x.dtype, device=x.device, ) - pump_wavelengths = x[:, : self.num_pumps] + pump_wavelengths = x[:, : self.n_pumps] pump_loss = self._alpha_to_linear( self.loss_coefficients[2] + self.loss_coefficients[1] * pump_wavelengths + self.loss_coefficients[0] * (pump_wavelengths) ** 2 - ).repeat_interleave(self.modes, dim=1) + ).repeat_interleave(self.n_modes, dim=1) losses = torch.cat( ( pump_loss, - self.signal_loss.expand(batch_size, self.num_channels * self.modes), + self.signal_loss.expand(batch_size, self.n_channels * self.n_modes), ), dim=1, ) @@ -577,7 +580,7 @@ def forward(self, x): pump_freqs = self._lambda2frequency(pump_wavelengths) total_freqs = torch.cat( - (pump_freqs, self.signal_frequency.expand(batch_size, self.num_channels)), + (pump_freqs, self.signal_frequency.expand(batch_size, self.n_channels)), dim=1, ) total_wavelenghts = self._frequency2lambda(total_freqs) @@ -585,8 +588,8 @@ def forward(self, x): # I don't want to allow for negative values of the pump power in the optimizer total_power = torch.cat( ( - x[:, self.num_pumps:], - self.signal_power.expand(batch_size, self.num_channels * self.modes), + x[:, self.n_pumps:], + self.signal_power.expand(batch_size, self.n_channels * self.n_modes), ), 1, ) @@ -634,18 +637,18 @@ def forward(self, x): +---------+---------+ mantain the same topology? - gain = gain.repeat_interleave(self.modes, dim=1).repeat_interleave( - self.modes, dim=2 + gain = gain.repeat_interleave(self.n_modes, dim=1).repeat_interleave( + self.n_modes, dim=2 ) beware, dim = 0 is the batch dimension """ - gain = gain.repeat_interleave(self.modes, dim=1).repeat_interleave( - self.modes, dim=2 + gain = gain.repeat_interleave(self.n_modes, dim=1).repeat_interleave( + self.n_modes, dim=2 ).float() - # print("self.modes", self.modes) + # print("self.n_modes", self.n_modes) - # oi = torch.from_numpy(self.fiber.get_oi_matrix_torch(range(self.modes), 3e8 / total_freqs)) + # oi = torch.from_numpy(self.fiber.get_oi_matrix_torch(range(self.n_modes), 3e8 / total_freqs)) oi = self.fiber.torch_oi.evaluate_oi_tensor(total_wavelenghts) # oi_avg = torch.mean(oi) # print(f"OI : {oi_avg.shape}") @@ -655,16 +658,39 @@ def forward(self, x): # oi = torch.from_numpy(self.overlap_integrals_avg[None, :, :].repeat(num_freqs, axis=1).repeat(num_freqs, axis=2)).float() G = gain * oi # G = gain - solution = torch_rk4( + raw_solution = torch_rk4( MMFRamanAmplifier.ode, total_power, self.z, losses, G, self.direction, - ).view(-1, num_freqs, self.modes) - signal_spectrum = solution[:, self.num_pumps:, :].clone() + ) + solution=raw_solution.view(-1, num_freqs, self.n_modes) + + # print("-----plotting signal solutions") + # plt.clf() + # plt.figure(figsize=(4, 3)) + # cmap = viridis + # signals = watt2dBm(raw_solution.detach().numpy()) + # print(signals) + # z_plot = np.linspace(0, self.fiber.length, len(signals[:, 0, 0])) * 1e-3 + # # lss = ["-", "--", "-.", ":", "-"] + # mode_labels = ["LP01", "LP11", "LP21", "LP02"] + # for i in range(self.n_modes): + # plt.plot(z_plot, + # signals[0, :, :, i], color=cmap(i / self.n_modes + 0.2), alpha=0.3) + # plt.ylabel(r"$P$ [dBm]") + # plt.xlabel(r"$z$ [km]") + # # plt.legend() + # plt.tight_layout() + # plt.grid(False) + # plt.savefig("media/optimization/signal_INNER_profile.png") + # plt.clf() + # print("-----Done.") + + signal_spectrum = solution[:, self.n_pumps:, :].clone() # print("*"*30) # print(solution) # print("*"*30) # if self.counterpumping: - # pump_initial_power = solution[:, : self.num_pumps, :].clone() + # pump_initial_power = solution[:, : self.n_pumps, :].clone() # return signal_spectrum, pump_initial_power # else: # print("signal_spectrum", signal_spectrum) diff --git a/scripts/modules/cfg.py b/scripts/modules/cfg.py index 5904705..6f6f9c2 100644 --- a/scripts/modules/cfg.py +++ b/scripts/modules/cfg.py @@ -22,7 +22,7 @@ class Config(BaseModel): def load_toml_to_struct(filepath: str) -> Config: # with open(filepath, "rb") as f: data = toml.load(filepath) - print(data) + # print(data) return Config(**data) # Serialize a Pydantic model into a TOML file diff --git a/scripts/modules/plot_optimization.py b/scripts/modules/plot_optimization.py index e046921..7ec5d72 100644 --- a/scripts/modules/plot_optimization.py +++ b/scripts/modules/plot_optimization.py @@ -81,6 +81,4 @@ def analyze_optimization( print(f"{'ASE':<30} | {avg_ase:.5e} dB") print(f"{'Average pump power at z=0':<30} | {avg_pump_power_0:.5e} dBm") print(f"{'Average pump power at z=L':<30} | {avg_pump_power_L:.5e} dBm") - return - - \ No newline at end of file + return \ No newline at end of file diff --git a/scripts/optimize.py b/scripts/optimize.py index 92cfa01..a63b5ee 100644 --- a/scripts/optimize.py +++ b/scripts/optimize.py @@ -175,7 +175,7 @@ def ct_solver(power_per_pump, if __name__ == "__main__": - recompute = False # Set to True to force re-computation + recompute = True # Set to True to force re-computation signal_powers = [-10, -5, 0] for ix in range(3): @@ -190,8 +190,8 @@ def ct_solver(power_per_pump, power_per_pump = 6, pump_band_a = 1410e-9, pump_band_b = 1520e-9, - learning_rate = 1e-2, - epochs = 1500, + learning_rate = 1e-3, + epochs = 4, lock_wavelengths = 200, batch_size = 1, use_precomputed = False,