\(\require{AMScd}\)

This is a brief note that on something I really wish I had realized a couple years ago. My GSOC project didn't quite work out the way I hoped it would, but it seems that it turned into someone's master's thesis. Last week, I realized that my failed attempt to translate the ideas of Conal Elliott's simple essence of automatic differentiation to the setting of accelerate were not in vain1. If you're willing to lean on the the existing ad package, you can end up with a very nice framework for composing functions, as well as their vector-Jacobian and Jacobian-vector products. The formalism is flexible enough to work with (monadic) code generation and accelerate, which are both important to me. Critically, this formalism allows the implicit "reverse mode AD tape" of the computation map f v to have a single node, not one node for each variable in the vector, and to have containers full of unboxed IEEE floats, not boxed types containing functions. The node reduction is a major performance win, and an accelerate implementation would be impossible without unboxed types.

Types from the paper

Let's revisit the first type of the simple essence paper.

newtype D a b = D (a -> (b, a -+> b))
type a -+> b = a -> b -- a linear map which yields the derivative

Over the course of the paper, Elliott generalizes this to the following type

newtype D k a b = D (a -> (b, k a b))

where k is instantiated to several concrete categories, yielding forward mode and reverse mode automatic differentiation.

The trick

The proper setting for automatic differentiation is not Hask, where the objects are types and the morphisms are functions, but the functor category on Hask, where the objects are functors and the morphisms are natural (well, parametric) transformations between them. The following types move us into the functor category setting.

-- push forward a tangent vector, aka Jacobian-vector product
newtype FAD f g = FAD {
  getFAD :: forall a. Num a => f a -> (g a, f a -> g a)
}

-- pull back a cotangent covector, aka vector-Jacobian product
newtype RAD f g = RAD {
  getRAD :: forall a. Num a => f a -> (g a, g a -> f a)
}

Once again, the second element of the tuple is a linear map, but the linear map is now the entire Jacobian!

This is wonderful and purely functional, but you eventually need effects, if only to use LAPACK or cuBLAS. Thus, the user needs to be able to write their own vector-Jacobian and Jacobian-vector products that contain effects.

newtype FADT m f g = FADT {
  getFADT :: forall a. Num a => f a -> m (g a, f a -> m (g a))
}

newtype RADT m f g = RADT {
  getRADT :: forall a. Num a => f a -> m (g a, g a -> m (f a))
}

type FAD f g = FADT Identity f g
type RAD f g = RADT Identity f g

