Compiling With CPS

Posted on April 30, 2015
Tags: compilers, haskell

Hello folks. It’s been a busy month so I haven’t had much a chance to write but I think now’s a good time to talk about another compiler related subject: continuation passing style conversion.

When you’re compiling a functional languages (in a sane way) your compiler mostly consists of phases which run over the AST and simplify it. For example in a language with pattern matching, it’s almost certainly the case that we can write something like

    case x of
      (1, 2) -> True
      (_, _) -> False

Wonderfully concise code. However, it’s hard to compile nested patterns like that. In the compiler, we might simplify this to

    case x of
     (a, b) -> case a of
                 1 -> case b of
                        2 -> True
                        _ -> False
                 _ -> False

note to future me, write a pattern matching compiler

We’ve transformed our large nested pattern into a series of simpler, unnested patterns. The benefit here is that this maps more straight to a series of conditionals (or jumps).

Now one of the biggest decisions in any compiler is what to do with expressions. We want to get rid of complicated nested expressions because chances are our compilation target doesn’t support them. In my second to last post we transformed a functional language into something like SSA. In this post, we’re going to walk through a different intermediate representation: CPS.

What is CPS

CPS is a restriction of how a functional language works. In CPS we don’t have nested expressions anymore. We instead have a series of lets which telescope out and each binds a “flat” expressions. This is the process of “removing expressions” from our language. A compiler probably is targeting something with a much weaker notion of expressions (like assembly) and so we change our tree like structure into something more linear.

Additionally, no functions return. Instead, they take a continuation and when they’re about to return they instead pass their value to it. This means that conceptually, all functions are transformed from a -> b to (a, b -> void) -> void. Logically, this is actually a reasonable thing to do. This corresponds to mapping a proposition b to ¬ ¬ b. What’s cool here is that since each function call calls a continuation instead of returning its result, we can imagine that each function just transferring control over to some other part of the program instead of returning. This leads to a very slick and efficient way of implementing CPSed function calls as we’ll see.

This means we’d change something like

    fact n = if n == 0 then 1 else n * fact (n - 1)

into

    fact n k =
      if n == 0
      then k 1
      else let n' = n - 1 in
           fact n' (\r ->
                          let r' = n * r in
                          k r')

To see what’s going on here we

  1. Added an extra argument to fact, its return continuation
  2. In the first branch, we pass the result to the continuation instead of returning it
  3. In the next branch, we lift the nested expression n - 1 into a flat let binding
  4. We add an extra argument to the recursive call, the continuation
  5. In this continuation, we apply multiply the result of the recursive call by n (Note here that we did close over n, this lambda is a real lambda)
  6. Finally, we pass the final result to the original continuation k.

The only tree-style-nesting here comes from the top if expression, everything else is completely linear.

Let’s formalize this process by converting Simply Typed Lambda Calculus (STLC) to CPS form.

STLC to CPS

First things first, we specify an AST for normal STLC.

    data Tp = Arr Tp Tp | Int deriving Show

    data Op = Plus | Minus | Times | Divide

    -- The Tp in App refers to the return type, it's helpful later
    data Exp a = App (Exp a) (Exp a) Tp

               | Lam Tp (Scope () Exp a)
               | Num Int
                 -- No need for binding here since we have Minus
               | IfZ (Exp a) (Exp a) (Exp a)
               | Binop Op (Exp a) (Exp a)
               | Var a

We’ve supplemented our lambda calculus with natural numbers and some binary operations because it makes things a bit more fun. Additionally, we’re using bound to deal with bindings for lambdas. This means there’s a terribly boring monad instance lying around that I won’t bore you with.

To convert to CPS, we first need to figure out how to convert our types. Since CPS functions never return we want them to go to Void, the unoccupied type. However, since our language doesn’t allow Void outside of continuations, and doesn’t allow functions that don’t go to Void, let’s bundle them up into one new type Cont a which is just a function from a -> Void. However, this presents us with a problem, how do we turn an Arr a b into this style of function? It seems like our function should take two arguments, a and b -> Void so that it can produce a Void of its own. However, this requires products since currying isn’t possible with the restriction that all functions return Void! Therefore, we supplement our CPS language with pairs and projections for them.

Now we can write the AST for CPS types and a conversion between Tp and it.

    data CTp = Cont CTp | CInt | CPair CTp CTp

    cpsTp :: Tp -> CTp
    cpsTp (Arr l r) = Cont $ CPair (cpsTp l) (Cont (cpsTp r))
    cpsTp Int = CInt

The only interesting thing here is how we translate function types, but we talked about that above. Now for expressions.

We want to define a new data type that encapsulates the restrictions of CPS. In order to do this we factor out our data types into “flat expressions” and “CPS expressions”. Flat expressions are things like values and variables while CPS expressions contain things like “Jump to this continuation” or “Branch on this flat expression”. Finally, there’s let expressions to perform various operations on expressions.

    data LetBinder a = OpBind Op (FlatExp a) (FlatExp a)
                     | ProjL a
                     | ProjR a
                     | Pair (FlatExp a) (FlatExp a)

    data FlatExp a = CNum Int | CVar a | CLam CTp a (CExp a)

    data CExp a = Let a (LetBinder a) (CExp a)
                | CIf (FlatExp a) (CExp a) (CExp a)
                | Jump (FlatExp a) (FlatExp a)
                | Halt (FlatExp a)

