Skip to content

Commit c1299e7

Browse files
authored
Merge pull request #804 from gangliao/check_avx
Add inline and bit manipulation in CpuId.h
2 parents 1adc6a2 + 081eb1c commit c1299e7

File tree

3 files changed

+101
-68
lines changed

3 files changed

+101
-68
lines changed

paddle/utils/CpuId.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License. */
1414

1515
#ifdef _WIN32
1616

17+
#include <intrin.h>
18+
1719
/// for MSVC
1820
#define CPUID(info, x) __cpuidex(info, x, 0)
1921

@@ -31,25 +33,27 @@ namespace paddle {
3133
SIMDFlags::SIMDFlags() {
3234
unsigned int cpuInfo[4];
3335
// CPUID: https://en.wikipedia.org/wiki/CPUID
36+
// clang-format off
3437
CPUID(cpuInfo, 0x00000001);
35-
simd_flags_ |= cpuInfo[3] & (1 << 25) ? SIMD_SSE : SIMD_NONE;
36-
simd_flags_ |= cpuInfo[3] & (1 << 26) ? SIMD_SSE2 : SIMD_NONE;
37-
simd_flags_ |= cpuInfo[2] & (1 << 0) ? SIMD_SSE3 : SIMD_NONE;
38-
simd_flags_ |= cpuInfo[2] & (1 << 9) ? SIMD_SSSE3 : SIMD_NONE;
38+
simd_flags_ |= cpuInfo[3] & (1 << 25) ? SIMD_SSE : SIMD_NONE;
39+
simd_flags_ |= cpuInfo[3] & (1 << 26) ? SIMD_SSE2 : SIMD_NONE;
40+
simd_flags_ |= cpuInfo[2] & (1 << 0) ? SIMD_SSE3 : SIMD_NONE;
41+
simd_flags_ |= cpuInfo[2] & (1 << 9) ? SIMD_SSSE3 : SIMD_NONE;
3942
simd_flags_ |= cpuInfo[2] & (1 << 19) ? SIMD_SSE41 : SIMD_NONE;
4043
simd_flags_ |= cpuInfo[2] & (1 << 20) ? SIMD_SSE42 : SIMD_NONE;
41-
simd_flags_ |= cpuInfo[2] & (1 << 12) ? SIMD_FMA3 : SIMD_NONE;
42-
simd_flags_ |= cpuInfo[2] & (1 << 28) ? SIMD_AVX : SIMD_NONE;
44+
simd_flags_ |= cpuInfo[2] & (1 << 12) ? SIMD_FMA3 : SIMD_NONE;
45+
simd_flags_ |= cpuInfo[2] & (1 << 28) ? SIMD_AVX : SIMD_NONE;
4346

4447
CPUID(cpuInfo, 0x00000007);
45-
simd_flags_ |= cpuInfo[1] & (1 << 5) ? SIMD_AVX2 : SIMD_NONE;
46-
simd_flags_ |= cpuInfo[1] & (1 << 16) ? SIMD_AVX512 : SIMD_NONE;
48+
simd_flags_ |= cpuInfo[1] & (1 << 5) ? SIMD_AVX2 : SIMD_NONE;
49+
simd_flags_ |= cpuInfo[1] & (1 << 16) ? SIMD_AVX512: SIMD_NONE;
4750

4851
CPUID(cpuInfo, 0x80000001);
49-
simd_flags_ |= cpuInfo[2] & (1 << 16) ? SIMD_FMA4 : SIMD_NONE;
52+
simd_flags_ |= cpuInfo[2] & (1 << 16) ? SIMD_FMA4 : SIMD_NONE;
53+
// clang-fotmat on
5054
}
5155

52-
SIMDFlags* SIMDFlags::instance() {
56+
SIMDFlags const* SIMDFlags::instance() {
5357
static SIMDFlags instance;
5458
return &instance;
5559
}

paddle/utils/CpuId.h

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,61 +11,90 @@ limitations under the License. */
1111

1212
#pragma once
1313

14-
#include <iostream>
1514
#include "DisableCopy.h"
1615

1716
namespace paddle {
1817

18+
// clang-format off
19+
enum simd_t {
20+
SIMD_NONE = 0, ///< None
21+
SIMD_SSE = 1 << 0, ///< SSE
22+
SIMD_SSE2 = 1 << 1, ///< SSE 2
23+
SIMD_SSE3 = 1 << 2, ///< SSE 3
24+
SIMD_SSSE3 = 1 << 3, ///< SSSE 3
25+
SIMD_SSE41 = 1 << 4, ///< SSE 4.1
26+
SIMD_SSE42 = 1 << 5, ///< SSE 4.2
27+
SIMD_FMA3 = 1 << 6, ///< FMA 3
28+
SIMD_FMA4 = 1 << 7, ///< FMA 4
29+
SIMD_AVX = 1 << 8, ///< AVX
30+
SIMD_AVX2 = 1 << 9, ///< AVX 2
31+
SIMD_AVX512 = 1 << 10, ///< AVX 512
32+
};
33+
// clang-format on
34+
1935
class SIMDFlags final {
2036
public:
2137
DISABLE_COPY(SIMDFlags);
2238

2339
SIMDFlags();
2440

25-
static SIMDFlags* instance();
41+
static SIMDFlags const* instance();
2642

27-
inline bool isSSE() const { return simd_flags_ & SIMD_SSE; }
28-
inline bool isSSE2() const { return simd_flags_ & SIMD_SSE2; }
29-
inline bool isSSE3() const { return simd_flags_ & SIMD_SSE3; }
30-
inline bool isSSSE3() const { return simd_flags_ & SIMD_SSSE3; }
31-
inline bool isSSE41() const { return simd_flags_ & SIMD_SSE41; }
32-
inline bool isSSE42() const { return simd_flags_ & SIMD_SSE42; }
33-
inline bool isFMA3() const { return simd_flags_ & SIMD_FMA3; }
34-
inline bool isFMA4() const { return simd_flags_ & SIMD_FMA4; }
35-
inline bool isAVX() const { return simd_flags_ & SIMD_AVX; }
36-
inline bool isAVX2() const { return simd_flags_ & SIMD_AVX2; }
37-
inline bool isAVX512() const { return simd_flags_ & SIMD_AVX512; }
43+
inline bool check(int flags) const {
44+
return !((simd_flags_ & flags) ^ flags);
45+
}
3846

3947
private:
40-
enum simd_t {
41-
SIMD_NONE = 0, ///< None
42-
SIMD_SSE = 1 << 0, ///< SSE
43-
SIMD_SSE2 = 1 << 1, ///< SSE 2
44-
SIMD_SSE3 = 1 << 2, ///< SSE 3
45-
SIMD_SSSE3 = 1 << 3, ///< SSSE 3
46-
SIMD_SSE41 = 1 << 4, ///< SSE 4.1
47-
SIMD_SSE42 = 1 << 5, ///< SSE 4.2
48-
SIMD_FMA3 = 1 << 6, ///< FMA 3
49-
SIMD_FMA4 = 1 << 7, ///< FMA 4
50-
SIMD_AVX = 1 << 8, ///< AVX
51-
SIMD_AVX2 = 1 << 9, ///< AVX 2
52-
SIMD_AVX512 = 1 << 10, ///< AVX 512
53-
};
54-
55-
/// simd flags
5648
int simd_flags_ = SIMD_NONE;
5749
};
5850

59-
#define HAS_SSE SIMDFlags::instance()->isSSE()
60-
#define HAS_SSE2 SIMDFlags::instance()->isSSE2()
61-
#define HAS_SSE3 SIMDFlags::instance()->isSSE3()
62-
#define HAS_SSSE3 SIMDFlags::instance()->isSSSE3()
63-
#define HAS_SSE41 SIMDFlags::instance()->isSSE41()
64-
#define HAS_SSE42 SIMDFlags::instance()->isSSE42()
65-
#define HAS_FMA3 SIMDFlags::instance()->isFMA3()
66-
#define HAS_FMA4 SIMDFlags::instance()->isFMA4()
67-
#define HAS_AVX SIMDFlags::instance()->isAVX()
68-
#define HAS_AVX2 SIMDFlags::instance()->isAVX2()
69-
#define HAS_AVX512 SIMDFlags::instance()->isAVX512()
51+
/**
52+
* @brief Check SIMD flags at runtime.
53+
*
54+
* For example.
55+
* @code{.cpp}
56+
*
57+
* if (HAS_SIMD(SIMD_AVX2 | SIMD_FMA4)) {
58+
* avx2_fm4_stub();
59+
* } else if (HAS_SIMD(SIMD_AVX)) {
60+
* avx_stub();
61+
* }
62+
*
63+
* @endcode
64+
*/
65+
#define HAS_SIMD(__flags) SIMDFlags::instance()->check(__flags)
66+
67+
/**
68+
* @brief Check SIMD flags at runtime.
69+
*
70+
* 1. Check all SIMD flags at runtime:
71+
*
72+
* @code{.cpp}
73+
* if (HAS_AVX && HAS_AVX2) {
74+
* avx2_stub();
75+
* }
76+
* @endcod
77+
*
78+
* 2. Check one SIMD flag at runtime:
79+
*
80+
* @code{.cpp}
81+
* if (HAS_SSE41 || HAS_SSE42) {
82+
* sse4_stub();
83+
* }
84+
* @endcode
85+
*/
86+
// clang-format off
87+
#define HAS_SSE HAS_SIMD(SIMD_SSE)
88+
#define HAS_SSE2 HAS_SIMD(SIMD_SSE2)
89+
#define HAS_SSE3 HAS_SIMD(SIMD_SSE3)
90+
#define HAS_SSSE3 HAS_SIMD(SIMD_SSSE3)
91+
#define HAS_SSE41 HAS_SIMD(SIMD_SSE41)
92+
#define HAS_SSE42 HAS_SIMD(SIMD_SSE42)
93+
#define HAS_FMA3 HAS_SIMD(SIMD_FMA3)
94+
#define HAS_FMA4 HAS_SIMD(SIMD_FMA4)
95+
#define HAS_AVX HAS_SIMD(SIMD_AVX)
96+
#define HAS_AVX2 HAS_SIMD(SIMD_AVX2)
97+
#define HAS_AVX512 HAS_SIMD(SIMD_AVX512)
98+
// clang-format on
7099

71100
} // namespace paddle

paddle/utils/tests/test_SIMDFlags.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,33 +19,33 @@ using namespace paddle; // NOLINT
1919

2020
TEST(SIMDFlags, gccTest) {
2121
#if (defined(__GNUC__) || defined(__GNUG__)) && !(defined(__clang__))
22-
CHECK(!__builtin_cpu_supports("sse") != HAS_SSE);
23-
CHECK(!__builtin_cpu_supports("sse2") != HAS_SSE2);
24-
CHECK(!__builtin_cpu_supports("sse3") != HAS_SSE3);
25-
CHECK(!__builtin_cpu_supports("ssse3") != HAS_SSSE3);
22+
// clang-format off
23+
CHECK(!__builtin_cpu_supports("sse") != HAS_SSE);
24+
CHECK(!__builtin_cpu_supports("sse2") != HAS_SSE2);
25+
CHECK(!__builtin_cpu_supports("sse3") != HAS_SSE3);
26+
CHECK(!__builtin_cpu_supports("ssse3") != HAS_SSSE3);
2627
CHECK(!__builtin_cpu_supports("sse4.1") != HAS_SSE41);
2728
CHECK(!__builtin_cpu_supports("sse4.2") != HAS_SSE42);
28-
CHECK(!__builtin_cpu_supports("avx") != HAS_AVX);
29-
CHECK(!__builtin_cpu_supports("avx2") != HAS_AVX2);
29+
CHECK(!__builtin_cpu_supports("avx") != HAS_AVX);
30+
CHECK(!__builtin_cpu_supports("avx2") != HAS_AVX2);
31+
// clang-format on
3032
#endif
3133
}
3234

3335
TEST(SIMDFlags, normalPrint) {
34-
auto simd = SIMDFlags::instance();
35-
LOG(INFO) << "Has SSE2: " << std::boolalpha << simd->isSSE2();
36-
LOG(INFO) << "Has SSE3: " << std::boolalpha << simd->isSSE3();
37-
LOG(INFO) << "Has SSSE3: " << std::boolalpha << simd->isSSSE3();
38-
LOG(INFO) << "Has SSE4.1: " << std::boolalpha << simd->isSSE41();
39-
LOG(INFO) << "Has SSE4.2: " << std::boolalpha << simd->isSSE42();
40-
LOG(INFO) << "Has FMA3: " << std::boolalpha << simd->isFMA3();
41-
LOG(INFO) << "Has FMA4: " << std::boolalpha << simd->isFMA4();
42-
LOG(INFO) << "Has AVX: " << std::boolalpha << simd->isAVX();
43-
LOG(INFO) << "Has AVX2: " << std::boolalpha << simd->isAVX2();
44-
LOG(INFO) << "Has AVX512: " << std::boolalpha << simd->isAVX512();
36+
LOG(INFO) << "Has SSE: " << std::boolalpha << HAS_SSE;
37+
LOG(INFO) << "Has SSE2: " << std::boolalpha << HAS_SSE2;
38+
LOG(INFO) << "Has SSE3: " << std::boolalpha << HAS_SSE3;
39+
LOG(INFO) << "Has SSSE3: " << std::boolalpha << HAS_SSSE3;
40+
LOG(INFO) << "Has SSE4: " << std::boolalpha << HAS_SSE41 || HAS_SSE42;
41+
LOG(INFO) << "Has FMA3: " << std::boolalpha << HAS_FMA3;
42+
LOG(INFO) << "Has FMA4: " << std::boolalpha << HAS_FMA4;
43+
LOG(INFO) << "Has AVX: " << std::boolalpha << HAS_AVX;
44+
LOG(INFO) << "Has AVX2: " << std::boolalpha << HAS_AVX2;
45+
LOG(INFO) << "Has AVX512: " << std::boolalpha << HAS_AVX512;
4546
}
4647

4748
int main(int argc, char** argv) {
4849
testing::InitGoogleTest(&argc, argv);
49-
paddle::initMain(argc, argv);
5050
return RUN_ALL_TESTS();
5151
}

0 commit comments

Comments
 (0)