(By the way: if you're after high performance, and you're computing a function that is a sum of functions that depend on a small fraction of the total variables, try instantiating m as a monad that generates LLVM rather than IO. This works even better if you only need Hessian-vector products, e.g. as part of a truncated Newton-CG optimization routine. If you're implementing a neural network and spend all your time on matrix multiplications, IO is fine.)

A sketch of a library

From now on, I'll only use the RADT type, since all the instances for FADT are substantially similar. We can immediately write a Category instance.

instance Monad m => Category (RADT m) where
  id = RADT $ \fa -> pure (fa, pure)
  RADT g . RADT f = RADT $ \fa -> do
    (ga, dg) <- f fa
    (ha, dh) <- g ga
    pure (ha, dg <=< dh)

Since we're working in the category of "natural" transformations, we're not going to be able to write many other instances with normal Haskell typeclasses. However, we can write a higher-kinded Strong instance, albeit without the profunctor instance we would like.

-- HStrong should really have a FProfunctor p constraint, but I don't think
-- it's possible to write one with this encoding.
class HStrong p where
  ffirst :: p f g -> p (Product f h) (Product g h)
  fsecond :: p f g -> p (Product h f) (Product h g)

instance Monad m => HStrong (RADT m) where
  ffirst (RADT d) = RADT $ \(Pair fa ha) -> do
    (ga, dg) <- d fa
    pure (Pair ga ha, \(Pair df' dh') -> dg df' >>= pure . flip Pair dh')
  fsecond (RADT d) = RADT $ \(Pair ha fa) -> do
    (ga, dg) <- d fa
    pure (Pair ha ga, \(Pair dh' df') -> dg df' >>= pure . Pair dh')

We can also write the moral equivalent of arr, with many extra constraints to satisfy the ad and linear libraries.

liftRADT
  :: (Additive f, Traversable f, Additive g,
      Foldable g, Functor g, Monad m)
  => (forall s. (Reifies s Tape, Num a)
      => f (Reverse s a) -> g (Reverse s a))
  -> RADT m f g
liftRADT f = RADT $ \fa ->
  let j = jacobian' f fa
      g = fmap fst j
      dg = fmap snd j
  in pure (g, pure . (*! dg))

Ideally, we could use rebindable syntax to take advantage of arrow notation, but the preprocessor generates code with tuples, which are not the categorical product type we care about here. (We would want to use Data.Functor.Product instead.) In any case, we can write what I think should be a higher-kinded indexed monad instance.

class HIndexedMonad (m :: (* -> *) -> (* -> *) -> *) where
  freturn :: (forall a. f a -> m f f)
  fbind :: m f g -> (forall a. f a -> m g h) -> m f h

instance Monad m => HIndexedMonad (RADT m) where
  freturn f = RADT $ \fa -> pure (fa, pure)
  fbind (RADT f) m = RADT $ \fa -> do
    let RADT h = m fa
    (ga, dg) <- f fa
    (ha, dh) <- h ga
    pure (ha, dh >=> dg)

but this isn't very useful, since the forall guarantees we can't do anything interesting with the variables we bind with QualifiedDo notation.

What do some basic instances that use real numbers instead of category theory look like in this setting?

instance (Additive f, Monad m) => Num (RADT m f Identity) where
  fromInteger a = RADT $ \_ ->
    let a = Identity (fromInteger a)
        da = pure . const zero
      pure (a, da)

  RADT f + RADT g = RADT $ \z -> do
    (u,dudw) <- f z
    (v,dvdw) <- g z
    let w = u + v
        dw dw' = do
          du' <- dudw dw'
          dv' <- dvdw dw'
          pure $ du' ^+^ dv'
    pure (w, dw)

  RADT f - RADT g = RADT $ \z -> do
    (u,dudw) <- f z
    (v,dvdw) <- g z
    let w = u - v
        dw dw' = do
          du' <- dudw dw'
          dv' <- dvdw dw'
          pure $ du' ^-^ dv'
    pure (w, dw)

  RADT f * RADT g = RADT $ \z -> do
    (u,dudw) <- f z
    (v,dvdw) <- g z
    let w = u * v
        dw dw' = do
          du' <- dudw dw'
          dv' <- dvdw dw'
          pure $ du' ^* runIdentity v ^+^ dv' ^* runIdentity u
    pure (w, dw)

  abs (RADT f) = RADT $ \z -> do
    (u, dudw) <- f z
    let w = abs u
        dw dw' = do
          du' <- dudw dw'
          pure $ du' ^* signum (runIdentity u)
    pure (w, dw)

This is not a terribly useful library; since we can't use proc or do notation, we have to write everything in a point-free style. Let's fix this problem.

van Laarhoven lenses

There is hardly any code to be written if we realize that RAD is just a lens with an extra forall. If we are willing to abandon some type safety by replacing Data.Functor.Product with regular tuples, we can even use arrow notation (assuming we only use a NatLens between the arrow tip and tail that computes derivatives, and only pack arguments into tuples on the right hand side of the tail.)

type NatLens f g h k = forall a. Floating a => Lens (f a) (g a) (h a) (k a)
type NatLens' f g = NatLens f f g g

l1 :: NatLens' V2 V2
l1 = lens (\(V2 u v) -> V2 (sin u) (cos v))
          (\(V2 u v) (V2 dudk dvdk) ->
             V2 (cos u*dudk) (negate $ sin v*dvdk))

l2 :: NatLens' V2 Identity
l2 = lens (\(V2 u v) -> Identity $ u*u + v*v)
          (\(V2 u v) (Identity dwdk) -> V2 (2*u*dwdk) (2*v*dwdk))

l3 :: NatLens' V2 Identity
l3 = proc dz -> do
  dy <- l2 -< dz
  dx <- l1 -< dy
  returnA -< dx

To compute l3 x, use l3 as a getter; to compute grad l3 x, use l3 as a setter. Note that when using proc notation, we plumb the cotangent covectors around, not the normal values.

This can be extended to work with accelerate in a straightforward way - just write a different slightly different NatLens type alias that uses Exp or Acc as appropriate! We can also maintain some sharing if we don't use the lens combinator, and write our lenses by hand.

Conclusion

This is a very nice start to a library for efficient and user-extensible semi-automatic differentiation. Abandoning type safety for arrow/proc notation is critical to making this generally useful, but unfortunate. Gelisam's category-syntax library could be a good starting point for a type-safe quasiquoter. In a practical library, the FADT and RADT types, and the NatLens type should have a way to accumulate constraints that extend Num on the forall'd type

newtype FADT c m f g = FADT {
  getFADT :: forall a. c a => f a -> m (g a, f a -> m (g a))
}

newtype RADT c m f g = RADT {
  getRADT :: forall a. c a => f a -> m (g a, g a -> m (f a))
}

since we can't even write a Floating instance right now! This will likely require the constraints package, since we want to say something like "this works for any constraint c that entails Floating."

ADDENDUM: Unfortunately, after implementing this approach in some more depth, I discovered the lens with arrow notation approach ends up having some serious usabillity issues. For nontrivial code, it is extremely confusing to replace all arguments to your functions with cotangent vectors, and the type errors are just awful. In my opinion, the best way to implement an AD library backed by accelerate would be to stick to the type in Conal's paper, and ensure all operations are natural "by hand." The library would use Backpack to provide a signature that contains a subset of Accelerate functions, which would no longer take normal arrays as arguments; they would take RAD a b arguments, for suitable a and b. As an example, this is what a zipWith combinator might look like (using lists instead of accelerate arrays for now.)

newtype RAD a b = RAD { getRAD :: a -> (b, b -> a) }

type Tuple a = Product Identity Identity a

pattern Tuple :: a -> a -> Tuple a
pattern Tuple a b = Pair (Identity a) (Identity b)

zipWithA
  :: Num a
  => (forall s. Reifies s Tape => Tuple (Reverse s a) -> Reverse s a)
  -> RAD b [a]
  -> RAD c [a]
  -> RAD (b,c) [a]
zipWithA f (RAD g) (RAD h) = RAD $ \(u, v) ->
  let (x, dx) = g u
      (y, dy) = h v
      (z, dudv) = unzip $ zipWith (\a b -> grad' f (Tuple a b)) x y
      dz' dz = let (du,dv) = unzip $ fmap (\(Tuple u' v') -> (u', v')) dudv
               in (dx $ zipWith (*) du dz, dy $ zipWith (*) dv dz)
  in (z, dz')

To demonstrate these ideas, I will write an implementation of the No U-Turn Sampler backed by accelerate.

Footnotes:

1

The main idea of this post started as an alternate solution to the problems described in section 15 of the simple essence paper, entitled "Scaling Up". However, I eventually realized that the core D a b type is what needed to be generalized to make an accelerate implementation possible.