diff --git a/code_to_optimize/unoptimized_neural_net.py b/code_to_optimize/unoptimized_neural_net.py index acd7f0a26..b5ea732b3 100644 --- a/code_to_optimize/unoptimized_neural_net.py +++ b/code_to_optimize/unoptimized_neural_net.py @@ -1,5 +1,6 @@ import torch -import torch.nn as nn +import torch.nn.functional as F +from torch import nn class UnoptimizedNeuralNet(nn.Module): @@ -20,57 +21,9 @@ def forward(self, x): batch_size = x.shape[0] x = x.view(batch_size, -1) - hidden = torch.zeros(batch_size, self.hidden_size, dtype=x.dtype, device=x.device) - for b in range(batch_size): - for i in range(self.hidden_size): - neuron_sum = torch.tensor(0.0, dtype=x.dtype, device=x.device) - for j in range(self.input_size): - neuron_sum = neuron_sum + x[b, j] * self.fc1_weight[i, j] - neuron_sum = neuron_sum + self.fc1_bias[i] - hidden[b, i] = neuron_sum - - activated = torch.zeros_like(hidden) - for b in range(batch_size): - for i in range(self.hidden_size): - val = hidden[b, i] - if val > 0: - activated[b, i] = val - else: - activated[b, i] = 0.0 - - output = torch.zeros(batch_size, self.num_classes, dtype=x.dtype, device=x.device) - for b in range(batch_size): - for i in range(self.num_classes): - neuron_sum = torch.tensor(0.0, dtype=x.dtype, device=x.device) - temp_values = torch.zeros(self.hidden_size, dtype=x.dtype, device=x.device) - for j in range(self.hidden_size): - temp_values[j] = activated[b, j] - - for j in range(self.hidden_size): - neuron_sum = neuron_sum + temp_values[j] * self.fc2_weight[i, j] - - bias_value = self.fc2_bias[i] - neuron_sum = neuron_sum + bias_value - - output[b, i] = neuron_sum - - softmax_output = torch.zeros_like(output) - for b in range(batch_size): - max_val = output[b, 0].clone() - for i in range(1, self.num_classes): - if output[b, i] > max_val: - max_val = output[b, i].clone() - - exp_values = torch.zeros(self.num_classes, dtype=x.dtype, device=x.device) - for i in range(self.num_classes): - exp_val = torch.exp(output[b, i] - max_val) - exp_values[i] = exp_val - - sum_exp = torch.tensor(0.0, dtype=x.dtype, device=x.device) - for i in range(self.num_classes): - sum_exp = sum_exp + exp_values[i] - - for i in range(self.num_classes): - softmax_output[b, i] = exp_values[i] / sum_exp + hidden = F.linear(x, self.fc1_weight, self.fc1_bias) + activated = torch.clamp(hidden, min=0.0) + output = F.linear(activated, self.fc2_weight, self.fc2_bias) + softmax_output = torch.softmax(output, dim=1) return softmax_output.detach()