Skip to content

Commit 19bfb04

Browse files
committed
compiler: fix buffering with multiple conditions
1 parent 87e0263 commit 19bfb04

File tree

3 files changed

+26
-14
lines changed

3 files changed

+26
-14
lines changed

devito/ir/support/guards.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def xandg(self, d, guard):
300300

301301
def pairwise_or(self, d, *guards):
302302
m = dict(self)
303+
guards = list(guards)
303304

304305
if d in m:
305306
guards.append(m[d])

devito/passes/clusters/buffering.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,10 @@ def callback(self, clusters, prefix):
182182
for _, v in descriptors.items():
183183
if not v.is_readonly:
184184
continue
185-
if c is not v.firstread:
185+
if c not in v.firstread:
186186
continue
187187

188-
idxf = v.last_idx
188+
idxf = v.last_idx[c]
189189
idxb = mds[(v.xd, idxf)]
190190

191191
lhs = v.b.indexify()._subs(v.xd, idxb)
@@ -225,10 +225,10 @@ def callback(self, clusters, prefix):
225225
for _, v in descriptors.items():
226226
if v.is_readonly:
227227
continue
228-
if c is not v.lastwrite:
228+
if c not in v.lastwrite:
229229
continue
230230

231-
idxf = v.last_idx
231+
idxf = v.last_idx[c]
232232
idxb = mds[(v.xd, idxf)]
233233

234234
lhs = v.f.indexify()._subs(v.dim, idxf)
@@ -508,17 +508,19 @@ def subdims_mapper(self):
508508

509509
@cached_property
510510
def firstread(self):
511+
first_c = {}
511512
for c in self.clusters:
512-
if c.scope.reads.get(self.f):
513-
return c
514-
return None
513+
if c.scope.reads.get(self.f) and c.guards not in first_c:
514+
first_c[c.guards] = c
515+
return tuple(first_c.values())
515516

516517
@cached_property
517518
def lastwrite(self):
519+
last_c = {}
518520
for c in reversed(self.clusters):
519-
if c.scope.writes.get(self.f):
520-
return c
521-
return None
521+
if c.scope.writes.get(self.f) and c.guards not in last_c:
522+
last_c[c.guards] = c
523+
return tuple(last_c.values())
522524

523525
@property
524526
def is_read(self):
@@ -529,7 +531,7 @@ def is_read(self):
529531

530532
@property
531533
def is_write(self):
532-
return self.lastwrite is not None
534+
return bool(self.lastwrite)
533535

534536
@property
535537
def is_readonly(self):
@@ -604,8 +606,16 @@ def last_idx(self):
604606
* `time-1` in the case of `foo(u[time-1], u[time], u[time+1])`
605607
with a backwards-propagating `time` Dimension.
606608
"""
609+
last_idxs = {}
607610
func = vmax if self.is_forward_buffering else vmin
608-
return func(*[Vector(i) for i in self.indices])[0]
611+
for c in self.lastwrite:
612+
idx = func(*[Vector(i) for i in extract_indices(self.f, self.dim, [c])])[0]
613+
last_idxs[c] = idx
614+
for c in self.firstread:
615+
idx = func(*[Vector(i) for i in extract_indices(self.f, self.dim, [c])])[0]
616+
last_idxs[c] = idx
617+
618+
return last_idxs
609619

610620
@cached_property
611621
def first_idx(self):

devito/types/parallel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from devito.exceptions import InvalidArgument
1616
from devito.parameters import configuration
1717
from devito.tools import as_list, as_tuple, is_integer
18+
from devito.symbolics import search
1819
from devito.types.array import Array, ArrayObject
1920
from devito.types.basic import Scalar, Symbol
2021
from devito.types.dimension import CustomDimension
@@ -75,8 +76,8 @@ def _arg_defaults(self, **kwargs):
7576
raise InvalidArgument("Cannot determine `npthreads`") from None
7677

7778
# If a symbolic object, it must be resolved
78-
if isinstance(npthreads, NPThreads):
79-
npthreads = kwargs.get(npthreads.name, npthreads.size)
79+
for th in search(npthreads, NPThreads):
80+
npthreads = npthreads._subs(th, kwargs.get(th.name, th.size))
8081

8182
return {self.name: max(base_nthreads - npthreads, 1)}
8283

0 commit comments

Comments
 (0)