diff --git a/loki/transformations/inline/procedures.py b/loki/transformations/inline/procedures.py index 01de41929..522536e13 100644 --- a/loki/transformations/inline/procedures.py +++ b/loki/transformations/inline/procedures.py @@ -95,12 +95,24 @@ def _offset_lbound(lbound, v): indices = [index for index, dim in enumerate(val.dimensions) if isinstance(dim, sym.Range)] - for index, dim in enumerate(var.dimensions): + lbounds_diff = [sym.IntLiteral(0) for _ in var.shape] + if var.shape and val.shape: + decl_lbounds = [(getattr(val.shape[i], 'lower', sym.IntLiteral(1)), + getattr(dim, 'lower', sym.IntLiteral(1))) for i, dim in enumerate(var.shape)] + var_ubounds = [getattr(dim, 'upper', dim) for dim in var.shape] + + for i, (lb_val, lb_var) in enumerate(decl_lbounds): + # we can't simply check if lb_val here as that would return a false negative if lb_val == 0 + if lb_val is not None and lb_var is not None: + lbounds_diff[i] = simplify(sym.Sum((lb_val, sym.Product((lb_var, sym.IntLiteral(-1)))))) + + for (index, dim), lbdiff in zip(enumerate(var.dimensions), lbounds_diff): # if the argument contains an array range, we must map the bounds accordingly if isinstance(val.dimensions[index], sym.Range) and (lower := val.dimensions[index].lower): + lower = simplify(sym.Sum((lower, lbdiff))) if isinstance(dim, sym.Range): - _lower = dim.lower or sym.IntLiteral(1) - _upper = dim.upper or val.dimensions[index].upper + _lower = dim.lower or decl_lbounds[index][1] + _upper = dim.upper or var_ubounds[index] _lower = _offset_lbound(lower, _lower) _upper = _offset_lbound(lower, _upper) @@ -109,7 +121,7 @@ def _offset_lbound(lbound, v): else: new_dimensions[indices[index]] = _offset_lbound(lower, dim) else: - new_dimensions[indices[index]] = dim + new_dimensions[indices[index]] = simplify(sym.Sum((dim, lbdiff))) return val.clone(dimensions=tuple(new_dimensions)) diff --git a/loki/transformations/inline/tests/test_inline_transformation.py b/loki/transformations/inline/tests/test_inline_transformation.py index 0cbca4b73..b25bf2ad8 100644 --- a/loki/transformations/inline/tests/test_inline_transformation.py +++ b/loki/transformations/inline/tests/test_inline_transformation.py @@ -318,19 +318,19 @@ def test_inline_transformation_adjust_imports(frontend, tmp_path): """ fcode_outer = """ -subroutine test_inline_outer(a, b) +subroutine test_inline_outer(a, b, f) use bnds_module, only: n use test_inline_mod, only: test_inline_inner use test_inline_another_mod, only: test_inline_another_inner implicit none - real(kind=8), intent(inout) :: a(n), b(n) - real(kind=8) :: c(8) + real(kind=8), intent(inout) :: a(n), b(n), f(n) + real(kind=8) :: c(12) !$loki inline call test_inline_another_inner() !$loki inline - call test_inline_inner(a, b, c(1:4), c(5:8)) + call test_inline_inner(a, b, c(1:4), c(5:8), c(9:12), f) end subroutine test_inline_outer """ @@ -339,12 +339,12 @@ def test_inline_transformation_adjust_imports(frontend, tmp_path): implicit none contains -subroutine test_inline_inner(a, b, c, d) +subroutine test_inline_inner(a, b, c, d, e, f) use BNDS_module, only: n, m use another_module, only: x - real(kind=8), intent(inout) :: a(n), b(n) - real(kind=8), intent(out) :: c(4), d(4) + real(kind=8), intent(inout) :: a(n), b(n), f(0:n-1) + real(kind=8), intent(out) :: c(4), d(4), e(0:3) real(kind=8) :: tmp(m) integer :: i @@ -354,9 +354,16 @@ def test_inline_transformation_adjust_imports(frontend, tmp_path): end do do i=1,4 c(i) = 0. + d(i) = 0. + e(i-1) = 0. enddo c(:) = 1. - d(1:4) = 0. + d(1:4) = 1. + e(0:3) = 1. + e(:) = 2. + do i=0, n-1 + f(i) = 2. + end do end subroutine test_inline_inner end module test_inline_mod """ @@ -388,17 +395,27 @@ def test_inline_transformation_adjust_imports(frontend, tmp_path): # Check that the inlining has happened assign = FindNodes(ir.Assignment).visit(outer.body) - assert len(assign) == 5 + assert len(assign) == 10 assert assign[0].lhs == 'tmp(1:m)' assert assign[0].rhs == 'x' assert assign[1].lhs == 'a(i)' assert assign[1].rhs == 'b(i) + sum(tmp)' assert assign[2].lhs == 'c(i)' assert assign[2].rhs == '0.' - assert assign[3].lhs == 'c(1:4)' - assert assign[3].rhs == '1.' - assert assign[4].lhs == 'c(5:8)' + assert assign[3].lhs == 'c(4 + i)' + assert assign[3].rhs == '0.' + assert assign[4].lhs == 'c(8 + i)' assert assign[4].rhs == '0.' + assert assign[5].lhs == 'c(1:4)' + assert assign[5].rhs == '1.' + assert assign[6].lhs == 'c(5:8)' + assert assign[6].rhs == '1.' + assert assign[7].lhs == 'c(9:12)' + assert assign[7].rhs == '1.' + assert assign[8].lhs == 'c(9:12)' + assert assign[8].rhs == '2.' + assert assign[9].lhs == 'f(1 + i)' + assert assign[9].rhs == '2.' # Now check that the right modules have been moved, # and the import of the call has been removed