Skip to content

Commit

Permalink
Merge pull request #35 from dfki-ric/bugfix/grid-sampler-bug
Browse files Browse the repository at this point in the history
Grid Sampler Bugfix
  • Loading branch information
mlaux1 authored May 23, 2024
2 parents 8aa2c5c + 9c695f1 commit d16ff4f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion deformable_gym/envs/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
points_per_axis = [np.linspace(
low[i], high[i], n_points_per_axis[i]) for i in range(self.n_dims)]

self.grid = np.array(np.meshgrid(*points_per_axis)).T.reshape(-1, 3)
self.grid = np.array(np.meshgrid(*points_per_axis)).T.reshape(-1, self.n_dims)
self.n_samples = len(self.grid)
self.n_calls = 0

Expand Down
8 changes: 4 additions & 4 deletions tests/envs/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def uniform_target_pose() -> npt.NDArray:

@pytest.fixture
def grid_target_pose() -> npt.NDArray:
target = np.array([1, 2, 3])
target = np.array([1, 2, 3, 4])
return target


Expand Down Expand Up @@ -60,9 +60,9 @@ def uniform_sampler() -> UniformSampler:
@pytest.fixture
def grid_sampler() -> GridSampler:
return GridSampler(
low=np.array([1, 2, 3]),
high=np.array([2, 3, 4]),
n_points_per_axis=np.array([5, 3, 1])
low=np.array([1, 2, 3, 4]),
high=np.array([2, 3, 4, 5]),
n_points_per_axis=np.array([5, 3, 1, 1])
)


Expand Down

0 comments on commit d16ff4f

Please sign in to comment.