Skip to content

Commit b5a3acc

Browse files
committed
perf: branchless square root implementation
1 parent fcb8a78 commit b5a3acc

File tree

3 files changed

+137
-90
lines changed

3 files changed

+137
-90
lines changed

src/Common.sol

Lines changed: 30 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -321,53 +321,23 @@ function exp2(uint256 x) pure returns (uint256 result) {
321321
/// @return result The index of the most significant bit as a uint256.
322322
/// @custom:smtchecker abstract-function-nondet
323323
function msb(uint256 x) pure returns (uint256 result) {
324-
// 2^128
325324
assembly ("memory-safe") {
326-
let factor := shl(7, gt(x, 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF))
327-
x := shr(factor, x)
328-
result := or(result, factor)
329-
}
330-
// 2^64
331-
assembly ("memory-safe") {
332-
let factor := shl(6, gt(x, 0xFFFFFFFFFFFFFFFF))
333-
x := shr(factor, x)
334-
result := or(result, factor)
335-
}
336-
// 2^32
337-
assembly ("memory-safe") {
338-
let factor := shl(5, gt(x, 0xFFFFFFFF))
339-
x := shr(factor, x)
340-
result := or(result, factor)
341-
}
342-
// 2^16
343-
assembly ("memory-safe") {
344-
let factor := shl(4, gt(x, 0xFFFF))
345-
x := shr(factor, x)
346-
result := or(result, factor)
347-
}
348-
// 2^8
349-
assembly ("memory-safe") {
350-
let factor := shl(3, gt(x, 0xFF))
351-
x := shr(factor, x)
352-
result := or(result, factor)
353-
}
354-
// 2^4
355-
assembly ("memory-safe") {
356-
let factor := shl(2, gt(x, 0xF))
357-
x := shr(factor, x)
358-
result := or(result, factor)
359-
}
360-
// 2^2
361-
assembly ("memory-safe") {
362-
let factor := shl(1, gt(x, 0x3))
363-
x := shr(factor, x)
364-
result := or(result, factor)
365-
}
366-
// 2^1
367-
// No need to shift x any more.
368-
assembly ("memory-safe") {
369-
let factor := gt(x, 0x1)
370-
result := or(result, factor)
325+
// 2^128
326+
result := shl(7, lt(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, x))
327+
// 2^64
328+
result := or(result, shl(6, lt(0xFFFFFFFFFFFFFFFF, shr(result, x))))
329+
// 2^32
330+
result := or(result, shl(5, lt(0xFFFFFFFF, shr(result, x))))
331+
// 2^16
332+
result := or(result, shl(4, lt(0xFFFF, shr(result, x))))
333+
// 2^8
334+
result := or(result, shl(3, lt(0xFF, shr(result, x))))
335+
// 2^4
336+
result := or(result, shl(2, lt(0xF, shr(result, x))))
337+
// 2^2
338+
result := or(result, shl(1, lt(0x3, shr(result, x))))
339+
// 2^1
340+
result := or(result, lt(0x1, shr(result, x)))
371341
}
372342
}
373343

@@ -596,10 +566,6 @@ function mulDivSigned(int256 x, int256 y, int256 denominator) pure returns (int2
596566
/// @return result The result as a uint256.
597567
/// @custom:smtchecker abstract-function-nondet
598568
function sqrt(uint256 x) pure returns (uint256 result) {
599-
if (x == 0) {
600-
return 0;
601-
}
602-
603569
// For our first guess, we calculate the biggest power of 2 which is smaller than the square root of x.
604570
//
605571
// We know that the "msb" (most significant bit) of x is a power of 2 such that we have:
@@ -623,53 +589,27 @@ function sqrt(uint256 x) pure returns (uint256 result) {
623589
// $$
624590
//
625591
// Consequently, $2^{log_2(x) /2} is a good first approximation of sqrt(x) with at least one correct bit.
626-
uint256 xAux = uint256(x);
627-
result = 1;
628-
if (xAux >= 2 ** 128) {
629-
xAux >>= 128;
630-
result <<= 64;
631-
}
632-
if (xAux >= 2 ** 64) {
633-
xAux >>= 64;
634-
result <<= 32;
635-
}
636-
if (xAux >= 2 ** 32) {
637-
xAux >>= 32;
638-
result <<= 16;
639-
}
640-
if (xAux >= 2 ** 16) {
641-
xAux >>= 16;
642-
result <<= 8;
643-
}
644-
if (xAux >= 2 ** 8) {
645-
xAux >>= 8;
646-
result <<= 4;
647-
}
648-
if (xAux >= 2 ** 4) {
649-
xAux >>= 4;
650-
result <<= 2;
651-
}
652-
if (xAux >= 2 ** 2) {
653-
result <<= 1;
592+
unchecked {
593+
// ideally, we should use arithmetic operators, but solc is not smart enough to optimize `2**(msb(x)/2)`
594+
/// forge-lint: disable-next-line(incorrect-shift)
595+
result = 1 << (msb(x) >> 1);
654596
}
655597

656598
// At this point, `result` is an estimation with at least one bit of precision. We know the true value has at
657599
// most 128 bits, since it is the square root of a uint256. Newton's method converges quadratically (precision
658600
// doubles at every iteration). We thus need at most 7 iteration to turn our partial result with one bit of
659601
// precision into the expected uint128 result.
660-
unchecked {
661-
result = (result + x / result) >> 1;
662-
result = (result + x / result) >> 1;
663-
result = (result + x / result) >> 1;
664-
result = (result + x / result) >> 1;
665-
result = (result + x / result) >> 1;
666-
result = (result + x / result) >> 1;
667-
result = (result + x / result) >> 1;
602+
assembly ("memory-safe") {
603+
// note: division by zero in EVM returns zero
604+
result := shr(1, add(result, div(x, result)))
605+
result := shr(1, add(result, div(x, result)))
606+
result := shr(1, add(result, div(x, result)))
607+
result := shr(1, add(result, div(x, result)))
608+
result := shr(1, add(result, div(x, result)))
609+
result := shr(1, add(result, div(x, result)))
610+
result := shr(1, add(result, div(x, result)))
668611

669612
// If x is not a perfect square, round the result toward zero.
670-
uint256 roundedResult = x / result;
671-
if (result >= roundedResult) {
672-
result = roundedResult;
673-
}
613+
result := sub(result, gt(result, div(x, result)))
674614
}
675615
}

test/fuzz/common/msb.t.sol

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,51 @@ import { msb } from "src/Common.sol";
55

66
import { Base_Test } from "../../Base.t.sol";
77

8+
/// @dev Previous implementation, for verifying regressions.
9+
///
10+
/// From https://github.com/PaulRBerg/prb-math/blob/v4.1.0/src/Common.sol#L297-L372
11+
function originalMsb(uint256 x) pure returns (uint256 result) {
12+
assembly ("memory-safe") {
13+
let factor := shl(7, gt(x, 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF))
14+
x := shr(factor, x)
15+
result := or(result, factor)
16+
}
17+
assembly ("memory-safe") {
18+
let factor := shl(6, gt(x, 0xFFFFFFFFFFFFFFFF))
19+
x := shr(factor, x)
20+
result := or(result, factor)
21+
}
22+
assembly ("memory-safe") {
23+
let factor := shl(5, gt(x, 0xFFFFFFFF))
24+
x := shr(factor, x)
25+
result := or(result, factor)
26+
}
27+
assembly ("memory-safe") {
28+
let factor := shl(4, gt(x, 0xFFFF))
29+
x := shr(factor, x)
30+
result := or(result, factor)
31+
}
32+
assembly ("memory-safe") {
33+
let factor := shl(3, gt(x, 0xFF))
34+
x := shr(factor, x)
35+
result := or(result, factor)
36+
}
37+
assembly ("memory-safe") {
38+
let factor := shl(2, gt(x, 0xF))
39+
x := shr(factor, x)
40+
result := or(result, factor)
41+
}
42+
assembly ("memory-safe") {
43+
let factor := shl(1, gt(x, 0x3))
44+
x := shr(factor, x)
45+
result := or(result, factor)
46+
}
47+
assembly ("memory-safe") {
48+
let factor := gt(x, 0x1)
49+
result := or(result, factor)
50+
}
51+
}
52+
853
/// @dev Collection of tests for the most significant bit function available in `Common.sol`.
954
contract Common_Sqrt_Test is Base_Test {
1055
function testFuzz_Msb_FitsUint8(uint256 x) external pure {
@@ -32,4 +77,8 @@ contract Common_Sqrt_Test is Base_Test {
3277
function testFuzz_Msb_Shifts2ToMoreThan(uint256 x) external pure whenShiftLeftDoesNotOverflow(x) {
3378
assertGt(2 << msb(x), x, "Common msb");
3479
}
80+
81+
function testFuzz_Msb_MatchesOriginalImplementation(uint256 x) external pure {
82+
assertEq(msb(x), originalMsb(x), "Common msb");
83+
}
3584
}

test/fuzz/common/sqrt.t.sol

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,60 @@ import { sqrt } from "src/Common.sol";
55

66
import { Base_Test } from "../../Base.t.sol";
77

8+
/// @dev Previous implementation, for verifying regressions.
9+
///
10+
/// From https://github.com/PaulRBerg/prb-math/blob/v4.1.0/src/Common.sol#L587-L675
11+
function originalSqrt(uint256 x) pure returns (uint256 result) {
12+
if (x == 0) {
13+
return 0;
14+
}
15+
16+
uint256 xAux = uint256(x);
17+
result = 1;
18+
if (xAux >= 2 ** 128) {
19+
xAux >>= 128;
20+
result <<= 64;
21+
}
22+
if (xAux >= 2 ** 64) {
23+
xAux >>= 64;
24+
result <<= 32;
25+
}
26+
if (xAux >= 2 ** 32) {
27+
xAux >>= 32;
28+
result <<= 16;
29+
}
30+
if (xAux >= 2 ** 16) {
31+
xAux >>= 16;
32+
result <<= 8;
33+
}
34+
if (xAux >= 2 ** 8) {
35+
xAux >>= 8;
36+
result <<= 4;
37+
}
38+
if (xAux >= 2 ** 4) {
39+
xAux >>= 4;
40+
result <<= 2;
41+
}
42+
if (xAux >= 2 ** 2) {
43+
result <<= 1;
44+
}
45+
46+
unchecked {
47+
result = (result + x / result) >> 1;
48+
result = (result + x / result) >> 1;
49+
result = (result + x / result) >> 1;
50+
result = (result + x / result) >> 1;
51+
result = (result + x / result) >> 1;
52+
result = (result + x / result) >> 1;
53+
result = (result + x / result) >> 1;
54+
55+
uint256 roundedResult = x / result;
56+
if (result >= roundedResult) {
57+
result = roundedResult;
58+
}
59+
}
60+
}
61+
862
/// @dev Collection of tests for the square root function available in `Common.sol`.
963
contract Common_Sqrt_Test is Base_Test {
1064
uint256 internal constant MAX_SQRT = type(uint128).max;
@@ -44,4 +98,8 @@ contract Common_Sqrt_Test is Base_Test {
4498
function testFuzz_Sqrt_OfPowerOfTwo(uint8 x) external pure whenEven(x) {
4599
vm.assertEq(sqrt(2 ** x), 2 ** (x / 2), "Common sqrt");
46100
}
101+
102+
function testFuzz_Sqrt_MatchesOriginalImplementation(uint256 x) external pure {
103+
assertEq(sqrt(x), originalSqrt(x), "Common sqrt");
104+
}
47105
}

0 commit comments

Comments
 (0)