Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions devito/ir/support/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def xandg(self, d, guard):

def pairwise_or(self, d, *guards):
m = dict(self)
guards = list(guards)

if d in m:
guards.append(m[d])
Expand Down
32 changes: 20 additions & 12 deletions devito/passes/clusters/buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,10 @@ def callback(self, clusters, prefix):
for _, v in descriptors.items():
if not v.is_readonly:
continue
if c is not v.firstread:
if c not in v.firstread:
continue

idxf = v.last_idx
idxf = v.last_idx[c]
idxb = mds[(v.xd, idxf)]

lhs = v.b.indexify()._subs(v.xd, idxb)
Expand Down Expand Up @@ -225,10 +225,10 @@ def callback(self, clusters, prefix):
for _, v in descriptors.items():
if v.is_readonly:
continue
if c is not v.lastwrite:
if c not in v.lastwrite:
continue

idxf = v.last_idx
idxf = v.last_idx[c]
idxb = mds[(v.xd, idxf)]

lhs = v.f.indexify()._subs(v.dim, idxf)
Expand Down Expand Up @@ -508,17 +508,19 @@ def subdims_mapper(self):

@cached_property
def firstread(self):
mapper = {}
for c in self.clusters:
if c.scope.reads.get(self.f):
return c
return None
if c.scope.reads.get(self.f) and c.guards not in mapper:
mapper[c.guards] = c
return tuple(mapper.values())

@cached_property
def lastwrite(self):
mapper = {}
for c in reversed(self.clusters):
if c.scope.writes.get(self.f):
return c
return None
if c.scope.writes.get(self.f) and c.guards not in mapper:
mapper[c.guards] = c
return tuple(mapper.values())

@property
def is_read(self):
Expand All @@ -529,7 +531,7 @@ def is_read(self):

@property
def is_write(self):
return self.lastwrite is not None
return bool(self.lastwrite)

@property
def is_readonly(self):
Expand Down Expand Up @@ -604,8 +606,14 @@ def last_idx(self):
* `time-1` in the case of `foo(u[time-1], u[time], u[time+1])`
with a backwards-propagating `time` Dimension.
"""
mapper = {}
func = vmax if self.is_forward_buffering else vmin
return func(*[Vector(i) for i in self.indices])[0]
for c in self.lastwrite + self.firstread:
indices = extract_indices(self.f, self.dim, [c])
idx = func(*[Vector(i) for i in indices])[0]
mapper[c] = idx

return mapper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically if you return a frozendict it'd be better but no big deal as long as we don't do crazy things at the caller site


@cached_property
def first_idx(self):
Expand Down
5 changes: 3 additions & 2 deletions devito/types/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from devito.exceptions import InvalidArgument
from devito.parameters import configuration
from devito.symbolics import search
from devito.tools import as_list, as_tuple, is_integer
from devito.types.array import Array, ArrayObject
from devito.types.basic import Scalar, Symbol
Expand Down Expand Up @@ -75,8 +76,8 @@ def _arg_defaults(self, **kwargs):
raise InvalidArgument("Cannot determine `npthreads`") from None

# If a symbolic object, it must be resolved
if isinstance(npthreads, NPThreads):
npthreads = kwargs.get(npthreads.name, npthreads.size)
for th in search(npthreads, NPThreads):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why search?

npthreads = npthreads._subs(th, kwargs.get(th.name, th.size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why subs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be an expression, e.g nthreads0 + nthreads1 +.... And need to replace all

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be an expression, e.g nthreads0 + nthreads1 +.... And need to replace all

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be an expression, e.g nthreads0 + nthreads1 +.... And need to replace all

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mh, what would the reproducer be?


return {self.name: max(base_nthreads - npthreads, 1)}

Expand Down
Loading