@@ -198,10 +198,33 @@ class index_iterator {
198198 }
199199
200200 NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator operator ++(int ) { return index_iterator (i_++); }
201+ NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator operator --(int ) { return index_iterator (i_--); }
201202 NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator& operator ++() {
202203 ++i_;
203204 return *this ;
204205 }
206+ NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator& operator --() {
207+ --i_;
208+ return *this ;
209+ }
210+ NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator& operator +=(index_t r) {
211+ i_ += r;
212+ return *this ;
213+ }
214+ NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator& operator -=(index_t r) {
215+ i_ -= r;
216+ return *this ;
217+ }
218+ NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator operator +(index_t r) {
219+ return index_iterator (i_ + r);
220+ }
221+ NDARRAY_INLINE NDARRAY_HOST_DEVICE index_iterator operator -(index_t r) {
222+ return index_iterator (i_ - r);
223+ }
224+ NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t operator -(const index_iterator& r) {
225+ return i_ - r.i_ ;
226+ }
227+ NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t operator [](index_t n) const { return i_ + n; }
205228};
206229
207230template <index_t Min, index_t Extent, index_t Stride>
@@ -271,6 +294,7 @@ class interval {
271294 NDARRAY_INLINE NDARRAY_HOST_DEVICE void set_min (index_t min) { min_ = min; }
272295 /* * Get or set the number of indices in this interval. */
273296 NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t extent () const { return extent_; }
297+ NDARRAY_INLINE NDARRAY_HOST_DEVICE index_t size () const { return extent_; }
274298 NDARRAY_INLINE NDARRAY_HOST_DEVICE void set_extent (index_t extent) { extent_ = extent; }
275299
276300 /* * Get or set the last index in this interval. */
@@ -433,6 +457,7 @@ class dim : protected interval<Min_, Extent_> {
433457 using base_range::begin;
434458 using base_range::end;
435459 using base_range::extent;
460+ using base_range::size;
436461 using base_range::is_in_range;
437462 using base_range::max;
438463 using base_range::min;
@@ -490,6 +515,8 @@ using broadcast_dim = dim<Min, Extent, 0>;
490515namespace internal {
491516
492517// An iterator for a range of intervals.
518+ // This is like a random access iterator in that it can move forward in constant time, but
519+ // but unlike a random access iterator, it cannot be moved in reverse.
493520template <index_t InnerExtent = dynamic>
494521class split_iterator {
495522 fixed_interval<InnerExtent> i;
@@ -507,47 +534,69 @@ class split_iterator {
507534 }
508535
509536 NDARRAY_HOST_DEVICE fixed_interval<InnerExtent> operator *() const { return i; }
537+ NDARRAY_HOST_DEVICE const fixed_interval<InnerExtent>* operator ->() const { return &i; }
510538
511- NDARRAY_HOST_DEVICE split_iterator& operator ++() {
539+ NDARRAY_HOST_DEVICE split_iterator& operator +=(index_t n) {
540+ assert (n >= 0 );
512541 if (is_static (InnerExtent)) {
513542 // When the extent of the inner split is a compile-time constant,
514543 // we can't shrink the out of bounds interval. Instead, shift the min,
515544 // assuming the outer dimension is bigger than the inner extent.
516- i.set_min (i.min () + InnerExtent);
545+ i.set_min (i.min () + InnerExtent * n );
517546 // Only shift the min when this straddles the end of the buffer,
518547 // so the iterator can advance to the end (one past the max).
519548 if (i.min () <= outer_max && i.max () > outer_max) { i.set_min (outer_max - InnerExtent + 1 ); }
520549 } else {
521550 // When the extent of the inner split is not a compile-time constant,
522551 // we can just modify the extent.
523- i.set_min (i.min () + i.extent ());
552+ i.set_min (i.min () + i.extent () * n );
524553 index_t max = min (i.max (), outer_max);
525554 i.set_extent (max - i.min () + 1 );
526555 }
527556 return *this ;
528557 }
558+ NDARRAY_HOST_DEVICE split_iterator operator +(index_t n) const {
559+ split_iterator<InnerExtent> result (*this );
560+ return result += n;
561+ }
562+ NDARRAY_HOST_DEVICE split_iterator& operator ++() {
563+ return *this += 1 ;
564+ }
529565 NDARRAY_HOST_DEVICE split_iterator operator ++(int ) {
530566 split_iterator<InnerExtent> result (*this );
531- ++ *this ;
567+ *this += 1 ;
532568 return result;
533569 }
570+
571+ NDARRAY_HOST_DEVICE index_t operator -(const split_iterator& r) const {
572+ return r.i .extent () > 0 ? (i.max () - r.i .min () + r.i .extent () - i.extent ()) / r.i .extent () : 0 ;
573+ }
574+
575+ NDARRAY_HOST_DEVICE fixed_interval<InnerExtent> operator [](index_t n) const {
576+ split_iterator result (*this );
577+ result += n;
578+ return *result;
579+ }
534580};
535581
536- // TODO: Remove this when std::iterator_range is standard.
537- template <class T >
538- class iterator_range {
539- T begin_;
540- T end_;
582+ template <index_t InnerExtent = dynamic>
583+ class split_result {
584+ public:
585+ using iterator = split_iterator<InnerExtent>;
586+
587+ private:
588+ iterator begin_;
589+ iterator end_;
541590
542591public:
543- NDARRAY_HOST_DEVICE iterator_range (T begin, T end) : begin_(begin), end_(end) {}
592+ NDARRAY_HOST_DEVICE split_result (iterator begin, iterator end) : begin_(begin), end_(end) {}
544593
545- NDARRAY_HOST_DEVICE T begin () const { return begin_; }
546- NDARRAY_HOST_DEVICE T end () const { return end_; }
547- };
594+ NDARRAY_HOST_DEVICE iterator begin () const { return begin_; }
595+ NDARRAY_HOST_DEVICE iterator end () const { return end_; }
548596
549- template <index_t InnerExtent = dynamic>
550- using split_iterator_range = iterator_range<split_iterator<InnerExtent>>;
597+ NDARRAY_HOST_DEVICE index_t size () const { return end_ - begin_; }
598+ NDARRAY_HOST_DEVICE iterator operator [](index_t i) const { return begin_ + i; }
599+ };
551600
552601} // namespace internal
553602
@@ -562,14 +611,14 @@ using split_iterator_range = iterator_range<split_iterator<InnerExtent>>;
562611 * - `split<5>(interval<>(0, 12))` produces the intervals `[0, 5)`,
563612 * `[5, 10)`, `[7, 12)`. Note the last two intervals overlap. */
564613template <index_t InnerExtent, index_t Min, index_t Extent>
565- NDARRAY_HOST_DEVICE internal::split_iterator_range <InnerExtent> split (
614+ NDARRAY_HOST_DEVICE internal::split_result <InnerExtent> split (
566615 const interval<Min, Extent>& v) {
567616 assert (v.extent () >= InnerExtent);
568617 return {{fixed_interval<InnerExtent>(v.min ()), v.max ()},
569618 {fixed_interval<InnerExtent>(v.max () + 1 ), v.max ()}};
570619}
571620template <index_t InnerExtent, index_t Min, index_t Extent, index_t Stride>
572- NDARRAY_HOST_DEVICE internal::split_iterator_range <InnerExtent> split (
621+ NDARRAY_HOST_DEVICE internal::split_result <InnerExtent> split (
573622 const dim<Min, Extent, Stride>& v) {
574623 return split<InnerExtent>(interval<Min, Extent>(v.min (), v.extent ()));
575624}
@@ -585,13 +634,13 @@ NDARRAY_HOST_DEVICE internal::split_iterator_range<InnerExtent> split(
585634// avoid some conversion messes. dim<Min, Extent> probably can't implicitly
586635// convert to interval<>.
587636template <index_t Min, index_t Extent>
588- NDARRAY_HOST_DEVICE internal::split_iterator_range <> split (
637+ NDARRAY_HOST_DEVICE internal::split_result <> split (
589638 const interval<Min, Extent>& v, index_t inner_extent) {
590639 return {{interval<>(v.min (), internal::min (inner_extent, v.extent ())), v.max ()},
591640 {interval<>(v.max () + 1 , 0 ), v.max ()}};
592641}
593642template <index_t Min, index_t Extent, index_t Stride>
594- NDARRAY_HOST_DEVICE internal::split_iterator_range <> split (
643+ NDARRAY_HOST_DEVICE internal::split_result <> split (
595644 const dim<Min, Extent, Stride>& v, index_t inner_extent) {
596645 return split (interval<Min, Extent>(v.min (), v.extent ()), inner_extent);
597646}
@@ -608,10 +657,10 @@ NDARRAY_INLINE NDARRAY_HOST_DEVICE auto apply(Fn&& fn, const Args& args, index_s
608657 -> decltype(fn(std::get<Is>(args)...)) {
609658 return fn (std::get<Is>(args)...);
610659}
611- template <class Fn , class ... Args>
612- NDARRAY_INLINE NDARRAY_HOST_DEVICE auto apply (Fn&& fn, const std::tuple< Args...> & args)
613- -> decltype(internal::apply(fn, args, make_index_sequence<sizeof ...( Args) >())) {
614- return internal::apply (fn, args, make_index_sequence<sizeof ...( Args) >());
660+ template <class Fn , class Args >
661+ NDARRAY_INLINE NDARRAY_HOST_DEVICE auto apply (Fn&& fn, const Args& args)
662+ -> decltype(internal::apply(fn, args, make_index_sequence<std::tuple_size< Args>::value >())) {
663+ return internal::apply (fn, args, make_index_sequence<std::tuple_size< Args>::value >());
615664}
616665
617666template <class Fn , class ... Args>
0 commit comments