⚡️ Speed up function complex_activation by 77%
#1172
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 77% (0.77x) speedup for
complex_activationincode_to_optimize/complex_activation.py⏱️ Runtime :
2.77 milliseconds→1.57 milliseconds(best of134runs)📝 Explanation and details
The optimized code achieves a 77% speedup (from 2.77ms to 1.57ms) by applying
torch.compileto thecomplex_activationfunction. This decorator enables kernel fusion, which is critical for this workload.Why this optimization works:
The original implementation performs 6 sequential element-wise operations on tensors:
torch.sin(x)torch.cos(x)torch.exp(-x.abs())(1 + x.pow(2))torch.tanh(x) * torch.sigmoid(x)0.5 * x.pow(3)Without compilation, each operation launches a separate CUDA kernel (or CPU loop), incurring:
The line profiler shows these operations dominate runtime, with
torch.exp(-x.abs())and division taking 40.4% and 35.2% respectively.What
torch.compiledoes:By decorating the function with
@torch.compile, PyTorch's compiler:The optimized line profiler shows the function now executes through a single
compile_wrappercall, with the actual computation (return fn(*args, **kwargs)) taking the bulk of time as one fused operation rather than 6 separate ones.Fallback safety:
The code includes a compatibility check: if
torch.compileis unavailable (PyTorch < 2.0), it falls back to an identity decorator that preserves the original behavior. This ensures backward compatibility without breaking existing deployments.Impact:
This optimization is particularly effective for:
✅ Correctness verification report:
⚙️ Click to see Existing Unit Tests
test_complex_activation.py::TestComplexActivation.test_deterministictest_complex_activation.py::TestComplexActivation.test_gradient_flowtest_complex_activation.py::TestComplexActivation.test_negative_inputtest_complex_activation.py::TestComplexActivation.test_output_boundedtest_complex_activation.py::TestComplexActivation.test_output_devicetest_complex_activation.py::TestComplexActivation.test_output_dtypetest_complex_activation.py::TestComplexActivation.test_output_is_finitetest_complex_activation.py::TestComplexActivation.test_output_shapetest_complex_activation.py::TestComplexActivation.test_positive_inputtest_complex_activation.py::TestComplexActivation.test_zero_inputTo edit these changes
git checkout codeflash/optimize-complex_activation-mkvhfuuuand push.