Skip to content

Commit 809ebb6

Browse files
dsharletjiawen
andauthored
Add z order traversal helpers (#97)
* Make split results and index iterators look more like containers * Let apply work with anything tuple-like * Add more split tests * Add z order helpers * These need to be inlined * Add z order version of matrix multiply * Hide for_each_index_z_order in internal * Dead code * Fix comment * It's better to do ko outermost and only z-order io, jo * Rename to match for_each_index_in_order * Pass z by value to avoid save/restore, makes calls tail recursive(?) * Tweak tiling and comments * No need to pass functions by value here * Add profiler output to gitignore * Fix special test case * Apply suggestions from code review Co-authored-by: Jiawen (Kevin) Chen <jiawen@users.noreply.github.com> * Make index_iterator a fully implemented random_access_iterator * Add blas comparison * Don't depend on ::size() for range * Update comment * Add prefetching * Add performance data * Test edge cases --------- Co-authored-by: Jiawen (Kevin) Chen <jiawen@users.noreply.github.com>
1 parent ba14d8a commit 809ebb6

File tree

10 files changed

+442
-40
lines changed

10 files changed

+442
-40
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,8 @@ docs/*
3535
*~
3636

3737
# Visual Studio folder status
38-
.vs
38+
.vs
39+
40+
# perf files
41+
perf.data
42+
perf.data.old

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ CFLAGS := $(CFLAGS) -O2 -ffast-math -fstrict-aliasing -fPIE
22
CXXFLAGS := $(CXXFLAGS) -std=c++14 -Wall
33
LDFLAGS := $(LDFLAGS)
44

5-
DEPS := include/array/array.h include/array/ein_reduce.h include/array/image.h include/array/matrix.h
5+
DEPS := include/array/array.h include/array/ein_reduce.h include/array/image.h include/array/matrix.h include/array/z_order.h
66

77
TEST_SRC := $(filter-out test/errors.cpp, $(wildcard test/*.cpp))
88
TEST_OBJ := $(TEST_SRC:%.cpp=obj/%.o)

examples/linear_algebra/Makefile

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@ CFLAGS := $(CFLAGS) -O2 -march=native -ffast-math -fstrict-aliasing -fno-excepti
22
CXXFLAGS := $(CXXFLAGS) -std=c++14 -Wall
33
LDFLAGS := $(LDFLAGS)
44

5-
DEPS := ../../include/array/array.h ../../include/array/matrix.h ../benchmark.h ../../include/array/ein_reduce.h
5+
ifneq ($(BLAS), )
6+
CFLAGS += -DBLAS
7+
LDFLAGS += -lblas
8+
endif
9+
10+
DEPS := ../../include/array/array.h ../../include/array/matrix.h ../benchmark.h ../../include/array/ein_reduce.h ../../include/array/z_order.h
611

712
bin/%: %.cpp $(DEPS)
813
mkdir -p $(@D)
9-
$(CXX) -I../../include -I../ -o $@ $< $(CFLAGS) $(CXXFLAGS) -lstdc++ -lm
14+
$(CXX) -I../../include -I../ -o $@ $< $(CFLAGS) $(CXXFLAGS) -lstdc++ -lm $(LDFLAGS)
1015

1116
.PHONY: all clean test
1217

examples/linear_algebra/matrix.cpp

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414

1515
#include "array/matrix.h"
1616
#include "array/ein_reduce.h"
17+
#include "array/z_order.h"
1718
#include "benchmark.h"
1819

1920
#include <functional>
2021
#include <iostream>
2122
#include <random>
2223

24+
#ifdef BLAS
25+
#include "cblas.h"
26+
#endif
27+
2328
using namespace nda;
2429

2530
// Make it easier to read the generated assembly for these functions.
@@ -209,9 +214,7 @@ NOINLINE void multiply_reduce_tiles(const_matrix_ref<T> A, const_matrix_ref<T> B
209214
}
210215
}
211216

212-
// With clang -O2, this generates (almost) the same fast inner loop as the above!!
213-
// It only spills one accumulator register, and produces statistically identical
214-
// performance.
217+
// With clang -O2, this generates exactly the same fast inner loop as the above!!
215218
template <typename T>
216219
NOINLINE void multiply_ein_reduce_tiles(
217220
const_matrix_ref<T> A, const_matrix_ref<T> B, matrix_ref<T> C) {
@@ -259,13 +262,90 @@ NOINLINE void multiply_ein_reduce_tiles(
259262
}
260263
}
261264

265+
// This is similar to the above, but:
266+
// - It additionally splits the reduction dimension k,
267+
// - It traverses the io, jo loops in z order, to improve locality,
268+
// - It prefetches in the inner loop.
269+
// This version achieves ~90% of the theoretical peak performance of my AMD Ryzen 5800X.
270+
template <typename T>
271+
NOINLINE void multiply_reduce_tiles_z_order(const_matrix_ref<T> A, const_matrix_ref<T> B, matrix_ref<T> C) {
272+
// Adjust this depending on the target architecture. For AVX2,
273+
// vectors are 256-bit.
274+
constexpr index_t vector_size = 32 / sizeof(T);
275+
constexpr index_t cache_line_size = 64 / sizeof(T);
276+
277+
// We want the tiles to be as big as possible without spilling any
278+
// of the accumulator registers to the stack.
279+
constexpr index_t tile_rows = 4;
280+
constexpr index_t tile_cols = vector_size * 3;
281+
constexpr index_t tile_k = 256;
282+
283+
// TODO: It seems like z-ordering all of io, jo, ko should be best...
284+
// But this seems better, even without the added convenience for initializing
285+
// the output.
286+
for (auto ko : split(A.j(), tile_k)) {
287+
auto split_i = split<tile_rows>(C.i());
288+
auto split_j = split<tile_cols>(C.j());
289+
for_all_in_z_order(std::make_tuple(split_i, split_j), [&](auto io, auto jo) {
290+
// Make a reference to this tile of the output.
291+
auto C_ijo = C(io, jo);
292+
293+
// Define an accumulator buffer.
294+
T buffer[tile_rows * tile_cols] = {0};
295+
auto accumulator = make_array_ref(buffer, make_compact(C_ijo.shape()));
296+
297+
// Perform the matrix multiplication for this tile.
298+
for (index_t k : ko) {
299+
for (index_t i = 0; i < io.extent(); i += cache_line_size) {
300+
_mm_prefetch(&A(io.min() + i, k + 8), _MM_HINT_T0);
301+
}
302+
for (index_t j = 0; j < jo.extent(); j += cache_line_size) {
303+
_mm_prefetch(&B(k + 4, jo.min() + j), _MM_HINT_T0);
304+
}
305+
for (index_t i : io) {
306+
for (index_t j : jo) {
307+
accumulator(i, j) += A(i, k) * B(k, j);
308+
}
309+
}
310+
}
311+
312+
// Add the accumulators for this iteration of ko to the output.
313+
// Because we split the K dimension, we are doing this more than once per
314+
// tile of output. To avoid adding to overlapping regions more than once
315+
// (when `split<>` is applied to a dimension not divided by the split factor),
316+
// we need to only initialize the result for the first iteration of ko.
317+
if (ko.min() == A.j().min()) {
318+
for (index_t i : io) {
319+
for (index_t j : jo) {
320+
C_ijo(i, j) = accumulator(i, j);
321+
}
322+
}
323+
} else {
324+
for (index_t i : io) {
325+
for (index_t j : jo) {
326+
C_ijo(i, j) += accumulator(i, j);
327+
}
328+
}
329+
}
330+
});
331+
}
332+
}
333+
334+
#ifdef BLAS
335+
void multiply_blas(const_matrix_ref<float> A, const_matrix_ref<float> B, matrix_ref<float> C) {
336+
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, C.i().extent(), C.j().extent(),
337+
A.j().extent(), 1.0, A.base(), A.i().stride(), B.base(), B.i().stride(), 0.0, C.base(),
338+
C.i().stride());
339+
}
340+
#endif
341+
262342
float relative_error(float A, float B) { return std::abs(A - B) / std::max(A, B); }
263343

264344
int main(int, const char**) {
265345
// Define two input matrices.
266-
constexpr index_t M = 32;
267-
constexpr index_t K = 10000;
268-
constexpr index_t N = 64;
346+
constexpr index_t M = 384;
347+
constexpr index_t K = 1536;
348+
constexpr index_t N = 384;
269349
matrix<float> A({M, K});
270350
matrix<float> B({K, N});
271351

@@ -278,8 +358,7 @@ int main(int, const char**) {
278358
generate(B, [&]() { return uniform(rng); });
279359

280360
matrix<float> c_ref({M, N});
281-
double ref_time = benchmark([&]() { multiply_ref(A.data(), B.data(), c_ref.data(), M, K, N); });
282-
std::cout << "reference time: " << ref_time * 1e3 << " ms" << std::endl;
361+
multiply_ref(A.data(), B.data(), c_ref.data(), M, K, N);
283362

284363
struct version {
285364
const char* name;
@@ -294,12 +373,17 @@ int main(int, const char**) {
294373
{"ein_reduce_matrix", multiply_ein_reduce_matrix<float>},
295374
{"reduce_tiles", multiply_reduce_tiles<float>},
296375
{"ein_reduce_tiles", multiply_ein_reduce_tiles<float>},
376+
{"reduce_tiles_z_order", multiply_reduce_tiles_z_order<float>},
377+
#ifdef BLAS
378+
{"blas", multiply_blas},
379+
#endif
297380
};
298381
for (auto i : versions) {
299382
// Compute the result using all matrix multiply methods.
300383
matrix<float> C({M, N});
301384
double time = benchmark([&]() { i.fn(A.cref(), B.cref(), C.ref()); });
302-
std::cout << i.name << " time: " << time * 1e3 << " ms" << std::endl;
385+
double flops = M * N * K * 2 / time;
386+
std::cout << i.name << " time: " << time * 1e3 << " ms, " << flops / 1e9 << " GFLOP/s" << std::endl;
303387

304388
// Verify the results from all methods are equal.
305389
const float tolerance = 1e-4f;

include/array/array.h

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,33 @@ class index_iterator {
198198
}
199199

200200
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator operator++(int) { return index_iterator(i_++); }
201+
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator operator--(int) { return index_iterator(i_--); }
201202
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator& operator++() {
202203
++i_;
203204
return *this;
204205
}
206+
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator& operator--() {
207+
--i_;
208+
return *this;
209+
}
210+
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator& operator+=(index_t r) {
211+
i_ += r;
212+
return *this;
213+
}
214+
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator& operator-=(index_t r) {
215+
i_ -= r;
216+
return *this;
217+
}
218+
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator operator+(index_t r) {
219+
return index_iterator(i_ + r);
220+
}
221+
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator operator-(index_t r) {
222+
return index_iterator(i_ - r);
223+
}
224+
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t operator-(const index_iterator& r) {
225+
return i_ - r.i_;
226+
}
227+
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t operator[](index_t n) const { return i_ + n; }
205228
};
206229

207230
template <index_t Min, index_t Extent, index_t Stride>
@@ -271,6 +294,7 @@ class interval {
271294
NDARRAY_INLINE NDARRAY_HOST_DEVICE void set_min(index_t min) { min_ = min; }
272295
/** Get or set the number of indices in this interval. */
273296
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t extent() const { return extent_; }
297+
NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t size() const { return extent_; }
274298
NDARRAY_INLINE NDARRAY_HOST_DEVICE void set_extent(index_t extent) { extent_ = extent; }
275299

276300
/** Get or set the last index in this interval. */
@@ -433,6 +457,7 @@ class dim : protected interval<Min_, Extent_> {
433457
using base_range::begin;
434458
using base_range::end;
435459
using base_range::extent;
460+
using base_range::size;
436461
using base_range::is_in_range;
437462
using base_range::max;
438463
using base_range::min;
@@ -490,6 +515,8 @@ using broadcast_dim = dim<Min, Extent, 0>;
490515
namespace internal {
491516

492517
// An iterator for a range of intervals.
518+
// This is like a random access iterator in that it can move forward in constant time, but
519+
// but unlike a random access iterator, it cannot be moved in reverse.
493520
template <index_t InnerExtent = dynamic>
494521
class split_iterator {
495522
fixed_interval<InnerExtent> i;
@@ -507,47 +534,69 @@ class split_iterator {
507534
}
508535

509536
NDARRAY_HOST_DEVICE fixed_interval<InnerExtent> operator*() const { return i; }
537+
NDARRAY_HOST_DEVICE const fixed_interval<InnerExtent>* operator->() const { return &i; }
510538

511-
NDARRAY_HOST_DEVICE split_iterator& operator++() {
539+
NDARRAY_HOST_DEVICE split_iterator& operator+=(index_t n) {
540+
assert(n >= 0);
512541
if (is_static(InnerExtent)) {
513542
// When the extent of the inner split is a compile-time constant,
514543
// we can't shrink the out of bounds interval. Instead, shift the min,
515544
// assuming the outer dimension is bigger than the inner extent.
516-
i.set_min(i.min() + InnerExtent);
545+
i.set_min(i.min() + InnerExtent * n);
517546
// Only shift the min when this straddles the end of the buffer,
518547
// so the iterator can advance to the end (one past the max).
519548
if (i.min() <= outer_max && i.max() > outer_max) { i.set_min(outer_max - InnerExtent + 1); }
520549
} else {
521550
// When the extent of the inner split is not a compile-time constant,
522551
// we can just modify the extent.
523-
i.set_min(i.min() + i.extent());
552+
i.set_min(i.min() + i.extent() * n);
524553
index_t max = min(i.max(), outer_max);
525554
i.set_extent(max - i.min() + 1);
526555
}
527556
return *this;
528557
}
558+
NDARRAY_HOST_DEVICE split_iterator operator+(index_t n) const {
559+
split_iterator<InnerExtent> result(*this);
560+
return result += n;
561+
}
562+
NDARRAY_HOST_DEVICE split_iterator& operator++() {
563+
return *this += 1;
564+
}
529565
NDARRAY_HOST_DEVICE split_iterator operator++(int) {
530566
split_iterator<InnerExtent> result(*this);
531-
++*this;
567+
*this += 1;
532568
return result;
533569
}
570+
571+
NDARRAY_HOST_DEVICE index_t operator-(const split_iterator& r) const {
572+
return r.i.extent() > 0 ? (i.max() - r.i.min() + r.i.extent() - i.extent()) / r.i.extent() : 0;
573+
}
574+
575+
NDARRAY_HOST_DEVICE fixed_interval<InnerExtent> operator[](index_t n) const {
576+
split_iterator result(*this);
577+
result += n;
578+
return *result;
579+
}
534580
};
535581

536-
// TODO: Remove this when std::iterator_range is standard.
537-
template <class T>
538-
class iterator_range {
539-
T begin_;
540-
T end_;
582+
template <index_t InnerExtent = dynamic>
583+
class split_result {
584+
public:
585+
using iterator = split_iterator<InnerExtent>;
586+
587+
private:
588+
iterator begin_;
589+
iterator end_;
541590

542591
public:
543-
NDARRAY_HOST_DEVICE iterator_range(T begin, T end) : begin_(begin), end_(end) {}
592+
NDARRAY_HOST_DEVICE split_result(iterator begin, iterator end) : begin_(begin), end_(end) {}
544593

545-
NDARRAY_HOST_DEVICE T begin() const { return begin_; }
546-
NDARRAY_HOST_DEVICE T end() const { return end_; }
547-
};
594+
NDARRAY_HOST_DEVICE iterator begin() const { return begin_; }
595+
NDARRAY_HOST_DEVICE iterator end() const { return end_; }
548596

549-
template <index_t InnerExtent = dynamic>
550-
using split_iterator_range = iterator_range<split_iterator<InnerExtent>>;
597+
NDARRAY_HOST_DEVICE index_t size() const { return end_ - begin_; }
598+
NDARRAY_HOST_DEVICE iterator operator[](index_t i) const { return begin_ + i; }
599+
};
551600

552601
} // namespace internal
553602

@@ -562,14 +611,14 @@ using split_iterator_range = iterator_range<split_iterator<InnerExtent>>;
562611
* - `split<5>(interval<>(0, 12))` produces the intervals `[0, 5)`,
563612
* `[5, 10)`, `[7, 12)`. Note the last two intervals overlap. */
564613
template <index_t InnerExtent, index_t Min, index_t Extent>
565-
NDARRAY_HOST_DEVICE internal::split_iterator_range<InnerExtent> split(
614+
NDARRAY_HOST_DEVICE internal::split_result<InnerExtent> split(
566615
const interval<Min, Extent>& v) {
567616
assert(v.extent() >= InnerExtent);
568617
return {{fixed_interval<InnerExtent>(v.min()), v.max()},
569618
{fixed_interval<InnerExtent>(v.max() + 1), v.max()}};
570619
}
571620
template <index_t InnerExtent, index_t Min, index_t Extent, index_t Stride>
572-
NDARRAY_HOST_DEVICE internal::split_iterator_range<InnerExtent> split(
621+
NDARRAY_HOST_DEVICE internal::split_result<InnerExtent> split(
573622
const dim<Min, Extent, Stride>& v) {
574623
return split<InnerExtent>(interval<Min, Extent>(v.min(), v.extent()));
575624
}
@@ -585,13 +634,13 @@ NDARRAY_HOST_DEVICE internal::split_iterator_range<InnerExtent> split(
585634
// avoid some conversion messes. dim<Min, Extent> probably can't implicitly
586635
// convert to interval<>.
587636
template <index_t Min, index_t Extent>
588-
NDARRAY_HOST_DEVICE internal::split_iterator_range<> split(
637+
NDARRAY_HOST_DEVICE internal::split_result<> split(
589638
const interval<Min, Extent>& v, index_t inner_extent) {
590639
return {{interval<>(v.min(), internal::min(inner_extent, v.extent())), v.max()},
591640
{interval<>(v.max() + 1, 0), v.max()}};
592641
}
593642
template <index_t Min, index_t Extent, index_t Stride>
594-
NDARRAY_HOST_DEVICE internal::split_iterator_range<> split(
643+
NDARRAY_HOST_DEVICE internal::split_result<> split(
595644
const dim<Min, Extent, Stride>& v, index_t inner_extent) {
596645
return split(interval<Min, Extent>(v.min(), v.extent()), inner_extent);
597646
}
@@ -608,10 +657,10 @@ NDARRAY_INLINE NDARRAY_HOST_DEVICE auto apply(Fn&& fn, const Args& args, index_s
608657
-> decltype(fn(std::get<Is>(args)...)) {
609658
return fn(std::get<Is>(args)...);
610659
}
611-
template <class Fn, class... Args>
612-
NDARRAY_INLINE NDARRAY_HOST_DEVICE auto apply(Fn&& fn, const std::tuple<Args...>& args)
613-
-> decltype(internal::apply(fn, args, make_index_sequence<sizeof...(Args)>())) {
614-
return internal::apply(fn, args, make_index_sequence<sizeof...(Args)>());
660+
template <class Fn, class Args>
661+
NDARRAY_INLINE NDARRAY_HOST_DEVICE auto apply(Fn&& fn, const Args& args)
662+
-> decltype(internal::apply(fn, args, make_index_sequence<std::tuple_size<Args>::value>())) {
663+
return internal::apply(fn, args, make_index_sequence<std::tuple_size<Args>::value>());
615664
}
616665

617666
template <class Fn, class... Args>

0 commit comments

Comments
 (0)