Skip to content

Commit 33caa61

Browse files
committed
Fix sizing of hash functions
1 parent e306e2a commit 33caa61

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

pandas/_libs/new_vector.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ template <typename T> class PandasVector {
112112

113113
template <typename T, bool IsMasked> class PandasHashTable {
114114
public:
115-
using HashValueT = decltype(PandasHashFunction<T>()(T()));
115+
// in English, if the return value from the hashing function is 4 bytes or less, use
116+
// uint32_t for the khash "int" size. Otherwise use 64 bits
117+
using HashValueT = typename std::conditional<sizeof(decltype(PandasHashFunction<T>()(T()))) <= 4, uint32_t, uint64_t>::type;
116118
explicit PandasHashTable<T, IsMasked>() = default;
117119
explicit PandasHashTable<T, IsMasked>(size_t new_size) {
118120
// TODO: C++20 std::in_range would be great to safely check cast
@@ -135,7 +137,7 @@ template <typename T, bool IsMasked> class PandasHashTable {
135137
auto SizeOf() const noexcept {
136138
constexpr auto overhead = 4 * sizeof(uint32_t) + 3 * sizeof(uint32_t *);
137139
const auto for_flags =
138-
std::max(static_cast<HashValueT>(1), hash_map_.n_buckets() >> 5) * sizeof(uint32_t);
140+
std::max(1UL, hash_map_.n_buckets() >> 5) * sizeof(uint32_t);
139141
const auto for_pairs =
140142
hash_map_.n_buckets() * (sizeof(T) + sizeof(Py_ssize_t));
141143

@@ -196,8 +198,8 @@ template <typename T, bool IsMasked> class PandasHashTable {
196198
nb::call_guard<nb::gil_scoped_release>();
197199
const auto keys_v = keys.view();
198200
const auto values_v = values.view();
199-
200-
for (decltype(values_v.shape(0)) i = 0; i < values_v.shape(0); i++) {
201+
const auto n = values_v.shape(0);
202+
for (auto i = decltype(n){0}; i < n; i++) {
201203
hash_map_[keys_v(i)] = values_v(i);
202204
}
203205
}
@@ -212,12 +214,13 @@ template <typename T, bool IsMasked> class PandasHashTable {
212214

213215
nb::call_guard<nb::gil_scoped_release>();
214216
const auto values_v = values.view();
217+
const auto n = values_v.shape(0);
215218
if constexpr (IsMasked) {
216219
const auto mask_base =
217220
nb::cast<nb::ndarray<const uint8_t, nb::ndim<1>>>(mask);
218221
const auto mask_v = mask_base.view();
219222
auto na_position = na_position_; // pandas uses int8_t here - why?
220-
for (decltype(values_v.shape(0)) i = 0; i < values_v.shape(0); i++) {
223+
for (auto i = decltype(n){0}; i < n; i++) {
221224
if (mask_v(i)) {
222225
na_position = i;
223226
} else {
@@ -226,7 +229,7 @@ template <typename T, bool IsMasked> class PandasHashTable {
226229
}
227230
na_position_ = na_position;
228231
} else {
229-
for (decltype(values_v.shape(0)) i = 0; i < values_v.shape(0); i++) {
232+
for (auto i = decltype(n){0}; i < n; i++) {
230233
const auto key = values_v(i);
231234
hash_map_[key] = i;
232235
}
@@ -252,7 +255,7 @@ template <typename T, bool IsMasked> class PandasHashTable {
252255
const auto mask_base =
253256
nb::cast<nb::ndarray<const uint8_t, nb::ndim<1>>>(mask);
254257
const auto mask_v = mask_base.view();
255-
for (decltype(values.shape(0)) i = 0; i < n; i++) {
258+
for (auto i = decltype(n){0}; i < n; i++) {
256259
if (mask_v(i)) {
257260
locs[i] = na_position_;
258261
} else {
@@ -266,7 +269,7 @@ template <typename T, bool IsMasked> class PandasHashTable {
266269
}
267270
}
268271
} else {
269-
for (decltype(values.shape(n)) i = 0; i < n; i++) {
272+
for (auto i = decltype(n){0}; i < n; i++) {
270273
const auto val = values_v(i);
271274
const auto position = hash_map_.get(val);
272275
if (position == hash_map_.end()) {
@@ -326,7 +329,7 @@ template <typename T, bool IsMasked> class PandasHashTable {
326329
auto *labels = new Py_ssize_t[n];
327330
Py_ssize_t count = 0;
328331

329-
for (decltype(values.shape(0)) i = 0; i < n; i++) {
332+
for (auto i = decltype(n){0}; i < n; i++) {
330333
const auto val = values_v(i);
331334

332335
// specific for groupby
@@ -460,7 +463,7 @@ template <typename T, bool IsMasked> class PandasHashTable {
460463
nb::call_guard<nb::gil_scoped_release>();
461464
const auto mask_v = mask.view();
462465

463-
for (decltype(values_v.shape(0)) i = 0; i < n; i++) {
466+
for (auto i = decltype(n){0}; i < n; i++) {
464467
if constexpr (IgnoreNA) {
465468
if (mask_v(i)) {
466469
labels[i] = na_sentinel;
@@ -485,7 +488,7 @@ template <typename T, bool IsMasked> class PandasHashTable {
485488
}
486489
} else {
487490
nb::call_guard<nb::gil_scoped_release>();
488-
for (decltype(values_v.shape(0)) i = 0; i < n; i++) {
491+
for (auto i = decltype(n){0}; i < n; i++) {
489492
const auto val = values_v(i);
490493

491494
if constexpr (IgnoreNA) {
@@ -549,7 +552,7 @@ template <typename T, bool IsMasked> class PandasHashTable {
549552
const auto mask_v = mask.view();
550553

551554
bool seen_na = false;
552-
for (decltype(values.shape(0)) i = 0; i < n; i++) {
555+
for (auto i = decltype(n){0}; i < n; i++) {
553556
const auto val = values_v(i);
554557

555558
if constexpr (IgnoreNA) {
@@ -591,7 +594,7 @@ template <typename T, bool IsMasked> class PandasHashTable {
591594
}
592595
} else {
593596
nb::call_guard<nb::gil_scoped_release>();
594-
for (decltype(values.shape(0)) i = 0; i < n; i++) {
597+
for (auto i = decltype(n){0}; i < n; i++) {
595598
const auto val = values_v(i);
596599
auto k = hash_map_.get(val);
597600
if (k == hash_map_.end()) {
@@ -627,7 +630,7 @@ template <typename T, bool IsMasked> class PandasHashTable {
627630
}
628631
nb::call_guard<nb::gil_scoped_release>();
629632

630-
for (decltype(values.shape(0)) i = 0; i < n; i++) {
633+
for (auto i = decltype(n){0}; i < n; i++) {
631634
if constexpr (IgnoreNA) {
632635
// TODO: current pandas code is a bit messy here...
633636
// labels[i] = na_sentinel
@@ -644,7 +647,7 @@ template <typename T, bool IsMasked> class PandasHashTable {
644647
}
645648
} else {
646649
nb::call_guard<nb::gil_scoped_release>();
647-
for (decltype(values.shape(0)) i = 0; i < n; i++) {
650+
for (auto i = decltype(n){0}; i < n; i++) {
648651
const auto val = values_v(i);
649652
auto k = hash_map_.get(val);
650653
if (k == hash_map_.end()) {

0 commit comments

Comments
 (0)