Skip to content

Commit

Permalink
Merge pull request #2503 from devitocodes/patch-beta2-is-transient
Browse files Browse the repository at this point in the history
compiler: Avoid allocating Bundles on the host if transient
  • Loading branch information
FabioLuporini authored Dec 20, 2024
2 parents 5a15896 + 93e4323 commit cc04242
Showing 1 changed file with 62 additions and 6 deletions.
68 changes: 62 additions & 6 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
import numpy as np

from devito.ir import (Block, Call, Definition, DummyExpr, Return, EntryFunction,
FindSymbols, MapExprStmts, Transformer, make_callable)
FindNodes, FindSymbols, MapExprStmts, Transformer,
make_callable)
from devito.passes import is_gpu_create
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
SizeOf, VOID, Keyword, pow_to_mul)
from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten
from devito.types import Array, CustomDimension, DeviceMap, DeviceRM, Eq, Symbol
from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap,
DeviceRM, Eq, Symbol)

__all__ = ['DataManager', 'DeviceAwareDataManager', 'Storage']

Expand Down Expand Up @@ -214,6 +216,38 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):

storage.update(obj, site, allocs=alloc, frees=free, efuncs=(efunc0, efunc1))

def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
"""
Allocate a Bundle struct in the host high bandwidth memory.
"""
decl = Definition(obj)

memptr = VOID(Byref(obj._C_symbol), '**')
alignment = obj._data_alignment
nbytes = SizeOf(obj._C_typedata)
alloc = self.lang['host-alloc'](memptr, alignment, nbytes)

nbytes_param = Symbol(name='nbytes', dtype=np.uint64, is_const=True)
nbytes_arg = SizeOf(obj.indexed._C_typedata)*obj.size

ffp1 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
init0 = DummyExpr(ffp1, nbytes_param)
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
init1 = DummyExpr(ffp2, 0)

free = self.lang['host-free'](obj._C_symbol)

ret = Return(obj._C_symbol)

name = self.sregistry.make_name(prefix='alloc')
body = (decl, alloc, init0, init1, ret)
efunc0 = make_callable(name, body, retval=obj)
args = list(efunc0.parameters)
args[args.index(nbytes_param)] = nbytes_arg
alloc = Call(name, args, retobj=obj)

storage.update(obj, site, allocs=alloc, frees=free, efuncs=efunc0)

def _alloc_object_array_on_low_lat_mem(self, site, obj, storage):
"""
Allocate an Array of Objects in the low latency memory.
Expand Down Expand Up @@ -340,9 +374,22 @@ def place_definitions(self, iet, globs=None, **kwargs):
for i in FindSymbols().visit(iet):
if i in defines:
continue

elif i.is_LocalObject:
self._alloc_object_on_low_lat_mem(iet, i, storage)
elif i.is_Array or i.is_Bundle:

elif i.is_Bundle:
if i._mem_heap:
if i.is_transient:
self._alloc_bundle_struct_on_high_bw_mem(iet, i, storage)
elif i._mem_local:
self._alloc_local_array_on_high_bw_mem(iet, i, storage)
elif i._mem_mapped:
self._alloc_mapped_array_on_high_bw_mem(iet, i, storage)
elif i._mem_stack:
self._alloc_array_on_low_lat_mem(iet, i, storage)

elif i.is_Array:
if i._mem_heap:
if i._mem_host:
self._alloc_host_array_on_high_bw_mem(iet, i, storage)
Expand All @@ -355,8 +402,10 @@ def place_definitions(self, iet, globs=None, **kwargs):
elif globs is not None:
# Track, to be handled by the EntryFunction being a global obj!
globs.add(i)

elif i.is_ObjectArray:
self._alloc_object_array_on_low_lat_mem(iet, i, storage)

elif i.is_PointerArray:
self._alloc_pointed_array_on_high_bw_mem(iet, i, storage)

Expand Down Expand Up @@ -571,16 +620,23 @@ def make_zero_init(obj, rcompile, sregistry):
cdims.append(CustomDimension(name=d.name, parent=d,
symbolic_min=m, symbolic_max=M))

eq = Eq(obj[cdims], 0)
if obj.is_Bundle:
eqns = [Eq(ComponentAccess(obj[cdims], i), 0) for i in range(obj.ncomp)]
else:
eqns = [Eq(obj[cdims], 0)]

irs, byproduct = rcompile(eq)
irs, byproduct = rcompile(eqns)

init = irs.iet.body.body[0]

name = sregistry.make_name(prefix='init')
efunc = make_callable(name, init)
init = Call(name, efunc.parameters)

efuncs = [efunc] + [i.root for i in byproduct.funcs]
efuncs = [efunc]

# Also the called device kernels, if any
calls = [i.name for i in FindNodes(Call).visit(efunc)]
efuncs.extend([i.root for i in byproduct.funcs if i.root.name in calls])

return efuncs, init

0 comments on commit cc04242

Please sign in to comment.