# Compiling With CPS

Hello folks. It’s been a busy month so I haven’t had much a chance to write but I think now’s a good time to talk about another compiler related subject: continuation passing style conversion.

When you’re compiling a functional languages (in a sane way) your compiler mostly consists of phases which run over the AST and simplify it. For example in a language with pattern matching, it’s almost certainly the case that we can write something like

```
case x of
(1, 2) -> True
(_, _) -> False
```

Wonderfully concise code. However, it’s hard to compile nested patterns like that. In the compiler, we might simplify this to

```
case x of
(a, b) -> case a of
1 -> case b of
2 -> True
_ -> False
_ -> False
```

*note to future me, write a pattern matching compiler*

We’ve transformed our large nested pattern into a series of simpler, unnested patterns. The benefit here is that this maps more straight to a series of conditionals (or jumps).

Now one of the biggest decisions in any compiler is what to do with expressions. We want to get rid of complicated nested expressions because chances are our compilation target doesn’t support them. In my second to last post we transformed a functional language into something like SSA. In this post, we’re going to walk through a different intermediate representation: CPS.

## What is CPS

CPS is a restriction of how a functional language works. In CPS we don’t have nested expressions anymore. We instead have a series of lets which telescope out and each binds a “flat” expressions. This is the process of “removing expressions” from our language. A compiler probably is targeting something with a much weaker notion of expressions (like assembly) and so we change our tree like structure into something more linear.

Additionally, no functions return. Instead, they take a continuation and when they’re about to return they instead pass their value to it. This means that conceptually, all functions are transformed from `a -> b`

to `(a, b -> void) -> void`

. Logically, this is actually a reasonable thing to do. This corresponds to mapping a proposition `b`

to `¬ ¬ b`

. What’s cool here is that since each function call calls a continuation instead of returning its result, we can imagine that each function just transferring control over to some other part of the program instead of returning. This leads to a very slick and efficient way of implementing CPSed function calls as we’ll see.

This means we’d change something like

` fact n = if n == 0 then 1 else n * fact (n - 1)`

into

```
fact n k =
if n == 0
then k 1
else let n' = n - 1 in
fact n' (\r ->
let r' = n * r in
k r')
```

To see what’s going on here we

- Added an extra argument to
`fact`

, its return continuation - In the first branch, we pass the result to the continuation instead of returning it
- In the next branch, we lift the nested expression
`n - 1`

into a flat let binding - We add an extra argument to the recursive call, the continuation
- In this continuation, we apply multiply the result of the recursive call by
`n`

(Note here that we did close over`n`

, this lambda is a real lambda) - Finally, we pass the final result to the original continuation
`k`

.

The only tree-style-nesting here comes from the top `if`

expression, everything else is completely linear.

Let’s formalize this process by converting Simply Typed Lambda Calculus (STLC) to CPS form.

## STLC to CPS

First things first, we specify an AST for normal STLC.

```
data Tp = Arr Tp Tp | Int deriving Show
data Op = Plus | Minus | Times | Divide
-- The Tp in App refers to the return type, it's helpful later
data Exp a = App (Exp a) (Exp a) Tp
| Lam Tp (Scope () Exp a)
| Num Int
-- No need for binding here since we have Minus
| IfZ (Exp a) (Exp a) (Exp a)
| Binop Op (Exp a) (Exp a)
| Var a
```

We’ve supplemented our lambda calculus with natural numbers and some binary operations because it makes things a bit more fun. Additionally, we’re using bound to deal with bindings for lambdas. This means there’s a terribly boring monad instance lying around that I won’t bore you with.

To convert to CPS, we first need to figure out how to convert our types. Since CPS functions never return we want them to go to `Void`

, the unoccupied type. However, since our language doesn’t allow `Void`

outside of continuations, and doesn’t allow functions that don’t go to `Void`

, let’s bundle them up into one new type `Cont a`

which is just a function from `a -> Void`

. However, this presents us with a problem, how do we turn an `Arr a b`

into this style of function? It seems like our function should take two arguments, `a`

and `b -> Void`

so that it can produce a `Void`

of its own. However, this requires products since currying isn’t possible with the restriction that all functions return `Void`

! Therefore, we supplement our CPS language with pairs and projections for them.

Now we can write the AST for CPS types and a conversion between `Tp`

and it.

```
data CTp = Cont CTp | CInt | CPair CTp CTp
cpsTp :: Tp -> CTp
cpsTp (Arr l r) = Cont $ CPair (cpsTp l) (Cont (cpsTp r))
cpsTp Int = CInt
```

The only interesting thing here is how we translate function types, but we talked about that above. Now for expressions.

We want to define a new data type that encapsulates the restrictions of CPS. In order to do this we factor out our data types into “flat expressions” and “CPS expressions”. Flat expressions are things like values and variables while CPS expressions contain things like “Jump to this continuation” or “Branch on this flat expression”. Finally, there’s let expressions to perform various operations on expressions.

```
data LetBinder a = OpBind Op (FlatExp a) (FlatExp a)
| ProjL a
| ProjR a
| Pair (FlatExp a) (FlatExp a)
data FlatExp a = CNum Int | CVar a | CLam CTp a (CExp a)
data CExp a = Let a (LetBinder a) (CExp a)
| CIf (FlatExp a) (CExp a) (CExp a)
| Jump (FlatExp a) (FlatExp a)
| Halt (FlatExp a)
```

`Let`

s let us bind the results of a few “primitive operations” across values and variables to a fresh variable. This is where things like “incrementing a number” happen. Additionally, in order to create a pair or access its elements we need to us a `Let`

.

Notice that here application is spelled `Jump`

hinting that it really is just a `jmp`

