diff --git a/code_to_optimize/complex_activation.py b/code_to_optimize/complex_activation.py index d9ed216d3..5fd521b40 100644 --- a/code_to_optimize/complex_activation.py +++ b/code_to_optimize/complex_activation.py @@ -1,4 +1,9 @@ import torch + +_compile = getattr(torch, "compile", None) + + +@_compile def complex_activation(x): """A custom activation with many small operations - compile makes a huge difference""" # Many sequential element-wise ops create kernel launch overhead @@ -8,4 +13,8 @@ def complex_activation(x): x = x / (1 + x.pow(2)) x = torch.tanh(x) * torch.sigmoid(x) x = x - 0.5 * x.pow(3) - return x \ No newline at end of file + return x + + +def _identity_decorator(fn): + return fn