diff --git a/devito/ir/support/guards.py b/devito/ir/support/guards.py index deba5be148..dbb1fe9ca7 100644 --- a/devito/ir/support/guards.py +++ b/devito/ir/support/guards.py @@ -300,6 +300,7 @@ def xandg(self, d, guard): def pairwise_or(self, d, *guards): m = dict(self) + guards = list(guards) if d in m: guards.append(m[d]) diff --git a/devito/passes/clusters/buffering.py b/devito/passes/clusters/buffering.py index 61676d2e7f..2a8c78e7fc 100644 --- a/devito/passes/clusters/buffering.py +++ b/devito/passes/clusters/buffering.py @@ -182,10 +182,10 @@ def callback(self, clusters, prefix): for _, v in descriptors.items(): if not v.is_readonly: continue - if c is not v.firstread: + if c not in v.firstread: continue - idxf = v.last_idx + idxf = v.last_idx[c] idxb = mds[(v.xd, idxf)] lhs = v.b.indexify()._subs(v.xd, idxb) @@ -225,10 +225,10 @@ def callback(self, clusters, prefix): for _, v in descriptors.items(): if v.is_readonly: continue - if c is not v.lastwrite: + if c not in v.lastwrite: continue - idxf = v.last_idx + idxf = v.last_idx[c] idxb = mds[(v.xd, idxf)] lhs = v.f.indexify()._subs(v.dim, idxf) @@ -508,17 +508,19 @@ def subdims_mapper(self): @cached_property def firstread(self): + mapper = {} for c in self.clusters: if c.scope.reads.get(self.f): - return c - return None + mapper.setdefault(c.guards, c) + return tuple(mapper.values()) @cached_property def lastwrite(self): + mapper = {} for c in reversed(self.clusters): if c.scope.writes.get(self.f): - return c - return None + mapper.setdefault(c.guards, c) + return tuple(mapper.values()) @property def is_read(self): @@ -529,7 +531,7 @@ def is_read(self): @property def is_write(self): - return self.lastwrite is not None + return bool(self.lastwrite) @property def is_readonly(self): @@ -604,8 +606,14 @@ def last_idx(self): * `time-1` in the case of `foo(u[time-1], u[time], u[time+1])` with a backwards-propagating `time` Dimension. """ + mapper = {} func = vmax if self.is_forward_buffering else vmin - return func(*[Vector(i) for i in self.indices])[0] + for c in self.lastwrite + self.firstread: + indices = extract_indices(self.f, self.dim, [c]) + idx = func(*[Vector(i) for i in indices])[0] + mapper[c] = idx + + return frozendict(mapper) @cached_property def first_idx(self): diff --git a/devito/types/parallel.py b/devito/types/parallel.py index 891eaad176..ff9e1405a9 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -14,6 +14,7 @@ from devito.exceptions import InvalidArgument from devito.parameters import configuration +from devito.symbolics import search from devito.tools import as_list, as_tuple, is_integer from devito.types.array import Array, ArrayObject from devito.types.basic import Scalar, Symbol @@ -75,8 +76,8 @@ def _arg_defaults(self, **kwargs): raise InvalidArgument("Cannot determine `npthreads`") from None # If a symbolic object, it must be resolved - if isinstance(npthreads, NPThreads): - npthreads = kwargs.get(npthreads.name, npthreads.size) + for th in search(npthreads, NPThreads): + npthreads = npthreads._subs(th, kwargs.get(th.name, th.size)) return {self.name: max(base_nthreads - npthreads, 1)}