diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 7f3af30ee0f..ded648dd536 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -327,24 +327,16 @@ def dspace(self): # Construct the `intervals` of the DataSpace, that is a global, # Dimension-centric view of the data space intervals = IntervalGroup.generate('union', *parts.values()) + # E.g., `db0 -> time`, but `xi NOT-> x` intervals = intervals.promote(lambda d: not d.is_Sub) intervals = intervals.zero(set(intervals.dimensions) - oobs) - # Intersect with intervals from buffered dimensions. Unions of - # buffered dimension intervals may result in shrinking time size - try: - proc = [] - for f, v in parts.items(): - if f.save: - for i in v: - if i.dim.is_Time: - proc.append(intervals[i.dim].intersection(i)) - else: - proc.append(intervals[i.dim]) - intervals = IntervalGroup(proc) - except AttributeError: - pass + # Buffered TimeDimensions should not shirnk their upper time offset + for f, v in parts.items(): + if f.is_TimeFunction: + if f.save and not f.time_dim.is_Conditional: + intervals = intervals.ceil(v[f.time_dim]) return DataSpace(intervals, parts) diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index 7abc68ce081..00c454590ea 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -259,6 +259,11 @@ def negate(self): def zero(self): return Interval(self.dim, 0, 0, self.stamp) + def ceil(self, o): + if o.is_Null: + return self._rebuild() + return Interval(self.dim, self.lower, o.upper, self.stamp) + def flip(self): return Interval(self.dim, self.upper, self.lower, self.stamp) @@ -492,6 +497,11 @@ def zero(self, d=None): return IntervalGroup(intervals, relations=self.relations, mode=self.mode) + def ceil(self, o=None): + d = self.dimensions if o is None else as_tuple(o.dim) + return IntervalGroup([i.ceil(o) if i.dim in d else i for i in self], + relations=self.relations) + def lift(self, d=None, v=None): d = set(self.dimensions if d is None else as_tuple(d)) intervals = [i.lift(v) if i.dim._defines & d else i for i in self] diff --git a/tests/test_buffering.py b/tests/test_buffering.py index 9a12c800288..d3c86670252 100644 --- a/tests/test_buffering.py +++ b/tests/test_buffering.py @@ -752,24 +752,3 @@ def test_stencil_issue_1915_v2(subdomain): op1.apply(time_M=nt-2, u=u1) assert np.all(u.data == u1.data) - - -def test_default_timeM(): - """ - MFE for issue #2235 - """ - grid = Grid(shape=(4, 4)) - - u = TimeFunction(name='u', grid=grid) - usave = TimeFunction(name='usave', grid=grid, save=5) - - eqns = [Eq(u.forward, u + 1), - Eq(usave, u)] - - op = Operator(eqns) - - assert op.arguments()['time_M'] == 4 - - op.apply() - - assert all(np.all(usave.data[i] == i) for i in range(4)) diff --git a/tests/test_checkpointing.py b/tests/test_checkpointing.py index 75cca861cc3..0217f46d526 100644 --- a/tests/test_checkpointing.py +++ b/tests/test_checkpointing.py @@ -10,7 +10,7 @@ @switchconfig(log_level='WARNING') -def test_segmented_incremment(): +def test_segmented_increment(): """ Test for segmented operator execution of a one-sided first order function (increment). The corresponding set of stencil offsets in diff --git a/tests/test_dimension.py b/tests/test_dimension.py index e4446900b37..65c3fd9fad4 100644 --- a/tests/test_dimension.py +++ b/tests/test_dimension.py @@ -210,6 +210,25 @@ def test_modulo_dims_generation_v2(self): assert np.all(f.data[3] == 2) assert np.all(f.data[4] == 4) + def test_default_timeM(self): + """ + MFE for issue #2235 + """ + grid = Grid(shape=(4, 4)) + + u = TimeFunction(name='u', grid=grid) + usave = TimeFunction(name='usave', grid=grid, save=5) + + eqns = [Eq(u.forward, u + 1), + Eq(usave, u)] + + op = Operator(eqns) + + assert op.arguments()['time_M'] == 4 + op.apply() + + assert all(np.all(usave.data[i] == i) for i in range(4)) + class TestSubDimension(object): @@ -760,7 +779,7 @@ def test_basic(self): eqns = [Eq(u.forward, u + 1.), Eq(u2.forward, u2 + 1.), Eq(usave, u)] op = Operator(eqns) - op.apply() + op.apply(time_M=nt-2) assert np.all(np.allclose(u.data[(nt-1) % 3], nt-1)) assert np.all([np.allclose(u2.data[i], i) for i in range(nt)]) assert np.all([np.allclose(usave.data[i], i*factor) diff --git a/tests/test_operator.py b/tests/test_operator.py index 502ab1067df..9186058076a 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -2010,11 +2010,38 @@ def test_indirection(self): op = Operator(eqns) - assert op._dspace[time].lower == 1 + assert op._dspace[time].lower == 0 assert op._dspace[time].upper == 1 assert op.arguments()['time_M'] == nt - 2 - op() + op.apply() assert np.all(f.data[0] == 0.) assert np.all(f.data[i] == 3. for i in range(1, 10)) + + def test_indirection_v2(self): + nt = 10 + grid = Grid(shape=(4, 4)) + time = grid.time_dim + x, y = grid.dimensions + + f = TimeFunction(name='f', grid=grid, save=nt) + g = TimeFunction(name='g', grid=grid) + + idx = time + s = Indirection(name='ofs0', mapped=idx) + + eqns = [ + Eq(s, idx), + Eq(f[s, x, y], g + 3.) + ] + + op = Operator(eqns) + + assert op._dspace[time].lower == 0 + assert op._dspace[time].upper == 0 + assert op.arguments()['time_M'] == nt - 1 + + op.apply() + + assert np.all(f.data[i] == 3. for i in range(1, 10))