Skip to content

Commit 23f8d25

Browse files
committed
fix: type inferance in generate loop
1 parent 8c99347 commit 23f8d25

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

src/tools.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ function gen_loop(inds, inds_ax, fn_args, lhs, rhs, opts...)
419419
end
420420
ret_name = lhs.args[1]
421421
ret_size = Expr[]
422-
ret_type = _infer_ret_type(inds, fn_args, rhs)
422+
ret_type = _infer_ret_type(inds, inds_ax, fn_args, rhs)
423423
loop_head = Expr(:block)
424424
for (i, ind) in enumerate(inds)
425425
array_name, dims = first(inds_ax[ind])
@@ -480,16 +480,16 @@ function _parse_opt(opt::Expr)
480480
end
481481
end
482482

483-
function _infer_ret_type(inds, fn_args, ex)
484-
ex_replace = MacroTools.postwalk(ex) do sym
485-
if sym isa Symbol && sym in inds
486-
return :begin # replace the iteration symbol with first index
487-
else
488-
return sym
489-
end
483+
function _infer_ret_type(inds, inds_ax, fn_args, ex)
484+
ex′ = Expr(:block, ex)
485+
for ind in inds
486+
array_name, dims = first(inds_ax[ind])
487+
dim = first(dims)
488+
assignment = :($ind = firstindex($array_name, $dim))
489+
pushfirst!(ex′.args, assignment)
490490
end
491491
args = collect(fn_args)
492-
promote_op_fn = Expr(:(->), Expr(:tuple, args...), ex_replace)
492+
promote_op_fn = Expr(:(->), Expr(:tuple, args...), ex′)
493493
args_type = map(arg -> :(typeof($arg)), args)
494494
ret_type = :(Base.promote_op($promote_op_fn, $(args_type...)))
495495
return ret_type # expression for inferring the return type

test/tools.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@ end
1616
# those functions is sensitive to the order of arguments, which is undetermined
1717
f1 = @quickloop R1[i, j] := A[i] * B[j]
1818
f2 = @quickloop R2[j, i] := A[i] * B[j]
19-
f3 = @quickloop A[i, j] # just make a copy
19+
f3 = @quickloop ifelse(i == j, zero(eltype(A)), A[i, j]) # make diagonal zero
2020
f4 = @quickloop R3[i, j, k] := A[i, j] * B[i, k]
2121
@test (f1(A, B) == A .* B' == transpose(f2(A, B)) ||
2222
f1(A, B) == B .* A' == transpose(f2(A, B)))
23-
@test C == f3(C)
23+
F = f3(C)
24+
equality_f3 = true
25+
for i in 1:4, j in 1:4
26+
equality_f3 &= ifelse(i == j, F[i, j] == 0, F[i, j] == C[i, j])
27+
end
28+
@test equality_f3
2429
E = f4(C, D)
2530
equality_f4 = true
2631
for k in 1:4, j in 1:4, i in 1:4

0 commit comments

Comments
 (0)