Compiling With CPS
Hello folks. It’s been a busy month so I haven’t had much a chance to write but I think now’s a good time to talk about another compiler related subject: continuation passing style conversion.
When you’re compiling a functional languages (in a sane way) your compiler mostly consists of phases which run over the AST and simplify it. For example in a language with pattern matching, it’s almost certainly the case that we can write something like
Wonderfully concise code. However, it’s hard to compile nested patterns like that. In the compiler, we might simplify this to
note to future me, write a pattern matching compiler
We’ve transformed our large nested pattern into a series of simpler, unnested patterns. The benefit here is that this maps more straight to a series of conditionals (or jumps).
Now one of the biggest decisions in any compiler is what to do with expressions. We want to get rid of complicated nested expressions because chances are our compilation target doesn’t support them. In my second to last post we transformed a functional language into something like SSA. In this post, we’re going to walk through a different intermediate representation: CPS.
What is CPS
CPS is a restriction of how a functional language works. In CPS we don’t have nested expressions anymore. We instead have a series of lets which telescope out and each binds a “flat” expressions. This is the process of “removing expressions” from our language. A compiler probably is targeting something with a much weaker notion of expressions (like assembly) and so we change our tree like structure into something more linear.
Additionally, no functions return. Instead, they take a continuation and when they’re about to return they instead pass their value to it. This means that conceptually, all functions are transformed from a -> b
to (a, b -> void) -> void
. Logically, this is actually a reasonable thing to do. This corresponds to mapping a proposition b
to ¬ ¬ b
. What’s cool here is that since each function call calls a continuation instead of returning its result, we can imagine that each function just transferring control over to some other part of the program instead of returning. This leads to a very slick and efficient way of implementing CPSed function calls as we’ll see.
This means we’d change something like
into
To see what’s going on here we
- Added an extra argument to
fact
, its return continuation - In the first branch, we pass the result to the continuation instead of returning it
- In the next branch, we lift the nested expression
n - 1
into a flat let binding - We add an extra argument to the recursive call, the continuation
- In this continuation, we apply multiply the result of the recursive call by
n
(Note here that we did close overn
, this lambda is a real lambda) - Finally, we pass the final result to the original continuation
k
.
The only tree-style-nesting here comes from the top if
expression, everything else is completely linear.
Let’s formalize this process by converting Simply Typed Lambda Calculus (STLC) to CPS form.
STLC to CPS
First things first, we specify an AST for normal STLC.
data Tp = Arr Tp Tp | Int deriving Show
data Op = Plus | Minus | Times | Divide
-- The Tp in App refers to the return type, it's helpful later
data Exp a = App (Exp a) (Exp a) Tp
| Lam Tp (Scope () Exp a)
| Num Int
-- No need for binding here since we have Minus
| IfZ (Exp a) (Exp a) (Exp a)
| Binop Op (Exp a) (Exp a)
| Var a
We’ve supplemented our lambda calculus with natural numbers and some binary operations because it makes things a bit more fun. Additionally, we’re using bound to deal with bindings for lambdas. This means there’s a terribly boring monad instance lying around that I won’t bore you with.
To convert to CPS, we first need to figure out how to convert our types. Since CPS functions never return we want them to go to Void
, the unoccupied type. However, since our language doesn’t allow Void
outside of continuations, and doesn’t allow functions that don’t go to Void
, let’s bundle them up into one new type Cont a
which is just a function from a -> Void
. However, this presents us with a problem, how do we turn an Arr a b
into this style of function? It seems like our function should take two arguments, a
and b -> Void
so that it can produce a Void
of its own. However, this requires products since currying isn’t possible with the restriction that all functions return Void
! Therefore, we supplement our CPS language with pairs and projections for them.
Now we can write the AST for CPS types and a conversion between Tp
and it.
data CTp = Cont CTp | CInt | CPair CTp CTp
cpsTp :: Tp -> CTp
cpsTp (Arr l r) = Cont $ CPair (cpsTp l) (Cont (cpsTp r))
cpsTp Int = CInt
The only interesting thing here is how we translate function types, but we talked about that above. Now for expressions.
We want to define a new data type that encapsulates the restrictions of CPS. In order to do this we factor out our data types into “flat expressions” and “CPS expressions”. Flat expressions are things like values and variables while CPS expressions contain things like “Jump to this continuation” or “Branch on this flat expression”. Finally, there’s let expressions to perform various operations on expressions.
data LetBinder a = OpBind Op (FlatExp a) (FlatExp a)
| ProjL a
| ProjR a
| Pair (FlatExp a) (FlatExp a)
data FlatExp a = CNum Int | CVar a | CLam CTp a (CExp a)
data CExp a = Let a (LetBinder a) (CExp a)
| CIf (FlatExp a) (CExp a) (CExp a)
| Jump (FlatExp a) (FlatExp a)
| Halt (FlatExp a)
Let
s let us bind the results of a few “primitive operations” across values and variables to a fresh variable. This is where things like “incrementing a number” happen. Additionally, in order to create a pair or access its elements we need to us a Let
.
Notice that here application is spelled Jump
hinting that it really is just a jmp
and not dealing with the stack in any way. They’re all jumps we can not overflow the stack as would be an issue with a normal calling convention. To seal of the chain of function calls we have Halt
, it takes a FlatExp
and returns it as the result of the program.
Expressions here are also parameterized over variables but we can’t use bound with them (for reasons that deserve a blogpost-y rant :). Because of this we settle for just ensuring that each a
is globally unique.
So now instead of having a bunch of nested Exp
s, we have flat expressions which compute exactly one thing and linearize the tree of expressions into a series of flat ones with let binders. It’s still not quite “linear” since both lambdas and if branches let us have something tree-like.
We can now define conversion to CPS with one major helper function
This takes an expression, a “continuation” and produces a CExp
. We have some monad-gen stuff going on here because we need unique variables. The “continuation” is an actual Haskell function. So our function breaks an expression down to a FlatExp
and then feeds it to the continuation.
The first two cases are easy since variables and numbers are already flat expressions, they go straight into the continuation.
For IfZ
we first recurse on the i
. Then once we have a flattened computation representing i
, we use CIf
and recurse.
cps (Binop op l r) c =
cps l $ \fl ->
cps r $ \fr ->
gen >>= \out ->
Let out (OpBind op fl fr) <$> c (CVar out)
Like before, we use cps
to recurse on the left and right sides of the expression. This gives us two flat expressions which we use with OpBind
to compute the result and bind it to out
. Now that we have a variable for the result we just toss it to the continuation.
cps (Lam tp body) c = do
[pairArg, newCont, newArg] <- replicateM 3 gen
let body' = instantiate1 (Var newArg) body
cbody <- cps body' (return . Jump (CVar newCont))
c (CLam (cpsTp tp) pairArg
$ Let newArg (ProjL pairArg)
$ Let newCont (ProjR pairArg)
$ cbody)
Converting a lambda is a little bit more work. It needs to take a pair so a lot of the work is projecting out the left component (the argument) and the right component (the continuation). With those two things in hand we recurse in the body using the continuation supplied as an argument. The actual code makes this process look a little out of order. Remember that we only use cbody
once we’ve bound the projections to newArg
and pairArg
respectively.
cps (App l r tp) c = do
arg <- gen
cont <- CLam (cpsTp tp) arg <$> c (CVar arg)
cps l $ \fl ->
cps r $ \fr ->
gen >>= \pair ->
return $ Let pair (Pair fr cont) (Jump fl (CVar pair))
For application we just create a lambda for the current continuation. We then evaluate the left and right sides of the application using recursive calls. Now that we have a function to jump to, we create a pair of the argument and the continuation and bind it to a name. From there, we just jump to fl
, the function. Turning the continuation into a lambda is a little strange, it’s also we needed an annotation for App
. The lambda uses the return type of the application and constructs a continuation that maps a
to c a
. Note that c a
is a Haskell expressions with the type CExp a
.
With this, we’ve written a nice little compiler pass to convert expressions into their CPS forms. By doing this we’ve “eliminated expressions”. Everything is now flat and evaluation basically proceeds by evaluating one small computation and using the result to compute another and another.
There’s still some things left to compile out before this is machine code though
- Closures - these can be converted to explicitly pass records with closure conversion
- Hoist lambdas out of nested scope - this gets rid of anonymous functions, something we don’t have in C or assembly
- Make allocation explicit - Allocate a block of memory for a group of let statements and have them explicitly move the results of their computations to it
- Register allocation - Cleverly choose whether to store some particular variable in a register or load it in as needed.
Once we’ve done these steps we’ve basically written a compiler. However, they’re all influenced by the fact that we’ve compiled out expressions and (really) function calls with our conversion to CPS, it makes the process much much simpler.
Wrap Up
CPS conversion is a nice alternative to something like STG machines for lazy languages or SSA for imperative ones. As far as I’m aware the main SML interpreter (SML/NJ) compiles code in this way. As does Ur/Web if I’m not crazy. Additionally, the course entitled “Higher Order, Typed Compilation” which is taught here at CMU uses CPS conversion to make compiling SML really quite pleasant.
In fact, someone (Andrew Appel?) once wrote a paper that noted that SSA and CPS are actually the same. The key difference is that in SSA we merge multiple blocks together using the phi function. In CPS, we just let multiple source blocks jump to the same destination block (continuation). You can see this in our conversion of IfZ
to CPS, instead of using phi
to merge in the two branches, they both just use the continuation to jump to the remainder of the program. It makes things a little simpler because no one person needs to worry about
Finally, if you’re compiling a language like Scheme with call/cc, using CPS conversion makes the whole thing completely trivial. All you do is define call/cc
at the CPS level as
call/cc (f, c) = f ((λ (x, c') → c x), c)
So instead of using the continuation supplied to us in the expression we give to f
, we use the one for the whole call/cc
invocation! This causes us to not return into the body of f
but instead to carry on the rest of the program as if f
had returned whatever value x
is. This is how my old Scheme compiler did things, I put figuring out how to implement call/cc
off for a week before I realized it was a 10 minute job!
Hope this was helpful!
comments powered by Disqus