Skip to content

Commit

Permalink
update tests, CR version
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Sep 4, 2022
1 parent 805dcf9 commit 164b257
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[compat]
ChainRules = "1.5"
ChainRulesCore = "1.2"
ChainRules = "1.44.6"
ChainRulesCore = "1.15.3"
Combinatorics = "1"
StaticArrays = "1"
StatsBase = "0.33"
Expand Down
10 changes: 5 additions & 5 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,10 @@ end

@ChainRulesCore.non_differentiable StaticArrays.promote_tuple_eltype(T)

function ChainRules.frule((_, ∂A), ::typeof(getindex), A::AbstractArray, args...)
getindex(A, args...), getindex(∂A, args...)
end
# function ChainRules.frule((_, ∂A), ::typeof(getindex), A::AbstractArray, args...)
# getindex(A, args...), getindex(∂A, args...)
# end
# WARNING: Method definition frule(Any, typeof(Base.getindex), AbstractArray{T, N} where N where T, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/KVV0e/src/rulesets/Base/indexing.jl:59 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:184

function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), ::typeof(+), A::AbstractArray, B::AbstractArray)
map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ)
Expand Down Expand Up @@ -266,5 +267,4 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk},
val, Δ->(NoTangent(), NoTangent(), Δ)
end

Base.real(z::ZeroTangent) = z # TODO should be in CRC
Base.real(z::NoTangent) = z
Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581
16 changes: 10 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ let var"'" = Diffractor.PrimeDerivativeBack
@test @inferred(sin'(1.0)) == cos(1.0)
@test @inferred(sin''(1.0)) == -sin(1.0)
@test sin'''(1.0) == -cos(1.0)
@test sin''''(1.0) == sin(1.0) broken = VERSION >= v"1.8"
@test sin'''''(1.0) == cos(1.0) broken = VERSION >= v"1.8"
@test sin''''''(1.0) == -sin(1.0) broken = VERSION >= v"1.8"
@test sin''''(1.0) == sin(1.0) # broken = VERSION >= v"1.8"
@test sin'''''(1.0) == cos(1.0) # broken = VERSION >= v"1.8"
@test sin''''''(1.0) == -sin(1.0) # broken = VERSION >= v"1.8"

f_getfield(x) = getfield((x,), 1)
@test f_getfield'(1) == 1
Expand Down Expand Up @@ -265,13 +265,17 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
end

@testset "broadcast, 2nd order" begin
@test gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6]
@test gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
@test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk
@test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27]
@test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] # Control flow support not fully implemented yet for higher-order

@test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0]
@test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
@test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1]

@test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] exp.(1:3) # MethodError: no method matching copy(::Nothing)
@test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] [0,0,0]
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # ERROR: (1, current_logger_for_env(std_level::Base.CoreLogging.LogLevel, group, _module) @ Base.CoreLogging logging.jl:500, :($(Expr(:meta, :noinline))))
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}})
@test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] [27.675925925925927, -0.824074074074074, -2.1018518518518516]

@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
Expand Down

0 comments on commit 164b257

Please sign in to comment.