This summer, I worked on a Google Summer of Code project intended to implement automatic differentiation for Haskell's accelerate vector programming EDSL, using a new method of automatic differentiation. To make a long story short, the differentiation part worked out even better than I hoped, but the automatic part didn't for a variety of reasons which I will detail in this post.

Let's start with the good stuff, however. The paper for the method that I was using is still unfinished, so I can't say for sure, but I believe I worked out the punchline (with the help of a similar paper by Conal Elliott): in a purely functional setting, the following combinators are all the infrastructure you need to support vectorized, parallel AD capable of arbitrary mixed-mode computation.

scale a = \da -> a*da
compForward x = (matrixMultiply x .)
compReverse x = (. flip matrixMultiply x)

(Of course, you need derivative information for primitives as well.)

Differentiation…

How does this work? Let's start with forward-mode AD for scalar expressions. The usual way to do forward-mode AD is with dual numbers:

newtype D a = D a a

instance Num a => Num (D a) where
  D f f' + D g g' = D (f+g) (f'+g')
  D f f' * D g g' = D (f*g) (f'*g + f*g')
  -- etc.

Instead of keeping track of the derivative via a number, we can keep track of it with a function (in mathematical terms, we've moved from the usual derivative to the Frechet derivative.)

-- (-+>) is a linear function
type (-+>) a b = a -> b
newtype D a b = D (a -> (a, a -+> b))

In this setting, the automatically differentiated version of square x = x*x would be square = D (\x -> (x*x, scale (2*x))). We can also share work between the function and its derivative.

exp = D (\x -> let y = exp x in (y, scale y))

We can compose D types with their Category instance (exercise for the reader.) Once we're done composing them, we can get a numerical answer in a straightforward way.

withDerivative (D f) x = let (y, f') = f x in (y, f' 1.0)

We can now extend this trick to vector-valued functions by replacing scale x with matrixMultiply x, but there's an added complication, which is the impact of the association of matrix products on performance. Suppose we have some Jacobians already computed in a staged or lazy way

-- evaluated at x
f' = A.use $ A.fromList (A.Z A.:. (2::Int) A.:. (10::Int)) [0::Double ..]
-- evaluated at f(x)
g' = A.use $ A.fromList (A.Z A.:. (10::Int) A.:. (3::Int)) [0::Double ..]
-- evaluated at g(f(x))
h' = A.use $ A.fromList (A.Z A.:. (3::Int) A.:. (50::Int)) [0::Double ..]
-- evaluated at h(g(f(x)))
j' = A.use $ A.fromList (A.Z A.:. (50::Int) A.:. (80::Int)) [0::Double ..]
-- evaluated at j(h(g(f(x))))
k' = A.use $ A.fromList (A.Z A.:. (80::Int) A.:. (100::Int)) [0::Double ..]

and we want to compute the Jacobian of \(f \circ g \circ h \circ j \circ k\). The chain rule tells us what the answer is: we just matrix-multiply everything. It doesn't tell us how to do it efficiently, however. If we have a lot of inputs and a few outputs, we should multiply from left to right (reverse mode); otherwise, going from right to left is a better choice (forward mode.) We can do this elegantly with the compForward and compReverse combinators. (Note the analogy between passing id in the matrix case and 1.0 in the scalar case from earlier, and which matrix gets passed as an argument.)

λ> run $ (compForward f' 
          . compForward g'
          . compForward h'
          . compForward j' $ id) k'
Matrix (Z :. 2 :. 100) [...]
(4.70 secs, 8,707,288,416 bytes)

λ> run $ (compReverse g' 
          . compReverse h'
          . compReverse j'
          . compReverse k' $ id) f'
Matrix (Z :. 2 :. 100) [...]
(0.19 secs, 188,243,768 bytes)

It doesn't help in this case, but we can support mixed-mode computation easily: anytime you want to switch modes, just feed the comp train an id and start a new one.

λ> run $ (compForward f' . compForward g' $ id)
       . (compReverse j' . compReverse k' $ id) $ h'
Matrix (Z :. 2 :. 100) [...]
(0.47 secs, 688,244,264 bytes)

That's the differentiation; all that's left is the automatic part.

… of the not-o-matic variety

This was all relatively straightforward (in hindsight, at least), so why didn't I succeed in writing an AD library?

The short answer is "the automatic part."

It's possible to manually write a bunch of D combinators and write everything in a point-free style using Category instances, but that gets old fast. You can write half an Arrow instance for D, but arr is up to its usual dirty tricks: we can't differentiate an arbitrary Haskell function. For this reason, the implementation described in Conal Elliott's paper is intended to be used with his compiling to categories GHC plugin.

Unfortunately, the compiling with categories implementation can't be extended to accelerate for two reasons, compiler plugin aside. Firstly, the way it deals with vectorization is totally incompatible with the way accelerate works - it relies on a matrix being a vector of vectors and a bunch of instances that accelerate can't support. But what really is a dealbreaker is the separate Acc and Exp types in accelerate. compReverse and compForward are replaced by a category instance in this setting, and you need two separate category instances k and l to generalize Acc and Exp functions, respectively. Say now you want to wrap the accelerate function map :: (Exp a -> Exp b) -> Acc (Array a) -> Acc (Array b) into something like mapC :: a `l` b -> a `k` b. If you try to write mapC f . mapC g, you're hit with ambiguous type errors, and you now have to manually thread Proxy l arguments everywhere.

Since the whole point of automatic differentiation is to avoid threading derivatives around manually in a mechanical but easy to screw up way, our other choice is to transform accelerate syntax trees to thread AD state around. Unfortunately, Trevor McDonell et al. put a lot of effort and thought into making the accelerate compiler preserve types as a bug-squashing mechanism. They did an extremely good job of it - it is indeed extremely hard (perhaps impossible) to write code that changes the type of a function's accelerate syntax tree in a generic way.

I'll walk through a few examples of things that are critical for AD that are either very hard or impossible in a generic way with accelerate.

Type-changing function transformations

The AST for accelerate array functions has two constructors: Alam and Abody. A function of two arguments would get translated to Alam(Alam(Abody accelerateExpression)), where accelerateExpression has type OpenAcc aenv t - basically an AST for expressions that keeps track of result type and the environment of array variables. Suppose we've written an optimization f :: OpenAcc f t -> OpenAcc f t. We'd like to apply it to functions with arbitrarily many arguments, so we write the following.

applyRewriteAcc k (Abody b) = Abody (k b)
applyRewriteAcc k (Alam f)  = Alam (applyRewriteAcc k f)

This gives a type error, and won't compile unless you give it this type signature and turn on Rank2Types.

applyRewriteAcc
    :: (forall aenv' t'. OpenAcc aenv' t' -> OpenAcc aenv' t')
    -> PreOpenAfun OpenAcc aenv t
    -> PreOpenAfun OpenAcc aenv t

Great - you can now apply your optimization, and you know it's type preserving! Sadly, this won't work if you try to generalize the signature to something that leaves the environment alone but changes the type of the expression. Every change I tried to that type signature that would let you apply a general type-changing rewrite gives you a type error. However, you can get it to work for transformations with somewhat more concrete types (a fixed number of arguments and a "concrete enough" output type.)

Let's try to write a function that rewrites the AST of an Acc (Vector a) -> Acc (Vector a) and transforms it to a representation of Acc (Vector a) -> Acc (Vector a, Matrix a) - this is the type of our AD transformation in the case where the input is a vector. When you try to case match on the function's argument x, pattern matching on certain constructors gives you a type error. Indeed, the types of these constructors contradict the type of x - but these constructors could show up at a lower level in the syntax tree, and so you have to repeat almost all the patterns twice.

These constructors are a small minority, so you can try to go on with automatic differentiation. The next problem you run into is dealing with the stream fusion in accelerate - it also does not like its types being changed. I believe it is possible to overcome this - I just left it on the table because there were plenty of things to do that I had no idea how to approach at the time and didn't require a solution to the fusion problem. Once you've taken care of that, you can automatically differentiate functions of a single array argument, assuming you know the derivatives of everything and how to compose them.

I think with more time, these problems could have been overcome.

Adding a constant to a syntax tree in a type-preserving manner

At one point, I got stuck on the problem of adding a constant to an accelerate syntax tree (in a type-preserving manner, so nothing in the previous section applied.) This is the code that Trevor sent back, and I'm grateful he did, because I'd have been stuck for a very long time.

-- This is kind of a hack? We can traverse the representation of any type
-- down to primitive values in order to get a zero.
--
delta :: forall env c. Elt c => Int -> Int -> PreOpenExp acc env aenv c
delta i' j' = Const $ go (eltType (undefined::c))
  where
    go :: TupleType a -> a
    go TypeRunit         = ()
    go (TypeRpair ta tb) = (go ta, go tb)
    go (TypeRscalar t)   = scalar t

    scalar :: ScalarType a -> a
    scalar (SingleScalarType t) = single t
    scalar (VectorScalarType t) = vector t

    vector :: VectorType a -> a
    vector (Vector2Type t) = let x = single t in V2 x x

    single :: SingleType a -> a
    single (NumSingleType    t) = num t
    single (NonNumSingleType t) = nonnum t

    num :: NumType a -> a
    num (IntegralNumType t) | IntegralDict <- integralDict t = if i' == j' then 1 else 0
    num (FloatingNumType t) | FloatingDict <- floatingDict t = if i' == j' then 1 else 0

    nonnum :: NonNumType a -> a
    nonnum = undefined -- uh..?

It turns out accelerate has a generics implementation in it that knows about most of the types it compiles.

Tensor contractions

accelerate supports multidimensional arrays, not just vectors and matrices. Although the output of a function might be a vector or a matrix, there's a good chance a higher rank array shows up in an intermediate step. The generalized Jacobian of a function \(f\) between higher rank arrays looks like \[ D_{\mu_1\mu_2\cdots \mu_m \nu_1 \nu_2 \cdots \nu_n}(f) = \frac{\partial f_{\mu_1\mu_2\cdots \mu_m}}{\partial g_{\nu_1\nu_2\cdots \nu_n}} \] To compose two of these functions, we use the chain rule \[ D_{\mu_1\mu_2\cdots \mu_m \lambda_1 \lambda_2\cdots \lambda_p}(f \circ g) = D_{\mu_1\mu_2\cdots \mu_m \nu_1 \nu_2 \cdots \nu_n}(f) \circ D_{\nu_1\nu_2\cdots \nu_n \lambda_1 \lambda_2 \cdots \lambda_p}(g) = \frac{\partial f_{\mu_1\mu_2\cdots \mu_m}}{\partial g_{\nu_1\nu_2\cdots \nu_n}} \frac{\partial g_{\nu_1\nu_2\cdots \nu_n}}{\partial x_{\lambda_1\lambda_2\cdots \lambda_p}} \] This may be possible to do in accelerate generically, but I'm confident it isn't possible without some new primitives. On the other hand, the vectorization of accelerate combinators over inner dimensions might mean we don't need full generality due to sparsity. I still need to think about this more, especially in the cases of mappings from arrays of rank \(n\) to rank \(n+1\) and rank \(n\) to rank \(n-1\).

Differentiation of accelerate combinators

Since accelerate is a vector language, we can't build up differentiation purely from basic operations. We have to differentiate combinators as they are. (I did this in detail for folds here.) Unfortunately, folds are a misleading example, in that you get the honest generalized Jacobian. In the generic case, you get a sparse tensor in a combinator-specific format. Moreover, the set of combinators in the accelerate AST is not differentially closed, but you can avoid using most of the ones that break differential closedness. Most of these can or should be translatable to standard sparse tensor formats, but I don't know for sure.

On the minus side, accelerate will need to support tensor contraction for all pairs of sparse formats. On the plus side, this implies a very high potential for performance.

Dealing with tuples

Tuples are easy enough to deal with in expressions, but they make life difficult when they are function arguments. In the generic case, you end up having to compose linear maps that have different ranks in different parts of the tuple. You also have to smoothly join tuple-based matrix/tensor contraction with array-based contraction. Supporting the stencil operations in accelerate will most likely have to be done by converting them to generate expressions.

Hessians and beyond

I need to think more about how to support higher-order scalar and vector derivatives with this method - there was just not enough time to do so. The abstract of the draft paper I was sent mentioned a solution to this, but that section wasn't written yet.

Conclusion and status report

  • Even though I didn't get an AD library working by the end of the summer, I'm still optimistic about the potential of the method.
  • If you're willing to deal with the syntactic overhead of threading derivatives around manually, this method is already usable today. I got it to work on an extremely simple vector optimization problem in reverse mode (see Demo.hs in the github repository), and I could have extended it to a simple neural network if I knew the answer to the questions in the tensor contraction section and had more time.
  • I'm extremely close to being able to differentiate simple end-user programs automatically (i.e. ones that just use folds and maps). Stream fusion breakage is all that's left in the way of that milestone.
  • As I look back over the specifics of what I got stuck on in detail, most of it had to do with types and the fact that I didn't have access to a complete paper. (There was no mention of vector combinator differentiation in the paper, nor any of the mixed-mode combinators I discussed in the first section.)
  • In my last meeting with my mentors, I said I thought the only way forward was an AST that compiled to accelerate, but going over the accelerate difficulties a month-ish after I last touched them, I'm not so sure now - the part of the code that deals with the gory internals of accelerate may largely be done.
  • A lot depends on the answers to the tensor contraction questions. If we don't need fully generic contractions, we're in business. If not, there's a lot more work that needs to be done on accelerate before AD will be useful.
  • While the accelerate library was frustrating to use for automated transformation purposes, writing new code in it was an enjoyable experience.
  • Working out the compForward and compReverse trick really sold me on Haskell - there is no way I would have discovered that in any other language.

Finally, I'd like to thank my mentors Trevor McDonell, Sacha Sokoloski, and Edward Kmett for a lot of great conversations and their help with arcane problems of all sorts. I would not have gotten half as far and learned a quarter as much without their guidance.