diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index 54274c1729..24ae71c86a 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -1140,7 +1140,7 @@ def supports(self, query, language=None): warning(f"Couldn't establish if `query={query}` is supported on this " "system. Assuming it is not.") return False - elif query == 'async-loads' and cc >= 80: + elif query == 'async-pipe' and cc >= 80: # Asynchronous pipeline loads -- introduced in Ampere return True elif query in ('tma', 'thread-block-cluster') and cc >= 90: @@ -1157,7 +1157,7 @@ class Volta(NvidiaDevice): class Ampere(Volta): def supports(self, query, language=None): - if query == 'async-loads': + if query == 'async-pipe': return True else: return super().supports(query, language) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index f83dc39c94..85162b8a5f 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -83,9 +83,11 @@ def __repr__(self): if not self.is_Reduction: return super().__repr__() elif self.operation is OpInc: - return '%s += %s' % (self.lhs, self.rhs) + return f'Inc({self.lhs}, {self.rhs})' else: - return '%s = %s(%s)' % (self.lhs, self.operation, self.rhs) + return f'Eq({self.lhs}, {self.operation}({self.rhs}))' + + __str__ = __repr__ # Pickling support __reduce_ex__ = Pickable.__reduce_ex__ diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 8f6ae8f02f..d67661456f 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -305,6 +305,9 @@ def _gen_value(self, obj, mode=1, masked=()): qualifiers = [v for k, v in self._qualifiers_mapper.items() if getattr(obj.function, k, False) and v not in masked] + if obj.is_LocalObject and mode == 2: + qualifiers.extend(as_tuple(obj._C_tag)) + if (obj._mem_stack or obj._mem_constant) and mode == 1: strtype = self.ccode(obj._C_typedata) strshape = ''.join(f'[{self.ccode(i)}]' for i in obj.symbolic_shape) diff --git a/devito/ir/support/guards.py b/devito/ir/support/guards.py index a014db8abb..73740cb44b 100644 --- a/devito/ir/support/guards.py +++ b/devito/ir/support/guards.py @@ -291,6 +291,7 @@ def xandg(self, d, guard): def pairwise_or(self, d, *guards): m = dict(self) + guards = list(guards) if d in m: guards.append(m[d]) @@ -490,7 +491,9 @@ def pairwise_or(*guards): # Analysis for guard in guards: - if guard is true or guard is None: + if guard is true: + return true + elif guard is None: continue elif isinstance(guard, And): components = guard.args diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index f43351001c..4cfac8bac5 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -775,17 +775,19 @@ def __init__(self, intervals, sub_iterators=None, directions=None): super().__init__(intervals) # Normalize sub-iterators - sub_iterators = dict([(k, tuple(filter_ordered(as_tuple(v)))) - for k, v in (sub_iterators or {}).items()]) + sub_iterators = sub_iterators or {} + sub_iterators = {d: tuple(filter_ordered(as_tuple(v))) + for d, v in sub_iterators.items() if d in self.intervals} sub_iterators.update({i.dim: () for i in self.intervals if i.dim not in sub_iterators}) self._sub_iterators = frozendict(sub_iterators) # Normalize directions - if directions is None: - self._directions = frozendict([(i.dim, Any) for i in self.intervals]) - else: - self._directions = frozendict(directions) + directions = directions or {} + directions = {d: v for d, v in directions.items() if d in self.intervals} + directions.update({i.dim: Any for i in self.intervals + if i.dim not in directions}) + self._directions = frozendict(directions) def __repr__(self): ret = ', '.join(["%s%s" % (repr(i), repr(self.directions[i.dim])) @@ -807,8 +809,7 @@ def __lt__(self, other): return len(self.itintervals) < len(other.itintervals) def __hash__(self): - return hash((super().__hash__(), self.sub_iterators, - self.directions)) + return hash((super().__hash__(), self.sub_iterators, self.directions)) def __contains__(self, d): try: diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 509490a2e1..8c09a2ee45 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -12,7 +12,7 @@ from devito.finite_differences.differentiable import IndexDerivative from devito.ir import Cluster, Scope, cluster_pass -from devito.symbolics import estimate_cost, q_leaf, q_terminal +from devito.symbolics import Reserved, estimate_cost, q_leaf, q_terminal from devito.symbolics.search import search from devito.symbolics.manipulation import _uxreplace from devito.tools import DAG, as_list, as_tuple, frozendict, extract_dtype @@ -401,6 +401,7 @@ def _(expr): @_catch.register(Indexed) @_catch.register(Symbol) +@_catch.register(Reserved) def _(expr): """ Handler for objects preventing CSE to propagate through their arguments. diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index d244704058..a51acebecc 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -3,7 +3,7 @@ from sympy import S import numpy as np -from devito.finite_differences import IndexDerivative +from devito.finite_differences import IndexDerivative, Weights from devito.ir import Backward, Forward, Interval, IterationSpace, Queue from devito.passes.clusters.misc import fuse from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace @@ -94,17 +94,39 @@ def _core(expr, c, ispace, weights, reusables, mapper, **kwargs): @_core.register(Symbol) -@_core.register(Indexed) @_core.register(BasicWrapperMixin) def _(expr, c, ispace, weights, reusables, mapper, **kwargs): return expr, [] +@_core.register(Indexed) +def _(expr, c, ispace, weights, reusables, mapper, **kwargs): + if not isinstance(expr.function, Weights): + return expr, [] + + # Lower or reuse a previously lowered Weights array + sregistry = kwargs['sregistry'] + subs_user = kwargs['subs'] + + w0 = expr.function + k = tuple(w0.weights) + try: + w = weights[k] + except KeyError: + name = sregistry.make_name(prefix='w') + dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32 + initvalue = tuple(i.subs(subs_user) for i in k) + w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue) + + rebuilt = expr._subs(w0.indexed, w.indexed) + + return rebuilt, [] + + @_core.register(IndexDerivative) def _(expr, c, ispace, weights, reusables, mapper, **kwargs): sregistry = kwargs['sregistry'] options = kwargs['options'] - subs_user = kwargs['subs'] try: cbk0 = deriv_schedule_registry[options['deriv-schedule']] @@ -117,18 +139,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): # Create the concrete Weights array, or reuse an already existing one # if possible - name = sregistry.make_name(prefix='w') - w0 = ideriv.weights.function - dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32 - k = tuple(w0.weights) - try: - w = weights[k] - except KeyError: - initvalue = tuple(i.subs(subs_user) for i in k) - w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue) + w, _ = _core(ideriv.weights, c, ispace, weights, reusables, mapper, **kwargs) # Replace the abstract Weights array with the concrete one - subs = {w0.indexed: w.indexed} + subs = {ideriv.weights.base: w.base} init = uxreplace(init, subs) ideriv = uxreplace(ideriv, subs) @@ -155,13 +169,13 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): ispace1 = IterationSpace.union(ispace, ispace0, relations=extra) # The Symbol that will hold the result of the IndexDerivative computation - # NOTE: created before recurring so that we ultimately get a sound ordering + # NOTE: created before recursing so that we ultimately get a sound ordering try: s = reusables.pop() - assert np.can_cast(s.dtype, dtype) + assert np.can_cast(s.dtype, w.dtype) except KeyError: name = sregistry.make_name(prefix='r') - s = Symbol(name=name, dtype=dtype) + s = Symbol(name=name, dtype=w.dtype) # Go inside `expr` and recursively lower any nested IndexDerivatives expr, processed = _core(expr, c, ispace1, weights, reusables, mapper, **kwargs) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 978c093eed..7b16663a9b 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -5,7 +5,6 @@ from collections import OrderedDict from ctypes import c_uint64 -from functools import singledispatch from operator import itemgetter import numpy as np @@ -98,17 +97,29 @@ def _alloc_object_on_low_lat_mem(self, site, obj, storage): """ decl = Definition(obj) - if obj._C_init: - definition = (decl, obj._C_init) + init = obj._C_init + if not init: + definition = decl + efuncs = () + elif isinstance(init, (list, tuple)): + assert len(init) == 2, "Expected (efunc, call)" + init, definition = init + efuncs = (init,) + elif init.is_Callable: + definition = Call(init.name, init.parameters, + retobj=obj if init.retval else None) + efuncs = (init,) else: - definition = (decl) + definition = (decl, init) + efuncs = () frees = obj._C_free if obj.free_symbols - {obj}: - storage.update(obj, site, objs=definition, frees=frees) + storage.update(obj, site, objs=definition, efuncs=efuncs, frees=frees) else: - storage.update(obj, site, standalones=definition, frees=frees) + storage.update(obj, site, standalones=definition, efuncs=efuncs, + frees=frees) def _alloc_array_on_low_lat_mem(self, site, obj, storage): """ @@ -555,7 +566,7 @@ class DeviceAwareDataManager(DataManager): def __init__(self, options=None, **kwargs): self.gpu_fit = options['gpu-fit'] self.gpu_create = options['gpu-create'] - self.pmode = options.get('place-transfers') + self.gpu_place_transfers = options.get('place-transfers') super().__init__(**kwargs) @@ -588,7 +599,8 @@ def _map_array_on_high_bw_mem(self, site, obj, storage): storage.update(obj, site, maps=mmap, unmaps=unmap) - def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=False): + def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, + read_only=False, **kwargs): """ Map a Function already defined in the host memory in to the device high bandwidth memory. @@ -621,42 +633,41 @@ def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=F storage.update(obj, site, maps=mmap, unmaps=unmap, efuncs=efuncs) @iet_pass - def place_transfers(self, iet, data_movs=None, **kwargs): + def place_transfers(self, iet, data_movs=None, ctx=None, **kwargs): """ Create a new IET with host-device data transfers. This requires mapping symbols to the suitable memory spaces. """ - if not self.pmode: + if not self.gpu_place_transfers: return iet, {} - @singledispatch - def _place_transfers(iet, data_movs): + if not isinstance(iet, EntryFunction): return iet, {} - @_place_transfers.register(EntryFunction) - def _(iet, data_movs): - reads, writes = data_movs + reads, writes = data_movs - # Special symbol which gives user code control over data deallocations - devicerm = DeviceRM() + # Special symbol which gives user code control over data deallocations + devicerm = DeviceRM() - storage = Storage() - for i in filter_sorted(writes): - if i.is_Array: - self._map_array_on_high_bw_mem(iet, i, storage) - else: - self._map_function_on_high_bw_mem(iet, i, storage, devicerm) - for i in filter_sorted(reads - writes): - if i.is_Array: - self._map_array_on_high_bw_mem(iet, i, storage) - else: - self._map_function_on_high_bw_mem(iet, i, storage, devicerm, True) - - iet, efuncs = self._inject_definitions(iet, storage) + storage = Storage() + for i in filter_sorted(writes): + if i.is_Array: + self._map_array_on_high_bw_mem(iet, i, storage) + else: + self._map_function_on_high_bw_mem( + iet, i, storage, devicerm, ctx=ctx + ) + for i in filter_sorted(reads - writes): + if i.is_Array: + self._map_array_on_high_bw_mem(iet, i, storage) + else: + self._map_function_on_high_bw_mem( + iet, i, storage, devicerm, read_only=True, ctx=ctx + ) - return iet, {'efuncs': efuncs} + iet, efuncs = self._inject_definitions(iet, storage) - return _place_transfers(iet, data_movs=data_movs) + return iet, {'efuncs': efuncs} @iet_pass def place_devptr(self, iet, **kwargs): diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 724ccf5c84..0bc7ea3889 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -17,7 +17,7 @@ search) from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass from devito.types import ( - Array, Bundle, ComponentAccess, CompositeObject, Lock, IncrDimension, + Array, Bundle, ComponentAccess, CompositeObject, IncrDimension, FunctionMap, ModuloDimension, Indirection, Pointer, SharedData, ThreadArray, Symbol, Temp, NPThreads, NThreadsBase, Wildcard ) @@ -555,12 +555,19 @@ def _(i, mapper, sregistry): @abstract_object.register(Array) def _(i, mapper, sregistry): - if isinstance(i, Lock): - name = sregistry.make_name(prefix='lock') + name = sregistry.make_name(prefix=i._symbol_prefix) + + if i.initvalue is not None: + initvalue = [] + for v in i.initvalue: + try: + initvalue.append(v.xreplace(mapper)) + except AttributeError: + initvalue.append(v) else: - name = sregistry.make_name(prefix='a') + initvalue = None - v = i._rebuild(name=name, alias=True) + v = i._rebuild(name=name, initvalue=initvalue, alias=True) mapper.update({ i: v, @@ -667,6 +674,16 @@ def _(i, mapper, sregistry): mapper[i] = i._rebuild(name=sregistry.make_name(prefix='ptr')) +@abstract_object.register(FunctionMap) +def _(i, mapper, sregistry): + name = sregistry.make_name(prefix=i._symbol_prefix) + tensor = mapper.get(i.tensor, i.tensor) + + v = i._rebuild(name, tensor) + + mapper[i] = v + + @abstract_object.register(NPThreads) def _(i, mapper, sregistry): mapper[i] = i._rebuild(name=sregistry.make_name(prefix='npthreads')) diff --git a/devito/passes/iet/parpragma.py b/devito/passes/iet/parpragma.py index c8443ceb79..9f64843e28 100644 --- a/devito/passes/iet/parpragma.py +++ b/devito/passes/iet/parpragma.py @@ -427,7 +427,7 @@ def _make_parallel(self, iet, sync_mapper=None): return iet, {'includes': [self.langbb['header']]} - def make_parallel(self, graph): + def make_parallel(self, graph, **kwargs): return self._make_parallel(graph, sync_mapper=graph.sync_mapper) diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index eda71a0b74..23f5c33bf0 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -7,8 +7,8 @@ from devito.tools.dtypes_lowering import dtype_mapper __all__ = ['cast', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'BaseCast', # noqa - 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', - 'LONG'] + 'DOUBLE', 'VOID', 'LONG', 'ULONG', 'NoDeclStruct', 'c_complex', + 'c_double_complex'] limits_mapper = { diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 3a6e61742c..59a28bbd1d 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -18,12 +18,12 @@ from devito.types.basic import Basic __all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'BitwiseAnd', # noqa - 'LeftShift', 'RightShift', 'IntDiv', 'CallFromPointer', + 'LeftShift', 'RightShift', 'IntDiv', 'Terminal', 'CallFromPointer', 'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', - 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', - 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace', 'Rvalue', - 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit', + 'MathFunction', 'InlineIf', 'Reserved', 'ReservedWord', 'Keyword', + 'String', 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace', + 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit', 'VectorAccess'] @@ -147,6 +147,17 @@ def __mul__(self, other): return super().__mul__(other) +class Terminal: + + """ + Abstract base class for all terminal objects, that is, those objects + collected by `retrieve_terminals` in addition to all other SymPy atoms + such as `Symbol`, `Number`, etc. + """ + + pass + + class BasicWrapperMixin: """ @@ -188,7 +199,7 @@ def _sympystr(self, printer): return str(self) -class CallFromPointer(sympy.Expr, Pickable, BasicWrapperMixin): +class CallFromPointer(Expr, Pickable, BasicWrapperMixin, Terminal): """ Symbolic representation of the C notation ``pointer->call(params)``. @@ -256,7 +267,7 @@ def free_symbols(self): __reduce_ex__ = Pickable.__reduce_ex__ -class CallFromComposite(CallFromPointer, Pickable): +class CallFromComposite(CallFromPointer): """ Symbolic representation of the C notation ``composite.call(params)``. @@ -269,7 +280,7 @@ def __str__(self): __repr__ = __str__ -class FieldFromPointer(CallFromPointer, Pickable): +class FieldFromPointer(CallFromPointer): """ Symbolic representation of the C notation ``pointer->field``. @@ -290,7 +301,7 @@ def field(self): __repr__ = __str__ -class FieldFromComposite(CallFromPointer, Pickable): +class FieldFromComposite(CallFromPointer): """ Symbolic representation of the C notation ``composite.field``, @@ -322,10 +333,14 @@ class ListInitializer(sympy.Expr, Pickable): Symbolic representation of the C++ list initializer notation ``{a, b, ...}``. """ - __rargs__ = ('params',) + __rargs__ = ('*params',) __rkwargs__ = ('dtype',) - def __new__(cls, params, dtype=None): + def __new__(cls, *params, dtype=None, evaluate=False): + # Legacy API: allow a single list/tuple as argument + if len(params) == 1 and isinstance(params[0], (list, tuple, np.ndarray)): + params = params[0] + args = [] for p in as_tuple(params): try: @@ -352,7 +367,7 @@ def is_numeric(self): __reduce_ex__ = Pickable.__reduce_ex__ -class UnaryOp(sympy.Expr, Pickable, BasicWrapperMixin): +class UnaryOp(Expr, Pickable, BasicWrapperMixin): """ Symbolic representation of a unary C operator. @@ -490,7 +505,7 @@ def __str__(self): return f"{self._op}{self.base}" -class IndexedPointer(sympy.Expr, Pickable, BasicWrapperMixin): +class IndexedPointer(Expr, Pickable, BasicWrapperMixin, Terminal): """ Symbolic representation of the C notation ``symbol[...]`` @@ -537,7 +552,21 @@ def __str__(self): __reduce_ex__ = Pickable.__reduce_ex__ -class ReservedWord(sympy.Atom, Pickable): +class Reserved(Pickable): + + """ + A base class for all reserved words used throughout the lowering process, + including the final stage of code generation itself. + + Reserved objects have the following properties: + + * `estimate_cost(o) = 0`, where `o` is an instance of Reserved + """ + + pass + + +class ReservedWord(sympy.Atom, Reserved): """ A `ReservedWord` carries a value that has special meaning in the diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 85b7b8bfa9..1f177af690 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -10,8 +10,9 @@ from devito.finite_differences.differentiable import IndexDerivative from devito.logger import warning from devito.symbolics.extended_dtypes import INT -from devito.symbolics.extended_sympy import (CallFromPointer, Cast, - DefFunction, ReservedWord) +from devito.symbolics.extended_sympy import ( + CallFromPointer, Cast, DefFunction, Reserved +) from devito.symbolics.queries import q_routine from devito.tools import as_tuple, prod, is_integer from devito.tools.dtypes_lowering import infer_dtype @@ -175,7 +176,7 @@ def _(expr, estimate, seen): @_estimate_cost.register(ImaginaryUnit) @_estimate_cost.register(Number) -@_estimate_cost.register(ReservedWord) +@_estimate_cost.register(Reserved) def _(expr, estimate, seen): return 0, False diff --git a/devito/symbolics/queries.py b/devito/symbolics/queries.py index 2496a0aeb9..a52fa16aaf 100644 --- a/devito/symbolics/queries.py +++ b/devito/symbolics/queries.py @@ -1,7 +1,6 @@ from sympy import Eq, IndexedBase, Mod, S, diff, nan -from devito.symbolics.extended_sympy import (FieldFromComposite, FieldFromPointer, - IndexedPointer, IntDiv) +from devito.symbolics.extended_sympy import IntDiv, Terminal from devito.tools import as_tuple, is_integer from devito.types.basic import AbstractFunction from devito.types.constant import Constant @@ -16,13 +15,9 @@ 'q_dimension', 'q_positive', 'q_negative'] -# The following SymPy objects are considered tree leaves: -# -# * Number -# * Symbol -# * Indexed -extra_leaves = (FieldFromPointer, FieldFromComposite, IndexedBase, AbstractObject, - IndexedPointer) +# The following SymPy objects are considered tree leaves in addition to the classic +# SymPy atoms such as Number, Symbol, Indexed, etc +extra_leaves = (IndexedBase, AbstractObject, Terminal) def q_symbol(expr): diff --git a/devito/types/array.py b/devito/types/array.py index cb1815f644..75a69ba559 100644 --- a/devito/types/array.py +++ b/devito/types/array.py @@ -127,6 +127,8 @@ class Array(ArrayBasic): is_Array = True + _symbol_prefix = 'a' + __rkwargs__ = (ArrayBasic.__rkwargs__ + ('dimensions', 'scope', 'initvalue')) diff --git a/devito/types/basic.py b/devito/types/basic.py index db455e8924..4ec34e66d9 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1921,8 +1921,17 @@ def _mem_internal_lazy(self): return self._liveness == 'lazy' """ - A modifier added to the subclass C declaration when it appears - in a function signature. For example, a subclass might define `_C_modifier = '&'` + A modifier added to the declaration of the LocalType when it appears in a + function signature. For example, a subclass might define `_C_modifier = '&'` to impose pass-by-reference semantics. """ _C_modifier = None + + """ + One or more optional keywords added to the declaration of the LocalType + in between the type and the variable name when it appears in a function + signature. For example, some languages support these to modify the way + the compiler generates code for passing the parameter and how the + runtime accesses it. + """ + _C_tag = None diff --git a/devito/types/misc.py b/devito/types/misc.py index 8cdad91b07..473ac57b42 100644 --- a/devito/types/misc.py +++ b/devito/types/misc.py @@ -14,8 +14,8 @@ __all__ = ['Timer', 'Pointer', 'VolatileInt', 'FIndexed', 'Wildcard', 'Fence', 'Global', 'Hyperplane', 'Indirection', 'Temp', 'TempArray', 'Jump', - 'nop', 'WeakFence', 'CriticalRegion', 'Auto', 'AutoRef', 'auto', - 'size_t'] + 'nop', 'WeakFence', 'CriticalRegion', 'Auto', 'AutoRef', 'FunctionMap', + 'auto', 'size_t'] class Timer(CompositeObject): @@ -358,6 +358,30 @@ def closing(self): """ +class FunctionMap(LocalObject): + + """ + Wrap a Function in a LocalObject. + """ + + __rargs__ = ('name', 'tensor') + + def __init__(self, name, tensor, **kwargs): + super().__init__(name, **kwargs) + self.tensor = tensor + + def _hashable_content(self): + return super()._hashable_content() + (self.tensor,) + + @property + def free_symbols(self): + """ + The free symbols of a FunctionMap are the free symbols of the + underlying Function. + """ + return super().free_symbols | {self.tensor} + + # *** C/CXX support types size_t = CustomDtype('size_t') diff --git a/devito/types/object.py b/devito/types/object.py index 637e19dea0..a883fd8d51 100644 --- a/devito/types/object.py +++ b/devito/types/object.py @@ -176,10 +176,10 @@ class LocalObject(AbstractObject, LocalType): """ __rargs__ = ('name',) - __rkwargs__ = ('cargs', 'initvalue', 'liveness', 'is_global') + __rkwargs__ = ('cargs', 'initvalue', 'liveness', 'scope') def __init__(self, name, cargs=None, initvalue=None, liveness='lazy', - is_global=False, **kwargs): + scope='stack', **kwargs): self.name = name self.cargs = as_tuple(cargs) @@ -191,16 +191,17 @@ def __init__(self, name, cargs=None, initvalue=None, liveness='lazy', assert liveness in ['eager', 'lazy'] self._liveness = liveness - self._is_global = is_global + assert scope in ['stack', 'shared', 'global'] + self._scope = scope def _hashable_content(self): return (super()._hashable_content() + self.cargs + - (self.initvalue, self.liveness, self.is_global)) + (self.initvalue, self.liveness, self.scope)) @property - def is_global(self): - return self._is_global + def scope(self): + return self._scope @property def free_symbols(self): @@ -236,6 +237,10 @@ def _C_free(self): """ return None + @property + def _mem_shared(self): + return self._scope == 'shared' + @property def _mem_global(self): - return self._is_global + return self._scope == 'global' diff --git a/devito/types/parallel.py b/devito/types/parallel.py index 4383e6c208..26a357cc4d 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -227,6 +227,8 @@ class Lock(Array): is_volatile = True + _symbol_prefix = 'lock' + # Not a performance-sensitive object _data_alignment = False diff --git a/tests/test_iet.py b/tests/test_iet.py index b843d12f9f..a644590e59 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -17,9 +17,12 @@ from devito.passes.iet.engine import Graph from devito.passes.iet.languages.C import CDataManager from devito.symbolics import (Byref, FieldFromComposite, InlineIf, Macro, Class, - String, FLOAT) + String, ListInitializer, SizeOf, FLOAT) from devito.tools import CustomDtype, as_tuple, dtype_to_ctype -from devito.types import CustomDimension, Array, LocalObject, Symbol, Pointer +from devito.types import ( + CustomDimension, Array, LocalObject, Symbol, Pointer +) +from devito.types.misc import FunctionMap @pytest.fixture @@ -298,6 +301,52 @@ def _C_free(self): }""" +def test_make_cuda_tensor_map(): + + class CUTensorMap(FunctionMap): + + dtype = CustomDtype('CUtensorMap') + + @property + def _C_init(self): + symsizes = list(reversed(self.tensor.symbolic_shape)) + sizeof_dtype = SizeOf(self.tensor.dmap._C_typedata) + + sizes = ListInitializer(symsizes) + strides = ListInitializer([ + np.prod(symsizes[:i])*sizeof_dtype for i in range(1, len(symsizes)) + ]) + + arguments = [ + Byref(self), + Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'), + 4, self.tensor.dmap, sizes, strides, + ] + call = Call('cuTensorMapEncodeTiled', arguments) + + return call + + grid = Grid(shape=(10, 10, 10)) + + u = TimeFunction(name='u', grid=grid) + + tmap = CUTensorMap('tmap', u) + + iet = Call('foo', tmap) + iet = ElementalFunction('foo', iet, parameters=()) + dm = CDataManager(sregistry=None) + iet = CDataManager.place_definitions.__wrapped__(dm, iet)[0] + + assert str(iet) == """\ +static void foo() +{ + CUtensorMap tmap; + cuTensorMapEncodeTiled(&tmap,CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]}); + + foo(tmap); +}""" # noqa + + def test_cpp_local_object(): """ Test C++ support for LocalObjects. @@ -310,7 +359,7 @@ class MyObject(LocalObject): lo0 = MyObject('obj0') # Globally-scoped objects must not be declared in the function body - lo1 = MyObject('obj1', is_global=True) + lo1 = MyObject('obj1', scope='global') # A LocalObject using both a template and a modifier class SpecialObject(LocalObject): diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index e77adb0c26..090b40389e 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -424,6 +424,31 @@ def test_namespace(): assert not ns0.free_symbols +def test_list_initializer(): + # Legacy interface + init0 = ListInitializer((1, 2, 3)) + assert str(init0) == '{1, 2, 3}' + + init1 = ListInitializer(1, 2, 3) + assert str(init1) == '{1, 2, 3}' + + # Test hashing and equality + assert init0 == init1 + assert hash(init0) == hash(init1) + init2 = ListInitializer(1, 2) + assert init0 != init2 + assert hash(init0) != hash(init2) + assert hash(init0) == hash(init1) + + # Reconstruction + assert init0 == init0._rebuild() + assert init1 == init1._rebuild() + assert str(init1._rebuild(4, 5)) == '{4, 5}' + + # Accept `evaluate` but gently ignore it + assert str(ListInitializer((1, 2), evaluate=True)) == '{1, 2}' + + def test_rvalue(): ctype = ReservedWord('dummytype') ns = Namespace(['my', 'namespace'])