Support non-standard scalars in test_scalar#61
Support non-standard scalars in test_scalar#61sethaxen wants to merge 11 commits intoJuliaDiff:mainfrom
Conversation
Not supported in Julia 1.0
Not supported by Julia v1
|
will review tomorrow |
| # check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im | ||
| @testset "$f at $z, with tangent $Δz" for (i, Δz) in enumerate(Δzs) | ||
| frule_test(f, (z, Δz); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) | ||
| if !isa(Δz, Real) && i == 1 |
There was a problem hiding this comment.
Shouldn't this be:
| if !isa(Δz, Real) && i == 1 | |
| if !isa(Δz, Real) && length(Δzs) == 1 |
There was a problem hiding this comment.
In this case, no. i == 1 when when the given tangent vector is purely real, even if it isn't a Real. And this test checks that using an actually Real tangent vector gives the same result.
oxinabox
left a comment
There was a problem hiding this comment.
I think i would like to look at this again after the comments are addressed.
I am not sure i properly undestand what is going on for the
if !isa(Δz, Real) && i == 1
branches,
and I think I would be better able to, if we have split this into two methods, one for real and one not for real.
src/testers.jl
Outdated
| vΩ, Ω_from_vec = to_vec(Ω) | ||
| # orthonormal cotangent vectors | ||
| vΩ_basis = Diagonal(ones(eltype(vΩ), length(vΩ))) | ||
| ΔΩs = [Ω_from_vec(vΩ_basis[:, i]) for i in axes(vΩ_basis, 2)] |
There was a problem hiding this comment.
We should move this out into a helper function basis_vectors
| To use this tester for a scalar type `MyNumber <: AbstractNumber`, | ||
| `FiniteDifferences.to_vec(::MyNumber)` must be implemented. | ||
| """ | ||
| function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) |
There was a problem hiding this comment.
Can we simplify this code by defining a seperate method for:
| function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) | |
| function test_scalar(f, z::Real; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) |
?
There was a problem hiding this comment.
It would simplify the frule test because we wouldn't need the basis, but if the output is non-real we still need the basis on the output for the rrule test. Adding a separate method would require us to maintain that code in two places.
| function FiniteDifferences.to_vec(q::Quaternion) | ||
| function Quaternion_from_vec(q_vec) | ||
| return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4]) | ||
| end | ||
| return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec | ||
| end |
There was a problem hiding this comment.
We should move this to be defined in the package itself.
There was a problem hiding this comment.
Do you mean define this in Quaternions or FiniteDifferences?
There was a problem hiding this comment.
Or ChainRulesTestUtils?
There was a problem hiding this comment.
I don't know if it makes sense to make Quaternions an optional dependency for FiniteDifferences. Since I am only defining this for the purpose of testing, I'm comfortable with being type-piratical but just in the test suite where it can't pollute the methods table for other users. Thoughts?
| return quatfun(q), quatfun_pullback | ||
| end | ||
|
|
||
| q = quatrand() |
There was a problem hiding this comment.
Should we define rand_tangent(:: Quaternion) in this package also?
@willtebbutt do you have plans around further advancing rand_tangent ?
There was a problem hiding this comment.
I do not currently -- not sure that there's much to do beyond integrating it in with ChainRulesTestUtils in some way or another and continuing to add new methods where necessary.
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
|
|
||
| Get a set of basis (co)tangent vectors for `x`. | ||
|
|
||
| This function assumes that the (co)tangent vectors are of the same type as `x` and requires |
There was a problem hiding this comment.
It'd be nice to give basis vectors of the same type as the result of rand_tangent instead, but I'm not certain how to do that.
|
Is this ready for rereview? |
Not yet, I'll try to finish it up this week. |
test_scalarcurrently is veryRealandComplexfocused. This PR generalizestest_scalarto work the same for any scalar for whichFiniteDifferences.to_vec(and a handful of base functions) are implemented.We test it with
Quaternions.Quaternion. We'd ideally test against a more minimal number, but it turns out one needs to implement quite a few base methods to get a new number to work correctly.