Lets let us bind the results of a few “primitive operations” across values and variables to a fresh variable. This is where things like “incrementing a number” happen. Additionally, in order to create a pair or access its elements we need to us a Let.

Notice that here application is spelled Jump hinting that it really is just a jmp and not dealing with the stack in any way. They’re all jumps we can not overflow the stack as would be an issue with a normal calling convention. To seal of the chain of function calls we have Halt, it takes a FlatExp and returns it as the result of the program.

Expressions here are also parameterized over variables but we can’t use bound with them (for reasons that deserve a blogpost-y rant :). Because of this we settle for just ensuring that each a is globally unique.

So now instead of having a bunch of nested Exps, we have flat expressions which compute exactly one thing and linearize the tree of expressions into a series of flat ones with let binders. It’s still not quite “linear” since both lambdas and if branches let us have something tree-like.

We can now define conversion to CPS with one major helper function

    cps :: (Eq a, Enum a)
        => Exp a
        -> (FlatExp a -> Gen a (CExp a))
        -> Gen a (CExp a)

This takes an expression, a “continuation” and produces a CExp. We have some monad-gen stuff going on here because we need unique variables. The “continuation” is an actual Haskell function. So our function breaks an expression down to a FlatExp and then feeds it to the continuation.

    cps (Var a) c = c (CVar a)
    cps (Num i) c = c (CNum i)

The first two cases are easy since variables and numbers are already flat expressions, they go straight into the continuation.

    cps (IfZ i t e) c = cps i $ \ic -> CIf ic <$> cps t c <*> cps e c

For IfZ we first recurse on the i. Then once we have a flattened computation representing i, we use CIf and recurse.

    cps (Binop op l r) c =
      cps l $ \fl ->
      cps r $ \fr ->
      gen >>= \out ->
      Let out (OpBind op fl fr) <$> c (CVar out)

Like before, we use cps to recurse on the left and right sides of the expression. This gives us two flat expressions which we use with OpBind to compute the result and bind it to out. Now that we have a variable for the result we just toss it to the continuation.

    cps (Lam tp body) c = do
      [pairArg, newCont, newArg] <- replicateM 3 gen
      let body' = instantiate1 (Var newArg) body
      cbody <- cps body' (return . Jump (CVar newCont))
      c (CLam (cpsTp tp) pairArg
         $ Let newArg  (ProjL pairArg)
         $ Let newCont (ProjR pairArg)
         $ cbody)

Converting a lambda is a little bit more work. It needs to take a pair so a lot of the work is projecting out the left component (the argument) and the right component (the continuation). With those two things in hand we recurse in the body using the continuation supplied as an argument. The actual code makes this process look a little out of order. Remember that we only use cbody once we’ve bound the projections to newArg and pairArg respectively.

    cps (App l r tp) c = do
      arg <- gen
      cont <- CLam (cpsTp tp) arg <$> c (CVar arg)
      cps l $ \fl ->
        cps r $ \fr ->
        gen >>= \pair ->
        return $ Let pair (Pair fr cont) (Jump fl (CVar pair))

For application we just create a lambda for the current continuation. We then evaluate the left and right sides of the application using recursive calls. Now that we have a function to jump to, we create a pair of the argument and the continuation and bind it to a name. From there, we just jump to fl, the function. Turning the continuation into a lambda is a little strange, it’s also we needed an annotation for App. The lambda uses the return type of the application and constructs a continuation that maps a to c a. Note that c a is a Haskell expressions with the type CExp a.

    convert :: Exp Int -> CExp Int
    convert = runGen . flip cps (return . Halt)

With this, we’ve written a nice little compiler pass to convert expressions into their CPS forms. By doing this we’ve “eliminated expressions”. Everything is now flat and evaluation basically proceeds by evaluating one small computation and using the result to compute another and another.

There’s still some things left to compile out before this is machine code though

Once we’ve done these steps we’ve basically written a compiler. However, they’re all influenced by the fact that we’ve compiled out expressions and (really) function calls with our conversion to CPS, it makes the process much much simpler.

Wrap Up

CPS conversion is a nice alternative to something like STG machines for lazy languages or SSA for imperative ones. As far as I’m aware the main SML interpreter (SML/NJ) compiles code in this way. As does Ur/Web if I’m not crazy. Additionally, the course entitled “Higher Order, Typed Compilation” which is taught here at CMU uses CPS conversion to make compiling SML really quite pleasant.

In fact, someone (Andrew Appel?) once wrote a paper that noted that SSA and CPS are actually the same. The key difference is that in SSA we merge multiple blocks together using the phi function. In CPS, we just let multiple source blocks jump to the same destination block (continuation). You can see this in our conversion of IfZ to CPS, instead of using phi to merge in the two branches, they both just use the continuation to jump to the remainder of the program. It makes things a little simpler because no one person needs to worry about

Finally, if you’re compiling a language like Scheme with call/cc, using CPS conversion makes the whole thing completely trivial. All you do is define call/cc at the CPS level as

call/cc (f, c) = f ((λ (x, c') → c x), c)

So instead of using the continuation supplied to us in the expression we give to f, we use the one for the whole call/cc invocation! This causes us to not return into the body of f but instead to carry on the rest of the program as if f had returned whatever value x is. This is how my old Scheme compiler did things, I put figuring out how to implement call/cc off for a week before I realized it was a 10 minute job!

Hope this was helpful!

comments powered by Disqus