In this blog post, I will describe the basics of a new form of Automatic Differentiation (AD) that I will be implementing in Haskell for the Google Summer of Code, along with some background and why it is expected that the method will yield a speedup for practical problems.

Introduction

There's a lot of digital ink that's been spilled on the subject of AD, so I'll outsource the description elsewhere. If you want to learn more about AD, or the terms "forward-mode" and "reverse-mode" don't mean anything to you, here are some good Haskell-based articles on the subject.

AD is useful for a wide range of quantitative problems, and it is used extensively in machine learning and optimization—especially in neural networks. A lot of the work that goes into ML research is calculus: you compute the derivatives of some loss function you want to optimize, and stuff them into some kind of gradient descent woodchipper. AD lets you avoid doing these calculations by hand, and focus on dreaming up new things to optimize.

The problem is that AD can be significantly slower than hand-coded gradient computation. Moreover, optimal Jacobian accumulation is NP-complete.

The new method

The method I'll be working on this summer takes a new approach to AD, and is based on in-progress work by Martin Elsman, Fritz Henglein, Gabriele Keller, Ken Friis Larsen, and Dimitrios Vytionitis. They refer to the method as "combinatory automatic differentiation", although I think "point-free automatic differentiation" also captures the basic idea.

There's a lot more to say about the method, but I'll stick to a high-level overview here.