and not dealing with the stack in any way. They’re all jumps we can not overflow the stack as would be an issue with a normal calling convention. To seal of the chain of function calls we have `Halt`

, it takes a `FlatExp`

and returns it as the result of the program.

Expressions here are also parameterized over variables but we can’t use bound with them (for reasons that deserve a blogpost-y rant :). Because of this we settle for just ensuring that each `a`

is globally unique.

So now instead of having a bunch of nested `Exp`

s, we have flat expressions which compute exactly one thing and linearize the tree of expressions into a series of flat ones with let binders. It’s still not quite “linear” since both lambdas and if branches let us have something tree-like.

We can now define conversion to CPS with one major helper function

```
cps :: (Eq a, Enum a)
=> Exp a
-> (FlatExp a -> Gen a (CExp a))
-> Gen a (CExp a)
```

This takes an expression, a “continuation” and produces a `CExp`

. We have some monad-gen stuff going on here because we need unique variables. The “continuation” is an actual Haskell function. So our function breaks an expression down to a `FlatExp`

and then feeds it to the continuation.

```
cps (Var a) c = c (CVar a)
cps (Num i) c = c (CNum i)
```

The first two cases are easy since variables and numbers are already flat expressions, they go straight into the continuation.

` cps (IfZ i t e) c = cps i $ \ic -> CIf ic <$> cps t c <*> cps e c`

For `IfZ`

we first recurse on the `i`

. Then once we have a flattened computation representing `i`

, we use `CIf`

and recurse.

```
cps (Binop op l r) c =
cps l $ \fl ->
cps r $ \fr ->
gen >>= \out ->
Let out (OpBind op fl fr) <$> c (CVar out)
```

Like before, we use `cps`

to recurse on the left and right sides of the expression. This gives us two flat expressions which we use with `OpBind`

to compute the result and bind it to `out`

. Now that we have a variable for the result we just toss it to the continuation.

```
cps (Lam tp body) c = do
[pairArg, newCont, newArg] <- replicateM 3 gen
let body' = instantiate1 (Var newArg) body
cbody <- cps body' (return . Jump (CVar newCont))
c (CLam (cpsTp tp) pairArg
$ Let newArg (ProjL pairArg)
$ Let newCont (ProjR pairArg)
$ cbody)
```

Converting a lambda is a little bit more work. It needs to take a pair so a lot of the work is projecting out the left component (the argument) and the right component (the continuation). With those two things in hand we recurse in the body using the continuation supplied as an argument. The actual code makes this process look a little out of order. Remember that we only use `cbody`

once we’ve bound the projections to `newArg`

and `pairArg`

respectively.

```
cps (App l r tp) c = do
arg <- gen
cont <- CLam (cpsTp tp) arg <$> c (CVar arg)
cps l $ \fl ->
cps r $ \fr ->
gen >>= \pair ->
return $ Let pair (Pair fr cont) (Jump fl (CVar pair))
```

For application we just create a lambda for the current continuation. We then evaluate the left and right sides of the application using recursive calls. Now that we have a function to jump to, we create a pair of the argument and the continuation and bind it to a name. From there, we just jump to `fl`

, the function. Turning the continuation into a lambda is a little strange, it’s also we needed an annotation for `App`

. The lambda uses the return type of the application and constructs a continuation that maps `a`

to `c a`

. Note that `c a`

is a Haskell expressions with the type `CExp a`

.

```
convert :: Exp Int -> CExp Int
convert = runGen . flip cps (return . Halt)
```

With this, we’ve written a nice little compiler pass to convert expressions into their CPS forms. By doing this we’ve “eliminated expressions”. Everything is now flat and evaluation basically proceeds by evaluating one small computation and using the result to compute another and another.

There’s still some things left to compile out before this is machine code though

- Closures - these can be converted to explicitly pass records with closure conversion
- Hoist lambdas out of nested scope - this gets rid of anonymous functions, something we don’t have in C or assembly
- Make allocation explicit - Allocate a block of memory for a group of let statements and have them explicitly move the results of their computations to it
- Register allocation - Cleverly choose whether to store some particular variable in a register or load it in as needed.

Once we’ve done these steps we’ve basically written a compiler. However, they’re all influenced by the fact that we’ve compiled out expressions and (really) function calls with our conversion to CPS, it makes the process much much simpler.

## Wrap Up

CPS conversion is a nice alternative to something like STG machines for lazy languages or SSA for imperative ones. As far as I’m aware the main SML interpreter (SML/NJ) compiles code in this way. As does Ur/Web if I’m not crazy. Additionally, the course entitled “Higher Order, Typed Compilation” which is taught here at CMU uses CPS conversion to make compiling SML really quite pleasant.

In fact, someone (Andrew Appel?) once wrote a paper that noted that SSA and CPS are actually the same. The key difference is that in SSA we merge multiple blocks together using the phi function. In CPS, we just let multiple source blocks jump to the same destination block (continuation). You can see this in our conversion of `IfZ`

to CPS, instead of using `phi`

to merge in the two branches, they both just use the continuation to jump to the remainder of the program. It makes things a little simpler because no one person needs to worry about

Finally, if you’re compiling a language like Scheme with call/cc, using CPS conversion makes the whole thing completely trivial. All you do is define `call/cc`

at the CPS level as

`call/cc (f, c) = f ((λ (x, c') → c x), c)`

So instead of using the continuation supplied to us in the expression we give to `f`

, we use the one for the whole `call/cc`

invocation! This causes us to not return into the body of `f`

but instead to carry on the rest of the program as if `f`

had returned whatever value `x`

is. This is how my old Scheme compiler did things, I put figuring out how to implement `call/cc`

off for a week before I realized it was a 10 minute job!

Hope this was helpful!

comments powered by Disqus