Semi-automatic Differentiation
This is a brief note that on something I really wish I had realized a couple years ago.
My GSOC project didn't quite work out the way I hoped it would, but it seems that it turned into someone's master's thesis.
Last week, I realized that my failed attempt to translate the ideas of Conal Elliott's simple essence of automatic differentiation to the setting of accelerate
were not in vain1.
If you're willing to lean on the the existing ad package, you can end up with a very nice framework for composing functions, as well as their vector-Jacobian and Jacobian-vector products.
The formalism is flexible enough to work with (monadic) code generation and accelerate
, which are both important to me.
Critically, this formalism allows the implicit "reverse mode AD tape" of the computation map f v
to have a single node, not one node for each variable in the vector, and to have containers full of unboxed IEEE floats, not boxed types containing functions.
The node reduction is a major performance win, and an accelerate
implementation would be impossible without unboxed types.
Types from the paper
Let's revisit the first type of the simple essence paper.
newtype D a b = D (a -> (b, a -+> b)) type a -+> b = a -> b -- a linear map which yields the derivative
Over the course of the paper, Elliott generalizes this to the following type
newtype D k a b = D (a -> (b, k a b))
where k
is instantiated to several concrete categories, yielding forward mode and reverse mode automatic differentiation.
The trick
The proper setting for automatic differentiation is not Hask
, where the objects are types and the morphisms are functions, but the functor category on Hask
, where the objects are functors and the morphisms are natural (well, parametric) transformations between them.
The following types move us into the functor category setting.
-- push forward a tangent vector, aka Jacobian-vector product newtype FAD f g = FAD { getFAD :: forall a. Num a => f a -> (g a, f a -> g a) } -- pull back a cotangent covector, aka vector-Jacobian product newtype RAD f g = RAD { getRAD :: forall a. Num a => f a -> (g a, g a -> f a) }
Once again, the second element of the tuple is a linear map, but the linear map is now the entire Jacobian!
This is wonderful and purely functional, but you eventually need effects, if only to use LAPACK or cuBLAS. Thus, the user needs to be able to write their own vector-Jacobian and Jacobian-vector products that contain effects.
newtype FADT m f g = FADT { getFADT :: forall a. Num a => f a -> m (g a, f a -> m (g a)) } newtype RADT m f g = RADT { getRADT :: forall a. Num a => f a -> m (g a, g a -> m (f a)) } type FAD f g = FADT Identity f g type RAD f g = RADT Identity f g
(By the way: if you're after high performance, and you're computing a function that is a sum of functions that depend on a small fraction of the total variables, try instantiating m
as a monad that generates LLVM rather than IO
.
This works even better if you only need Hessian-vector products, e.g. as part of a truncated Newton-CG optimization routine.
If you're implementing a neural network and spend all your time on matrix multiplications, IO
is fine.)
A sketch of a library
From now on, I'll only use the RADT
type, since all the instances for FADT
are substantially similar.
We can immediately write a Category
instance.
instance Monad m => Category (RADT m) where id = RADT $ \fa -> pure (fa, pure) RADT g . RADT f = RADT $ \fa -> do (ga, dg) <- f fa (ha, dh) <- g ga pure (ha, dg <=< dh)
Since we're working in the category of "natural" transformations, we're not going to be able to write many other instances with normal Haskell typeclasses.
However, we can write a higher-kinded Strong
instance, albeit without the profunctor instance we would like.
-- HStrong should really have a FProfunctor p constraint, but I don't think -- it's possible to write one with this encoding. class HStrong p where ffirst :: p f g -> p (Product f h) (Product g h) fsecond :: p f g -> p (Product h f) (Product h g) instance Monad m => HStrong (RADT m) where ffirst (RADT d) = RADT $ \(Pair fa ha) -> do (ga, dg) <- d fa pure (Pair ga ha, \(Pair df' dh') -> dg df' >>= pure . flip Pair dh') fsecond (RADT d) = RADT $ \(Pair ha fa) -> do (ga, dg) <- d fa pure (Pair ha ga, \(Pair dh' df') -> dg df' >>= pure . Pair dh')
We can also write the moral equivalent of arr
, with many extra constraints to satisfy the ad
and linear
libraries.
liftRADT :: (Additive f, Traversable f, Additive g, Foldable g, Functor g, Monad m) => (forall s. (Reifies s Tape, Num a) => f (Reverse s a) -> g (Reverse s a)) -> RADT m f g liftRADT f = RADT $ \fa -> let j = jacobian' f fa g = fmap fst j dg = fmap snd j in pure (g, pure . (*! dg))
Ideally, we could use rebindable syntax to take advantage of arrow notation, but the preprocessor generates code with tuples, which are not the categorical product type we care about here.
(We would want to use Data.Functor.Product
instead.)
In any case, we can write what I think should be a higher-kinded indexed monad instance.
class HIndexedMonad (m :: (* -> *) -> (* -> *) -> *) where freturn :: (forall a. f a -> m f f) fbind :: m f g -> (forall a. f a -> m g h) -> m f h instance Monad m => HIndexedMonad (RADT m) where freturn f = RADT $ \fa -> pure (fa, pure) fbind (RADT f) m = RADT $ \fa -> do let RADT h = m fa (ga, dg) <- f fa (ha, dh) <- h ga pure (ha, dh >=> dg)
but this isn't very useful, since the forall
guarantees we can't do anything interesting with the variables we bind with QualifiedDo
notation.
What do some basic instances that use real numbers instead of category theory look like in this setting?
instance (Additive f, Monad m) => Num (RADT m f Identity) where fromInteger a = RADT $ \_ -> let a = Identity (fromInteger a) da = pure . const zero pure (a, da) RADT f + RADT g = RADT $ \z -> do (u,dudw) <- f z (v,dvdw) <- g z let w = u + v dw dw' = do du' <- dudw dw' dv' <- dvdw dw' pure $ du' ^+^ dv' pure (w, dw) RADT f - RADT g = RADT $ \z -> do (u,dudw) <- f z (v,dvdw) <- g z let w = u - v dw dw' = do du' <- dudw dw' dv' <- dvdw dw' pure $ du' ^-^ dv' pure (w, dw) RADT f * RADT g = RADT $ \z -> do (u,dudw) <- f z (v,dvdw) <- g z let w = u * v dw dw' = do du' <- dudw dw' dv' <- dvdw dw' pure $ du' ^* runIdentity v ^+^ dv' ^* runIdentity u pure (w, dw) abs (RADT f) = RADT $ \z -> do (u, dudw) <- f z let w = abs u dw dw' = do du' <- dudw dw' pure $ du' ^* signum (runIdentity u) pure (w, dw)
This is not a terribly useful library; since we can't use proc
or do
notation, we have to write everything in a point-free style.
Let's fix this problem.
van Laarhoven lenses
There is hardly any code to be written if we realize that RAD
is just a lens with an extra forall.
If we are willing to abandon some type safety by replacing Data.Functor.Product
with regular tuples, we can even use arrow notation (assuming we only use a NatLens
between the arrow tip and tail that computes derivatives, and only pack arguments into tuples on the right hand side of the tail.)
type NatLens f g h k = forall a. Floating a => Lens (f a) (g a) (h a) (k a) type NatLens' f g = NatLens f f g g l1 :: NatLens' V2 V2 l1 = lens (\(V2 u v) -> V2 (sin u) (cos v)) (\(V2 u v) (V2 dudk dvdk) -> V2 (cos u*dudk) (negate $ sin v*dvdk)) l2 :: NatLens' V2 Identity l2 = lens (\(V2 u v) -> Identity $ u*u + v*v) (\(V2 u v) (Identity dwdk) -> V2 (2*u*dwdk) (2*v*dwdk)) l3 :: NatLens' V2 Identity l3 = proc dz -> do dy <- l2 -< dz dx <- l1 -< dy returnA -< dx
To compute l3 x
, use l3
as a getter; to compute grad l3 x
, use l3
as a setter.
Note that when using proc
notation, we plumb the cotangent covectors around, not the normal values.
This can be extended to work with accelerate
in a straightforward way - just write a different slightly different NatLens
type alias that uses Exp
or Acc
as appropriate!
We can also maintain some sharing if we don't use the lens
combinator, and write our lenses by hand.
Conclusion
This is a very nice start to a library for efficient and user-extensible semi-automatic differentiation.
Abandoning type safety for arrow/proc notation is critical to making this generally useful, but unfortunate.
Gelisam's category-syntax library could be a good starting point for a type-safe quasiquoter.
In a practical library, the FADT
and RADT
types, and the NatLens
type should have a way to accumulate constraints that extend Num
on the forall
'd type
newtype FADT c m f g = FADT { getFADT :: forall a. c a => f a -> m (g a, f a -> m (g a)) } newtype RADT c m f g = RADT { getRADT :: forall a. c a => f a -> m (g a, g a -> m (f a)) }
since we can't even write a Floating
instance right now!
This will likely require the constraints
package, since we want to say something like "this works for any constraint c
that entails Floating
."
ADDENDUM: Unfortunately, after implementing this approach in some more depth, I discovered the lens with arrow notation approach ends up having some serious usabillity issues.
For nontrivial code, it is extremely confusing to replace all arguments to your functions with cotangent vectors, and the type errors are just awful.
In my opinion, the best way to implement an AD library backed by accelerate
would be to stick to the type in Conal's paper, and ensure all operations are natural "by hand."
The library would use Backpack to provide a signature that contains a subset of Accelerate functions, which would no longer take normal arrays as arguments; they would take RAD a b
arguments, for suitable a
and b
.
As an example, this is what a zipWith
combinator might look like (using lists instead of accelerate
arrays for now.)
newtype RAD a b = RAD { getRAD :: a -> (b, b -> a) } type Tuple a = Product Identity Identity a pattern Tuple :: a -> a -> Tuple a pattern Tuple a b = Pair (Identity a) (Identity b) zipWithA :: Num a => (forall s. Reifies s Tape => Tuple (Reverse s a) -> Reverse s a) -> RAD b [a] -> RAD c [a] -> RAD (b,c) [a] zipWithA f (RAD g) (RAD h) = RAD $ \(u, v) -> let (x, dx) = g u (y, dy) = h v (z, dudv) = unzip $ zipWith (\a b -> grad' f (Tuple a b)) x y dz' dz = let (du,dv) = unzip $ fmap (\(Tuple u' v') -> (u', v')) dudv in (dx $ zipWith (*) du dz, dy $ zipWith (*) dv dz) in (z, dz')
To demonstrate these ideas, I will write an implementation of the No U-Turn Sampler backed by accelerate
.
Footnotes:
The main idea of this post started as an alternate solution to the problems described in section 15 of the simple essence paper, entitled "Scaling Up". However, I eventually realized that the core D a b
type is what needed to be generalized to make an accelerate
implementation possible.