Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 6 additions & 53 deletions code_to_optimize/unoptimized_neural_net.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn


class UnoptimizedNeuralNet(nn.Module):
Expand All @@ -20,57 +21,9 @@ def forward(self, x):
batch_size = x.shape[0]
x = x.view(batch_size, -1)

hidden = torch.zeros(batch_size, self.hidden_size, dtype=x.dtype, device=x.device)
for b in range(batch_size):
for i in range(self.hidden_size):
neuron_sum = torch.tensor(0.0, dtype=x.dtype, device=x.device)
for j in range(self.input_size):
neuron_sum = neuron_sum + x[b, j] * self.fc1_weight[i, j]
neuron_sum = neuron_sum + self.fc1_bias[i]
hidden[b, i] = neuron_sum

activated = torch.zeros_like(hidden)
for b in range(batch_size):
for i in range(self.hidden_size):
val = hidden[b, i]
if val > 0:
activated[b, i] = val
else:
activated[b, i] = 0.0

output = torch.zeros(batch_size, self.num_classes, dtype=x.dtype, device=x.device)
for b in range(batch_size):
for i in range(self.num_classes):
neuron_sum = torch.tensor(0.0, dtype=x.dtype, device=x.device)
temp_values = torch.zeros(self.hidden_size, dtype=x.dtype, device=x.device)
for j in range(self.hidden_size):
temp_values[j] = activated[b, j]

for j in range(self.hidden_size):
neuron_sum = neuron_sum + temp_values[j] * self.fc2_weight[i, j]

bias_value = self.fc2_bias[i]
neuron_sum = neuron_sum + bias_value

output[b, i] = neuron_sum

softmax_output = torch.zeros_like(output)
for b in range(batch_size):
max_val = output[b, 0].clone()
for i in range(1, self.num_classes):
if output[b, i] > max_val:
max_val = output[b, i].clone()

exp_values = torch.zeros(self.num_classes, dtype=x.dtype, device=x.device)
for i in range(self.num_classes):
exp_val = torch.exp(output[b, i] - max_val)
exp_values[i] = exp_val

sum_exp = torch.tensor(0.0, dtype=x.dtype, device=x.device)
for i in range(self.num_classes):
sum_exp = sum_exp + exp_values[i]

for i in range(self.num_classes):
softmax_output[b, i] = exp_values[i] / sum_exp
hidden = F.linear(x, self.fc1_weight, self.fc1_bias)
activated = torch.clamp(hidden, min=0.0)
output = F.linear(activated, self.fc2_weight, self.fc2_bias)
softmax_output = torch.softmax(output, dim=1)

return softmax_output.detach()
Loading