diff --git a/code_to_optimize/complex_activation.py b/code_to_optimize/complex_activation.py index d9ed216d3..9f68e4562 100644 --- a/code_to_optimize/complex_activation.py +++ b/code_to_optimize/complex_activation.py @@ -1,4 +1,7 @@ import torch + + +@torch.compile() def complex_activation(x): """A custom activation with many small operations - compile makes a huge difference""" # Many sequential element-wise ops create kernel launch overhead @@ -8,4 +11,4 @@ def complex_activation(x): x = x / (1 + x.pow(2)) x = torch.tanh(x) * torch.sigmoid(x) x = x - 0.5 * x.pow(3) - return x \ No newline at end of file + return x