Skip to content

Commit 8b1cb4e

Browse files
committed
Sparse batch_eth_mnist
1 parent 471d455 commit 8b1cb4e

File tree

4 files changed

+88
-25
lines changed

4 files changed

+88
-25
lines changed

bindsnet/learning/MCC_learning.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ def update(self, **kwargs) -> None:
102102
if ((self.min is not None) or (self.max is not None)) and not isinstance(
103103
self, NoOp
104104
):
105-
self.feature_value.clamp_(self.min, self.max)
105+
if self.feature_value.is_sparse:
106+
self.feature_value = self.feature_value.to_dense().clamp_(self.min, self.max).to_sparse()
107+
else:
108+
self.feature_value.clamp_(self.min, self.max)
106109

107110
@abstractmethod
108111
def reset_state_variables(self) -> None:
@@ -247,10 +250,16 @@ def _connection_update(self, **kwargs) -> None:
247250
torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt
248251
)
249252
else:
250-
self.feature_value -= (
251-
self.reduction(torch.bmm(source_s, target_x), dim=0)
252-
* self.connection.dt
253-
)
253+
if self.feature_value.is_sparse:
254+
self.feature_value -= (
255+
torch.bmm(source_s, target_x)
256+
* self.connection.dt
257+
).to_sparse()
258+
else:
259+
self.feature_value -= (
260+
self.reduction(torch.bmm(source_s, target_x), dim=0)
261+
* self.connection.dt
262+
)
254263
del source_s, target_x
255264

256265
# Post-synaptic update.
@@ -278,10 +287,16 @@ def _connection_update(self, **kwargs) -> None:
278287
torch.mean(self.average_buffer_post, dim=0) * self.connection.dt
279288
)
280289
else:
281-
self.feature_value += (
282-
self.reduction(torch.bmm(source_x, target_s), dim=0)
283-
* self.connection.dt
284-
)
290+
if self.feature_value.is_sparse:
291+
self.feature_value += (
292+
torch.bmm(source_x, target_s)
293+
* self.connection.dt
294+
).to_sparse()
295+
else:
296+
self.feature_value += (
297+
self.reduction(torch.bmm(source_x, target_s), dim=0)
298+
* self.connection.dt
299+
)
285300
del source_x, target_s
286301

287302
super().update()

bindsnet/models/models.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
import torch
55
from scipy.spatial.distance import euclidean
66
from torch.nn.modules.utils import _pair
7+
from torch import device
78

89
from bindsnet.learning import PostPre
10+
from bindsnet.learning.MCC_learning import PostPre as MMCPostPre
911
from bindsnet.network import Network
1012
from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes
11-
from bindsnet.network.topology import Connection, LocalConnection
13+
from bindsnet.network.topology import Connection, LocalConnection, MulticompartmentConnection
14+
from bindsnet.network.topology_features import Weight
1215

1316

1417
class TwoLayerNetwork(Network):
@@ -94,6 +97,9 @@ class DiehlAndCook2015(Network):
9497
def __init__(
9598
self,
9699
n_inpt: int,
100+
device: device,
101+
batch_size: int,
102+
sparse: bool = False,
97103
n_neurons: int = 100,
98104
exc: float = 22.5,
99105
inh: float = 17.5,
@@ -169,28 +175,61 @@ def __init__(
169175
)
170176

171177
# Connections
172-
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
173-
input_exc_conn = Connection(
178+
if sparse:
179+
w = 0.3 * torch.rand(batch_size, self.n_inpt, self.n_neurons)
180+
else:
181+
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
182+
input_exc_conn = MulticompartmentConnection(
174183
source=input_layer,
175184
target=exc_layer,
176-
w=w,
177-
update_rule=PostPre,
178-
nu=nu,
179-
reduction=reduction,
180-
wmin=wmin,
181-
wmax=wmax,
182-
norm=norm,
185+
device=device,
186+
pipeline=[
187+
Weight(
188+
'weight',
189+
w,
190+
range=[wmin, wmax],
191+
norm=norm,
192+
reduction=reduction,
193+
nu=nu,
194+
learning_rule=MMCPostPre,
195+
sparse=sparse
196+
)
197+
]
183198
)
184199
w = self.exc * torch.diag(torch.ones(self.n_neurons))
185-
exc_inh_conn = Connection(
186-
source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc
200+
if sparse:
201+
w = w.unsqueeze(0).expand(batch_size, -1, -1)
202+
exc_inh_conn = MulticompartmentConnection(
203+
source=exc_layer,
204+
target=inh_layer,
205+
device=device,
206+
pipeline=[
207+
Weight(
208+
'weight',
209+
w,
210+
range=[0, self.exc],
211+
sparse=sparse
212+
)
213+
]
187214
)
188215
w = -self.inh * (
189216
torch.ones(self.n_neurons, self.n_neurons)
190217
- torch.diag(torch.ones(self.n_neurons))
191218
)
192-
inh_exc_conn = Connection(
193-
source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0
219+
if sparse:
220+
w = w.unsqueeze(0).expand(batch_size, -1, -1)
221+
inh_exc_conn = MulticompartmentConnection(
222+
source=inh_layer,
223+
target=exc_layer,
224+
device=device,
225+
pipeline=[
226+
Weight(
227+
'weight',
228+
w,
229+
range=[-self.inh, 0],
230+
sparse=sparse
231+
)
232+
]
194233
)
195234

196235
# Add to network

bindsnet/network/topology.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,11 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
450450
if conn_spikes.is_sparse:
451451
conn_spikes = conn_spikes.to_dense()
452452
conn_spikes = conn_spikes.view(s.size(0), self.source.n, self.target.n)
453-
out_signal = conn_spikes.sum(1)
453+
454+
if conn_spikes.is_sparse:
455+
out_signal = conn_spikes.to_dense().sum(1)
456+
else:
457+
out_signal = conn_spikes.sum(1)
454458

455459
if self.traces:
456460
self.activity = out_signal

examples/mnist/batch_eth_mnist.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
parser.add_argument("--test", dest="train", action="store_false")
4545
parser.add_argument("--plot", dest="plot", action="store_true")
4646
parser.add_argument("--gpu", dest="gpu", action="store_true")
47-
parser.set_defaults(plot=True, gpu=True)
47+
parser.add_argument("--sparse", dest="sparse", action="store_true")
48+
parser.set_defaults(gpu=True)
4849

4950
args = parser.parse_args()
5051

@@ -66,6 +67,7 @@
6667
train = args.train
6768
plot = args.plot
6869
gpu = args.gpu
70+
sparse = args.sparse
6971

7072
update_steps = int(n_train / batch_size / n_updates)
7173
update_interval = update_steps * batch_size
@@ -93,6 +95,9 @@
9395

9496
# Build network.
9597
network = DiehlAndCook2015(
98+
device=device,
99+
sparse=sparse,
100+
batch_size=batch_size,
96101
n_inpt=784,
97102
n_neurons=n_neurons,
98103
exc=exc,

0 commit comments

Comments
 (0)