Skip to content

Commit

Permalink
Excise getindex adjoint
Browse files Browse the repository at this point in the history
We have a better rule in Chainrules now
  • Loading branch information
ToucheSir committed Nov 10, 2022
1 parent d39ab59 commit d73f176
Showing 1 changed file with 2 additions and 40 deletions.
42 changes: 2 additions & 40 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,48 +21,10 @@ end
@adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,)
@adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,)

@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)

@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds)

∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin
if inds isa NTuple{N,Int} && T <: Number
dx = OneElement(dy, inds, axes(x))
elseif inds isa NTuple{<:Any, Integer}
dx = _zero(x, typeof(dy))
dx[inds...] = dy
else
dx = _zero(x, eltype(dy))
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
return (_project(x, dx), map(_->nothing, inds)...)
end

"""
OneElement(val, ind, axes) <: AbstractArray
Extremely simple `struct` used for the gradient of scalar `getindex`.
"""
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
val::T
ind::I
axes::A
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
end
Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))


_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false)
_zero(xs::AbstractArray, T) = fill!(similar(xs, Union{Nothing, T}), nothing)

_droplike(dy, dxv) = dy
_droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::AbstractVector) =
dropdims(dy; dims=2)

@adjoint getindex(::Type{T}, xs...) where {T} = T[xs...], dy -> (nothing, dy...)

_throw_mutation_error(f, args...) = error("""
Expand All @@ -83,7 +45,7 @@ Possible fixes:
_ -> _throw_mutation_error(copyto!, xs)

for f in [push!, pop!, pushfirst!, popfirst!]
@eval @adjoint! $f(x::AbstractVector, ys...) = $f(x, ys...),
@eval @adjoint! $f(x::AbstractVector, ys...) = $f(x, ys...),
_ -> _throw_mutation_error($f, x)
end

Expand Down Expand Up @@ -310,7 +272,7 @@ end
# =============

@adjoint parent(x::LinearAlgebra.Adjoint) = parent(x), ȳ -> (LinearAlgebra.Adjoint(ȳ),)
@adjoint parent(x::LinearAlgebra.Transpose) = parent(x), ȳ -> (LinearAlgebra.Transpose(ȳ),)
@adjoint parent(x::LinearAlgebra.Transpose) = parent(x), ȳ -> (LinearAlgebra.Transpose(ȳ),)

function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)
Expand Down

0 comments on commit d73f176

Please sign in to comment.