Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable inlining of a subroutine with an array subrange in an argument #484

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions loki/transformations/inline/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
FindNodes, FindVariables, FindInlineCalls, SubstituteExpressions,
pragmas_attached, is_loki_pragma, Interface, Pragma, AttachScopes
)
from loki.expression import symbols as sym
from loki.expression import symbols as sym, simplify
from loki.types import BasicType
from loki.tools import as_tuple, CaseInsensitiveDict
from loki.logging import error
Expand Down Expand Up @@ -87,12 +87,44 @@ def _map_unbound_dims(var, val):
For example, mapping the passed array ``m(:,j)`` to the local
expression ``a(i)`` yields ``m(i,j)``.
"""

def _offset_lbound(local_lbound, decl_lbound, v):
_sum = sym.Product((-1, decl_lbound))
_sum = sym.Sum((_sum, local_lbound, v))
return simplify(_sum)

new_dimensions = list(val.dimensions)

indices = [index for index, dim in enumerate(val.dimensions) if isinstance(dim, sym.Range)]

for index, dim in enumerate(var.dimensions):
new_dimensions[indices[index]] = dim
lbounds_diff = [sym.IntLiteral(0) for _ in var.shape]
var_ubounds = [getattr(dim, 'upper', dim) for dim in var.shape]
if var.shape and val.shape:
decl_lbounds = [(getattr(val.shape[i], 'lower', sym.IntLiteral(1)),
getattr(dim, 'lower', sym.IntLiteral(1))) for i, dim in enumerate(var.shape)]

for i, (lb_val, lb_var) in enumerate(decl_lbounds):
# we can't simply check if lb_val here as that would return a false negative if lb_val == 0
if lb_val is not None and lb_var is not None:
lbounds_diff[i] = simplify(sym.Sum((lb_val, sym.Product((lb_var, sym.IntLiteral(-1))))))

for (index, dim), lbdiff in zip(enumerate(var.dimensions), lbounds_diff):
# if the argument contains an array range, we must map the bounds accordingly
if isinstance(val.dimensions[index], sym.Range) and (lower := val.dimensions[index].lower):
lower = simplify(sym.Sum((lower, lbdiff)))
decl_lbound = decl_lbounds[index][0]
if isinstance(dim, sym.Range):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[no action] Note to self, there's a similar index-shifting problem in the associate resolver. We could possibly re-use a shared utility for this if we were to externalise this? This is beyond this PR tho.

_lower = dim.lower or decl_lbounds[index][1]
_upper = dim.upper or var_ubounds[index]

_lower = _offset_lbound(lower, decl_lbound, _lower)
_upper = _offset_lbound(lower, decl_lbound, _upper)

new_dimensions[indices[index]] = sym.Range((_lower, _upper))
else:
new_dimensions[indices[index]] = _offset_lbound(lower, decl_lbound, dim)
else:
new_dimensions[indices[index]] = simplify(sym.Sum((dim, lbdiff)))

return val.clone(dimensions=tuple(new_dimensions))

Expand Down
42 changes: 36 additions & 6 deletions loki/transformations/inline/tests/test_inline_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,18 +318,19 @@ def test_inline_transformation_adjust_imports(frontend, tmp_path):
"""

fcode_outer = """
subroutine test_inline_outer(a, b)
subroutine test_inline_outer(a, b, f)
use bnds_module, only: n
use test_inline_mod, only: test_inline_inner
use test_inline_another_mod, only: test_inline_another_inner
implicit none

real(kind=8), intent(inout) :: a(n), b(n)
real(kind=8), intent(inout) :: a(n), b(n), f(0:n-1)
real(kind=8) :: c(12)

!$loki inline
call test_inline_another_inner()
!$loki inline
call test_inline_inner(a, b)
call test_inline_inner(a, b, c(1:4), c(5:8), c(9:12), f)
end subroutine test_inline_outer
"""

Expand All @@ -338,18 +339,31 @@ def test_inline_transformation_adjust_imports(frontend, tmp_path):
implicit none
contains

subroutine test_inline_inner(a, b)
subroutine test_inline_inner(a, b, c, d, e, f)
use BNDS_module, only: n, m
use another_module, only: x

real(kind=8), intent(inout) :: a(n), b(n)
real(kind=8), intent(inout) :: a(n), b(n), f(2:n+1)
real(kind=8), intent(out) :: c(4), d(4), e(0:3)
real(kind=8) :: tmp(m)
integer :: i

tmp(1:m) = x
do i=1, n
a(i) = b(i) + sum(tmp)
end do
do i=1,4
c(i) = 0.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add an assignment d(i) = 1. here too, to check that this index is shifted too?

d(i) = 0.
e(i-1) = 0.
enddo
c(:) = 1.
d(1:4) = 1.
e(0:3) = 1.
e(:) = 2.
do i=2, n+1
f(i) = 2.
end do
end subroutine test_inline_inner
end module test_inline_mod
"""
Expand Down Expand Up @@ -381,11 +395,27 @@ def test_inline_transformation_adjust_imports(frontend, tmp_path):

# Check that the inlining has happened
assign = FindNodes(ir.Assignment).visit(outer.body)
assert len(assign) == 2
assert len(assign) == 10
assert assign[0].lhs == 'tmp(1:m)'
assert assign[0].rhs == 'x'
assert assign[1].lhs == 'a(i)'
assert assign[1].rhs == 'b(i) + sum(tmp)'
assert assign[2].lhs == 'c(i)'
assert assign[2].rhs == '0.'
assert assign[3].lhs == 'c(4 + i)'
assert assign[3].rhs == '0.'
assert assign[4].lhs == 'c(8 + i)'
assert assign[4].rhs == '0.'
assert assign[5].lhs == 'c(1:4)'
assert assign[5].rhs == '1.'
assert assign[6].lhs == 'c(5:8)'
assert assign[6].rhs == '1.'
assert assign[7].lhs == 'c(9:12)'
assert assign[7].rhs == '1.'
assert assign[8].lhs == 'c(9:12)'
assert assign[8].rhs == '2.'
assert assign[9].lhs == 'f(-2 + i)'
assert assign[9].rhs == '2.'

# Now check that the right modules have been moved,
# and the import of the call has been removed
Expand Down
Loading