First, some notation: denote function composition by \(\circ\), and composition of linear maps by \(\bullet\). The Frechet derivative (a generalization of the usual derivative to maps between Banach spaces, which spits out the Jacobian when applied to \(f: \mathbb{R}^m \to \mathbb{R}^n\)) is a linear map, and taking some liberties with notation, we can write the chain rule like so: \[ (f \circ g)'(x) = f'(g(x)) \bullet g'(x) \] In the finite dimensional case, \(\bullet\) is matrix multiplication, which has two key attributes.

  • It is amenable to parallelization and GPU friendly
  • People have been working on making implementations go faster since computers were invented.

How does this help us?

Suppose we have written our program \(f\) as a long series of function compositions. \[ f = f_k \circ f_{k-1} \circ \cdots \circ f_2 \circ f_1 \] We can always do this with pure functional code, even if the result is extremely hard to understand. The Jacobian accumulation problem can then be understood as asking us to choose the parenthetical association of functions which is operation-count-optimal after the chain rule is applied. (Since function composition is associative, all associations will all give the same answer, but some will be faster than others.) As special cases, we have forward mode AD \[ f = f_k \circ (f_{k-1} \circ (\cdots \circ (f_2 \circ f_1) \cdots)) \] and reverse mode AD \[ f = (\cdots(f_k \circ f_{k-1}) \circ \cdots \circ f_2) \circ f_1 \]

The main insight of the combinatory method is that working at the level of these function trains allows us to extract data parallelism from the chain rule, as long as we take care to avoid expression swell and optimize where possible. (I will discuss both of these topics in the future.) In the combinatory setting, the derivative of our function train might look like this \[ f' = ( \cdots (f_k \circ f_{k-1}) \circ \cdots \circ f_{k-p})' \bullet (f_j \circ f_{j-1})' \bullet (\cdots(f_m \circ f_{m-1}) \circ \cdots \circ f_{m-n})' \bullet \cdots \bullet (f_2 \circ f_1)' \] If we split up the function train wisely, we'll be able to spend most of our computational time doing two things when applying AD to practical problems.

  • Evaluating a Jacobian/Hessian matrix (mostly in parallel)
  • Parallel matrix multiplication

Sketch of an implementation

I'll now describe a quick and dirty prototype implementation of the idea, which is extremely inefficient—it uses list-based matrix multiplication, among other things.

We'll begin with some imports.

import Data.List (transpose)

For simplicity, we'll use an extremely sparse expression datatype, which is understood as the point-free composition of functions. Eventually, this will be replaced by the AST of accelerate. The Pi constructor represents projection onto the \(k\)-th component of a tuple, and it replaces variables. (Think of it as a kind of de Bruijn indexing.)

data AdExpr = Constant Double
  | Pi Int
  | Plus AdExpr AdExpr
  | Times AdExpr AdExpr
  deriving (Show, Eq)

Let's define a partial Num instance for AdExpr, so we can write things in a nicer way.

instance Num AdExpr where
  (+) = Plus
  (*) = Times
  fromInteger i = Constant (fromInteger i)

Next, we'll define an evaluator for our simple AST at a vector v.

evalAd :: AdExpr -> [Double] -> Double
evalAd (Constant x) _ = x
evalAd (Pi i) v = v !! i
evalAd (Plus x y) v = evalAd x v + evalAd y v
evalAd (Times x y) v = evalAd x v * evalAd y v

In this point-free setting, differentiation of an expression is a breeze. The diff function differentiates an expression with respect to the \(i\)-th input variable.

diff :: AdExpr -> Int -> (AdExpr -> AdExpr)
diff (Plus x y) i = Times (Plus (diff x i) (diff y i))
diff (Times x y) i = Times (Plus (Times (diff x i) y) (Times x (diff y i)))
diff (Constant _) _ = Times (Constant 0.0)
diff (Pi i) j = if i == j then Times (Constant 1.0) else Times (Constant 0.0)

Now that we can differentiate with respect to an individual variable, we can now compute the gradient and Jacobian of functions. We use the domainDim function to compute how many "variables" an expression depends on.

grad :: AdExpr -> [AdExpr]
grad expr = map (\i -> diff expr i $ Constant 1.0) [0..domainDim expr - 1]

jacobian :: [AdExpr] -> [[AdExpr]]
jacobian expr = map grad' expr
  where dim = max 0 . decf . maximum $ map domainDim expr
        decf x = x-1
        grad' e = map (\i -> diff e i $ Constant 1.0) [0..dim]

domainDim :: AdExpr -> Int
domainDim expr = if count == 0 then 0 else 1 + count
  where count = go expr 0
        go (Pi i) j = max i j
        go (Constant _) j = j
        go (Plus x y) j = max (go x j) (go y j)
        go (Times x y) j = max (go x j) (go y j)

Finally, we'll define list-based matrix multiplication.

matMatMul :: Num a => [[a]] -> [[a]] -> [[a]]
matMatMul a b =
 [[ sum $ zipWith (*) ar bc | bc <- (transpose b) ] | ar <- a]

With this infrastructure, we can define an "efficient" parallelizable implementation of the chain rule, assuming our function train was broken up in such a way that each piece's Jacobian can be evaluated in parallel.

chain :: [AdExpr] -> [AdExpr] -> [Double] -> [[Double]]
chain f g a = matMatMul jfga jga
  where ga = map (evalAd' a) g
        jga = map (map $ evalAd' a) $ jacobian g
        jfga = map (map $ evalAd' ga) $ jacobian f
        evalAd' = flip evalAd

As an example, we'll define two functions.

foo :: [AdExpr]
foo =
  let
    x = Pi 0
    y = Pi 1
    z = Pi 2
  in [y*z + x*z + x*y, x*x + y*y + z*z]

bar :: [AdExpr]
bar =
  let
    u = Pi 0
    v = Pi 1
  in
    [u*u+2*v, v*v*v+u]

and compute the Jacobian of their composition.

λ> chain bar foo [1,2,3]
[[114.0,96.0,78.0],[1181.0,2356.0,3531.0]]

We can also define a function to evaluate the Jacobian of a train of function compositions. (In the real implementation, this is another space for optimization. Unlike the Jacobian accumulation problem, there are efficient algorithms for solving the matrix chain multiplication problem.)

pipeline :: [[AdExpr]] -> [Double] -> [[Double]]
pipeline [f] v = map (map $ flip evalAd v) $ jacobian f
pipeline (f:fs) gv = matMatMul rest jf
  where rest = pipeline fs $ map (flip evalAd gv) f
        jf = map (map $ flip evalAd gv) $ jacobian f

Conclusion

In this blog post, I've described what I hope to achieve over the summer, and provided a proof-of-concept implementation for the combinatory AD method. In future blog posts, I will tackle additional related topics, including the following.

  • Optimization/avoiding expression swell
  • Integration with accelerate
  • Dealing with variables
  • Benchmarks
  • Integration with linear algebra primitives