-
Notifications
You must be signed in to change notification settings - Fork 341
Description
Consider the simulation loop in the Network.run() function:
# Simulate network activity for `time` timesteps.
for t in range(timesteps):
-> for l in self.layers:
# Update each layer of nodes.
if isinstance(self.layers[l], AbstractInput):
self.layers[l].step(inpts[l][t], self.dt)
else:
self.layers[l].step(inpts[l], self.dt)
# Clamp neurons to spike.
clamp = clamps.get(l, None)
if clamp is not None:
self.layers[l].s[clamp] = 1
# Run synapse updates.
-> for c in self.connections:
self.connections[c].update(
reward=reward, mask=masks.get(c, None), learning=self.learning
)
# Get input to all layers.
inpts.update(self.get_inputs())
# Record state variables of interest.
for m in self.monitors:
self.monitors[m].record()
# Re-normalize connections.
-> for c in self.connections:
self.connections[c].normalize()
Where I've marked a ->, there might be an opportunity to use torch.multiprocessing. Since we do updates at time t based on network state at time t-1, all Nodes / Connections updates can be performed with a separate process (thread?) at once. Letting k = no. of layers, m = no. of connections, given enough CPU / GPU resources, the loops marked with -> would have time complexity O(1) instead of O(k), O(m) in the number of layers and connections, respectively.
I think it'd be good to keep around two (?) multiprocessing.Pool objects around, one for Nodes objects and another for Connection objects. Instead of statements of the form:
for l in self.layers:
self.layers[l].step(...)
We might rewrite this as something like:
self.nodes_pool.map(Nodes.step, self.layers)
Here, nodes_pool is defined as an attribute in the Network constructor. This last bit probably won't work straightaway; we'd need to figure out the right syntax (if it exists).
This same idea can also be applied in the Network's reset() and get_inputs() functions.