Skip to content

Possible use of torch.multiprocessing #139

@djsaunde

Description

@djsaunde

Consider the simulation loop in the Network.run() function:

# Simulate network activity for `time` timesteps.
        for t in range(timesteps):
->          for l in self.layers:
                # Update each layer of nodes.
                if isinstance(self.layers[l], AbstractInput):
                    self.layers[l].step(inpts[l][t], self.dt)
                else:
                    self.layers[l].step(inpts[l], self.dt)

                # Clamp neurons to spike.
                clamp = clamps.get(l, None)
                if clamp is not None:
                    self.layers[l].s[clamp] = 1

            # Run synapse updates.
->          for c in self.connections:
                self.connections[c].update(
                    reward=reward, mask=masks.get(c, None), learning=self.learning
                )

            # Get input to all layers.
            inpts.update(self.get_inputs())

            # Record state variables of interest.
            for m in self.monitors:
                self.monitors[m].record()
        
        # Re-normalize connections.
->      for c in self.connections:
            self.connections[c].normalize()

Where I've marked a ->, there might be an opportunity to use torch.multiprocessing. Since we do updates at time t based on network state at time t-1, all Nodes / Connections updates can be performed with a separate process (thread?) at once. Letting k = no. of layers, m = no. of connections, given enough CPU / GPU resources, the loops marked with -> would have time complexity O(1) instead of O(k), O(m) in the number of layers and connections, respectively.

I think it'd be good to keep around two (?) multiprocessing.Pool objects around, one for Nodes objects and another for Connection objects. Instead of statements of the form:

for l in self.layers:
    self.layers[l].step(...)

We might rewrite this as something like:

self.nodes_pool.map(Nodes.step, self.layers)

Here, nodes_pool is defined as an attribute in the Network constructor. This last bit probably won't work straightaway; we'd need to figure out the right syntax (if it exists).

This same idea can also be applied in the Network's reset() and get_inputs() functions.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions