Skip to content

Commit

Permalink
add: fault detection in torch RK4
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzifrancesco committed Jan 21, 2025
1 parent 7b9c2f7 commit 2d118f9
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 58 deletions.
2 changes: 1 addition & 1 deletion pynlin/raman/pytorch/_torch_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def torch_rk4(func, y0, t, *args, **kwargs):
new_state = y + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)

if torch.isnan(new_state).any() or torch.isinf(new_state).any():

print("\033[41mWARN:\033[0m failure detected in RK4 propagation")
return y

y = new_state
Expand Down
126 changes: 76 additions & 50 deletions pynlin/raman/pytorch/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@
from pynlin.fiber import Fiber, MMFiber
from pynlin.raman.pytorch._torch_ode import torch_rk4
from pynlin.raman.response import impulse_response
from pynlin.utils import nu2lambda
from pynlin.utils import nu2lambda, watt2dBm

import matplotlib.pyplot as plt
import seaborn as sns

# tmp
import matplotlib.pyplot as plt
from matplotlib.cm import viridis


class RamanAmplifier(torch.nn.Module):
def __init__(
self,
fiber_length: float,
integration_steps: int,
num_pumps: int,
n_pumps: int,
signal_wavelengths: Union[list, NDArray],
power_per_channel: float,
fiber: Fiber,
Expand All @@ -38,7 +42,7 @@ def __init__(
The length of the fiber [m].
steps : int
The number of integration steps.
num_pumps : int
n_pumps : int
The number of Raman pumps.
signal_wavelength : torch.Tensor
The input signal wavelenghts.
Expand All @@ -55,8 +59,8 @@ def __init__(
super(RamanAmplifier, self).__init__()
self.c0 = speed_of_light
self.power_per_channel = power_per_channel
self.num_pumps = num_pumps
self.num_channels = signal_wavelengths.shape[0]
self.n_pumps = n_pumps
self.n_channels = signal_wavelengths.shape[0]
self.signal_wavelengths = torch.from_numpy(signal_wavelengths)
self.length = fiber_length
self.steps = integration_steps
Expand All @@ -65,7 +69,7 @@ def __init__(
if isinstance(signal_wavelengths, np.ndarray):
signal_wavelengths = torch.from_numpy(signal_wavelengths).float()

signal_power = self.power_per_channel * torch.ones((1, self.num_channels))
signal_power = self.power_per_channel * torch.ones((1, self.n_channels))

# limit the polynomial fit of the attenuation spectrum to order 2
num_loss_coeffs = len(fiber.losses)
Expand All @@ -84,7 +88,7 @@ def __init__(
if isinstance(signal_loss, np.ndarray):
signal_loss = torch.from_numpy(signal_loss)

# signal_loss = signal_loss.repeat_interleave(self.modes).view(1, -1)
# signal_loss = signal_loss.repeat_interleave(self.n_modes).view(1, -1)

self.raman_coefficient = fiber.raman_coefficient

Expand Down Expand Up @@ -126,21 +130,21 @@ def __init__(
self.register_buffer("raman_response", raman_response)

# # Doesn't matter, the pumps are turned off
# pump_lambda = torch.linspace(1420, 1480, self.num_pumps) * 1e-9
# pump_power = torch.zeros((num_pumps * modes))
# pump_lambda = torch.linspace(1420, 1480, self.n_pumps) * 1e-9
# pump_power = torch.zeros((n_pumps * modes))
# x = torch.cat((pump_lambda, pump_power)).float().view(1, -1)

if np.isscalar(pump_direction):
pump_direction = np.ones((self.num_pumps,)) * pump_direction
pump_direction = np.ones((self.n_pumps,)) * pump_direction
else:
pump_direction = np.atleast_1d(pump_direction)
signal_direction = np.ones((self.num_channels,))
signal_direction = np.ones((self.n_channels,))
direction = np.concatenate((pump_direction, signal_direction))
self.direction = torch.from_numpy(direction).float()

# if counterpumping:
# self.counterpumping = True
# direction[: self.num_pumps * self.modes] = -1
# direction[: self.n_pumps * self.n_modes] = -1
# else:
# self.counterpumping = False

Expand Down Expand Up @@ -237,7 +241,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

batch_size = x.shape[0]

num_freqs = self.num_channels + self.num_pumps
num_freqs = self.n_channels + self.n_pumps

# This will be the input to the interpolation function
interpolation_grid = torch.zeros(
Expand All @@ -246,7 +250,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
device=x.device,
)

pump_wavelengths = x[:, : self.num_pumps]
pump_wavelengths = x[:, : self.n_pumps]

# Compute the loss for each pump wavelength/mode
pump_loss = self._alpha_to_linear(
Expand All @@ -259,7 +263,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
losses = torch.cat(
(
pump_loss,
self.signal_loss.expand(batch_size, self.num_channels),
self.signal_loss.expand(batch_size, self.n_channels),
),
dim=1,
)
Expand All @@ -269,15 +273,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
pump_freqs = self._lambda2frequency(pump_wavelengths)

total_freqs = torch.cat(
(pump_freqs, self.signal_frequency.expand(batch_size, self.num_channels)),
(pump_freqs, self.signal_frequency.expand(batch_size, self.n_channels)),
dim=1,
)

# Concatenate input pump power and signal power, making sure power > 0
total_power = torch.cat(
(
x[:, self.num_pumps:],
self.signal_power.expand(batch_size, self.num_channels),
x[:, self.n_pumps:],
self.signal_power.expand(batch_size, self.n_channels),
),
1,
)
Expand Down Expand Up @@ -323,13 +327,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

signal_spectrum = solution[
:,
self.num_pumps:,
self.n_pumps:,
].clone()

return signal_spectrum

# if self.counterpumping:
# pump_initial_power = solution[:, : self.num_pumps].clone()
# pump_initial_power = solution[:, : self.n_pumps].clone()
# return signal_spectrum, pump_initial_power
# else:
# return signal_spectrum
Expand All @@ -340,7 +344,7 @@ def __init__(
self,
length,
steps,
num_pumps,
n_pumps,
signal_wavelength,
power_per_channel,
fiber: MMFiber,
Expand All @@ -356,7 +360,7 @@ def __init__(
The length of the fiber [m].
steps : int
The number of integration steps.
num_pumps : int
n_pumps : int
The number of Raman pumps.
signal_wavelength : torch.Tensor
The input signal wavelenghts.
Expand All @@ -373,13 +377,13 @@ def __init__(
super(MMFRamanAmplifier, self).__init__()
self.c0 = speed_of_light
self.power_per_channel = power_per_channel
self.num_pumps = num_pumps
self.num_channels = signal_wavelength.shape[0]
self.modes = fiber.n_modes
self.n_pumps = n_pumps
self.n_channels = signal_wavelength.shape[0]
self.n_modes = fiber.n_modes
self.length = length
self.steps = steps
self.fiber = fiber
self.overlap_integrals = fiber.overlap_integrals[:, :self.modes, :self.modes]
self.overlap_integrals = fiber.overlap_integrals[:, :self.n_modes, :self.n_modes]
overlap_integrals_tensor = torch.Tensor(fiber.overlap_integrals).float()

z = torch.linspace(0, self.length, self.steps)
Expand All @@ -388,7 +392,7 @@ def __init__(
signal_wavelength = torch.from_numpy(signal_wavelength).float()

signal_power = self.power_per_channel * torch.ones(
(1, self.num_channels * self.modes)
(1, self.n_channels * self.n_modes)
)

# limit the polynomial fit of the attenuation spectrum to order 2
Expand All @@ -404,7 +408,7 @@ def __init__(
if isinstance(signal_loss, np.ndarray):
signal_loss = torch.from_numpy(signal_loss)

signal_loss = signal_loss.repeat_interleave(self.modes).view(1, -1)
signal_loss = signal_loss.repeat_interleave(self.n_modes).view(1, -1)

self.raman_coefficient = fiber.raman_coefficient

Expand Down Expand Up @@ -446,17 +450,17 @@ def __init__(
self.register_buffer("raman_response", raman_response)

# Doesn't matter, the pumps are turned off
pump_lambda = torch.linspace(1420, 1480, self.num_pumps) * 1e-9
pump_power = torch.zeros((num_pumps * self.modes))
pump_lambda = torch.linspace(1420, 1480, self.n_pumps) * 1e-9
pump_power = torch.zeros((n_pumps * self.n_modes))
x = torch.cat((pump_lambda, pump_power)).float().view(1, -1)

direction = torch.ones(
((self.num_pumps + self.num_channels) * self.modes,)
((self.n_pumps + self.n_channels) * self.n_modes,)
).float()

if counterpumping:
self.counterpumping = True
direction[: self.num_pumps * self.modes] = -1
direction[: self.n_pumps * self.n_modes] = -1
else:
self.counterpumping = False

Expand Down Expand Up @@ -549,25 +553,24 @@ def forward(self, x):
signal_spectrum: torch.Tensor
signal powers on each mode (B, N_signals, N_modes)
"""

batch_size = x.shape[0]
num_freqs = self.num_channels + self.num_pumps
num_freqs = self.n_channels + self.n_pumps

interpolation_grid = torch.zeros(
(batch_size, 1, num_freqs ** 2, 2), dtype=x.dtype, device=x.device,
)
pump_wavelengths = x[:, : self.num_pumps]
pump_wavelengths = x[:, : self.n_pumps]

pump_loss = self._alpha_to_linear(
self.loss_coefficients[2]
+ self.loss_coefficients[1] * pump_wavelengths
+ self.loss_coefficients[0] * (pump_wavelengths) ** 2
).repeat_interleave(self.modes, dim=1)
).repeat_interleave(self.n_modes, dim=1)

losses = torch.cat(
(
pump_loss,
self.signal_loss.expand(batch_size, self.num_channels * self.modes),
self.signal_loss.expand(batch_size, self.n_channels * self.n_modes),
),
dim=1,
)
Expand All @@ -577,16 +580,16 @@ def forward(self, x):
pump_freqs = self._lambda2frequency(pump_wavelengths)

total_freqs = torch.cat(
(pump_freqs, self.signal_frequency.expand(batch_size, self.num_channels)),
(pump_freqs, self.signal_frequency.expand(batch_size, self.n_channels)),
dim=1,
)
total_wavelenghts = self._frequency2lambda(total_freqs)

# I don't want to allow for negative values of the pump power in the optimizer
total_power = torch.cat(
(
x[:, self.num_pumps:],
self.signal_power.expand(batch_size, self.num_channels * self.modes),
x[:, self.n_pumps:],
self.signal_power.expand(batch_size, self.n_channels * self.n_modes),
),
1,
)
Expand Down Expand Up @@ -634,18 +637,18 @@ def forward(self, x):
+---------+---------+
mantain the same topology?
gain = gain.repeat_interleave(self.modes, dim=1).repeat_interleave(
self.modes, dim=2
gain = gain.repeat_interleave(self.n_modes, dim=1).repeat_interleave(
self.n_modes, dim=2
)
beware, dim = 0 is the batch dimension
"""

gain = gain.repeat_interleave(self.modes, dim=1).repeat_interleave(
self.modes, dim=2
gain = gain.repeat_interleave(self.n_modes, dim=1).repeat_interleave(
self.n_modes, dim=2
).float()
# print("self.modes", self.modes)
# print("self.n_modes", self.n_modes)

# oi = torch.from_numpy(self.fiber.get_oi_matrix_torch(range(self.modes), 3e8 / total_freqs))
# oi = torch.from_numpy(self.fiber.get_oi_matrix_torch(range(self.n_modes), 3e8 / total_freqs))
oi = self.fiber.torch_oi.evaluate_oi_tensor(total_wavelenghts)
# oi_avg = torch.mean(oi)
# print(f"OI : {oi_avg.shape}")
Expand All @@ -655,16 +658,39 @@ def forward(self, x):
# oi = torch.from_numpy(self.overlap_integrals_avg[None, :, :].repeat(num_freqs, axis=1).repeat(num_freqs, axis=2)).float()
G = gain * oi
# G = gain
solution = torch_rk4(
raw_solution = torch_rk4(
MMFRamanAmplifier.ode, total_power, self.z, losses, G, self.direction,
).view(-1, num_freqs, self.modes)
signal_spectrum = solution[:, self.num_pumps:, :].clone()
)
solution=raw_solution.view(-1, num_freqs, self.n_modes)

# print("-----plotting signal solutions")
# plt.clf()
# plt.figure(figsize=(4, 3))
# cmap = viridis
# signals = watt2dBm(raw_solution.detach().numpy())
# print(signals)
# z_plot = np.linspace(0, self.fiber.length, len(signals[:, 0, 0])) * 1e-3
# # lss = ["-", "--", "-.", ":", "-"]
# mode_labels = ["LP01", "LP11", "LP21", "LP02"]
# for i in range(self.n_modes):
# plt.plot(z_plot,
# signals[0, :, :, i], color=cmap(i / self.n_modes + 0.2), alpha=0.3)
# plt.ylabel(r"$P$ [dBm]")
# plt.xlabel(r"$z$ [km]")
# # plt.legend()
# plt.tight_layout()
# plt.grid(False)
# plt.savefig("media/optimization/signal_INNER_profile.png")
# plt.clf()
# print("-----Done.")

signal_spectrum = solution[:, self.n_pumps:, :].clone()
# print("*"*30)
# print(solution)
# print("*"*30)

# if self.counterpumping:
# pump_initial_power = solution[:, : self.num_pumps, :].clone()
# pump_initial_power = solution[:, : self.n_pumps, :].clone()
# return signal_spectrum, pump_initial_power
# else:
# print("signal_spectrum", signal_spectrum)
Expand Down
2 changes: 1 addition & 1 deletion scripts/modules/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Config(BaseModel):
def load_toml_to_struct(filepath: str) -> Config:
# with open(filepath, "rb") as f:
data = toml.load(filepath)
print(data)
# print(data)
return Config(**data)

# Serialize a Pydantic model into a TOML file
Expand Down
4 changes: 1 addition & 3 deletions scripts/modules/plot_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,4 @@ def analyze_optimization(
print(f"{'ASE':<30} | {avg_ase:.5e} dB")
print(f"{'Average pump power at z=0':<30} | {avg_pump_power_0:.5e} dBm")
print(f"{'Average pump power at z=L':<30} | {avg_pump_power_L:.5e} dBm")
return


return
6 changes: 3 additions & 3 deletions scripts/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def ct_solver(power_per_pump,


if __name__ == "__main__":
recompute = False # Set to True to force re-computation
recompute = True # Set to True to force re-computation
signal_powers = [-10, -5, 0]

for ix in range(3):
Expand All @@ -190,8 +190,8 @@ def ct_solver(power_per_pump,
power_per_pump = 6,
pump_band_a = 1410e-9,
pump_band_b = 1520e-9,
learning_rate = 1e-2,
epochs = 1500,
learning_rate = 1e-3,
epochs = 4,
lock_wavelengths = 200,
batch_size = 1,
use_precomputed = False,
Expand Down

0 comments on commit 2d118f9

Please sign in to comment.