Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
Update exact_solution_data function in utils.py file
  • Loading branch information
dimerf99 committed Dec 17, 2024
1 parent 21e903a commit 0925ebe
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions tedeous/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,27 +290,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def exact_solution_data(grid, datapath, n_dim_in, n_dim_out):
def exact_solution_data(grid, datapath, pde_dim_in, pde_dim_out, t_dim_flag=False):
device_origin = grid.device
grid = grid.to('cpu').detach()

test_data = np.loadtxt(datapath, comments="%", encoding='utf-8').astype(np.float32)
test_data = torch.from_numpy(test_data)
grid_data = torch.stack([coord for coord in test_data[:, :pde_dim_in - t_dim_flag]])

exact_func = test_data[:, n_dim_in:]
grid_data = torch.stack([coord for coord in test_data[:, :n_dim_in]])
exact_func = test_data[:, pde_dim_in - t_dim_flag:]

if t_dim_flag:
N_t = exact_func.shape[1]
exact_func = exact_func.reshape(-1, pde_dim_out)
t = torch.linspace(min(grid[:, pde_dim_in - 1]), max(grid[:, pde_dim_in - 1]), N_t).to('cpu').detach()
grid_data = torch.vstack([torch.cartesian_prod(coord, t) for coord in grid_data])

grid_data = grid_data.cpu().numpy()
exact_func = exact_func.cpu().numpy()
grid = grid.cpu().numpy()

if n_dim_out == 1:
exact = scipy.interpolate.griddata(grid_data, exact_func, grid, method='nearest').reshape(-1)
if pde_dim_out == 1:
exact_func = scipy.interpolate.griddata(grid_data, exact_func, grid, method='nearest').reshape(-1)
else:
exact = np.array(
exact_func = np.array(
[scipy.interpolate.griddata(grid_data, exact_func[:, i_dim], grid, method='nearest').reshape(-1)
for i_dim in range(n_dim_out)]
for i_dim in range(pde_dim_out)]
)

exact = torch.from_numpy(exact).to(device_origin)
return exact
exact_func = torch.from_numpy(exact_func).to(device_origin)
return exact_func

0 comments on commit 0925ebe

Please sign in to comment.