From 38ef58f0a426b6f7453b73dfd8556f2b7ebbfa56 Mon Sep 17 00:00:00 2001 From: Gerhard Aigner Date: Wed, 5 Feb 2025 19:29:20 +0100 Subject: [PATCH] make replace/replace! work with count --- src/chainedvector.jl | 45 ++++++++++++++++++++++++++++++++++---- test/chainedvector.jl | 50 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/src/chainedvector.jl b/src/chainedvector.jl index 2ce72da..4cdec8b 100644 --- a/src/chainedvector.jl +++ b/src/chainedvector.jl @@ -974,9 +974,46 @@ function Base.filter!(f, a::ChainedVector) return a end -Base.replace(f::Base.Callable, a::ChainedVector) = ChainedVector([replace(f, A) for A in a.arrays]) -Base.replace!(f::Base.Callable, a::ChainedVector) = (foreach(A -> replace!(f, A), a.arrays); return a) -Base.replace(a::ChainedVector, old_new::Pair...; count::Union{Integer,Nothing}=nothing) = ChainedVector([replace(A, old_new...; count=count) for A in a.arrays]) -Base.replace!(a::ChainedVector, old_new::Pair...; count::Integer=typemax(Int)) = (foreach(A -> replace!(A, old_new...; count=count), a.arrays); return a) +function _check_count(count::Integer) + count < 0 && throw(DomainError(count, "`count` must not be negative")) + return min(count, typemax(Int)) % Int +end + +Base.replace(f::Base.Callable, a::ChainedVector; count::Integer=typemax(Int)) = + _replace!(f, copy(a), a, _check_count(count)) + +Base.replace!(f::Base.Callable, a::ChainedVector; count::Integer=typemax(Int)) = + _replace!(f, a, a, _check_count(count)) + +Base.replace(A::ChainedVector, old_new::Pair...; count::Integer=typemax(Int)) = + _replace_pairs!(copy(A), A, _check_count(count), old_new) + +Base.replace!(A::ChainedVector, old_new::Pair...; count::Integer=typemax(Int)) = + _replace_pairs!(A, A, _check_count(count), old_new) + +function _replace_pairs!(res, A::ChainedVector{T}, count::Int, old_new::Tuple{Vararg{Pair}}) where {T} + @inline function new(x) + for (old, new) in old_new + isequal(x, old) && return new + end + return x # no replace + end + _replace!(new, res, A, count) +end + +function _replace!(new::Base.Callable, res, A::ChainedVector{T}, count::Int) where {T} + count == 0 && return res + c = 0 + for i in eachindex(A) + x = A[i] + y = new(x) + if x !== y + res[i] = y + c += 1 + c == count && break + end + end + return res +end Base.Broadcast.broadcasted(f::F, A::ChainedVector) where {F} = map(f, A) diff --git a/test/chainedvector.jl b/test/chainedvector.jl index faf3d93..cfde689 100644 --- a/test/chainedvector.jl +++ b/test/chainedvector.jl @@ -350,6 +350,13 @@ @test map(x -> x == 1 ? 2.0 : x, x) == replace!(x, 1 => 2) @test isempty(x) + @test replace!(ChainedVector([[1,2], [1,2]]), 2=>20) == [1,20,1,20] + @test replace!(ChainedVector([[1,2], [1,2]]), 2=>20, count=1) == [1,20,1,2] + @test replace!(ChainedVector([[1,2], [1,2]]), 2=>20, count=2) == [1,20,1,20] + x = [1,2] + @test replace!(ChainedVector([x,[2,3]]), 2=>99) == [1,99,99,3] + @test x == [1,99] + # copyto! # ChainedVector dest: doffs, soffs, n x = ChainedVector([[1,2,3], [4,5,6], [7,8,9,10]]) @@ -593,6 +600,47 @@ end end end +@testset "replace[!] comparison with Vector" begin + + testvecs = ( + [[1, 2], [3, 2, 5]], + [[1, 2]], + [[2],[2],[2],[2,3]], + [[1,2,missing]], + [[missing,1],[missing,2,1]], + [[missing]] + ) + function missing_equal(a,b) + ismissing(a) && ismissing(b) && return true + ismissing(a) ⊻ ismissing(b) && return false + return all(skipmissing(a) .== skipmissing(b)) + end + gen_cv_v(x) = (c = ChainedVector(x); (c, collect(c))) + for f in (replace, replace!) + for x in testvecs + cv, v = gen_cv_v(x) + @test missing_equal(f(v, 2 => 22),f(cv, 2 => 22)) + @test missing_equal(v,cv) + + cv, v = gen_cv_v(x) + @test missing_equal(f(x -> x ÷ 2, v), f(x -> x ÷ 2, cv)) + @test missing_equal(v,cv) + + for c in (0, 1, 2, 3) + cv, v = gen_cv_v(x) + @test missing_equal(f(x -> x ÷ 2, v, count=c), f(x -> x ÷ 2, cv, count=c)) + @test missing_equal(v,cv) + + for p in ((2=>2,),(2 => 22,), (2 => 22, 3 => 33)) + cv, v = gen_cv_v(x) + @test missing_equal(f(v, p..., count=c), f(cv, p..., count=c)) + @test missing_equal(v,cv) + end + end + end + end +end + @testset "iteration protocol on ChainedVector" begin for len in 0:6 @@ -752,7 +800,7 @@ end end @testset "getindex with UnitRange" begin - x = ChainedVector([collect(1:i) for i = 10:100]) + x = ChainedVector([collect(1:i) for i = 1:10]) @test isempty(x[1:0]) @test x[1:1] == [1] @test x[1:end] == x