An Explanation of Type Inference for ML/Haskell
A couple of days ago I wrote a small implementation of a type inferencer for a mini ML language. It turns out there are very few explanations of how to do this properly and the ones that exist tend to be the really naive, super exponential algorithm. I wrote the algorithm in SML but nothing should be unfamiliar to the average Haskeller.
Type inference breaks down into essentially 2 components
- Constraint Generation
- Unification
We inspect the program we’re trying to infer a type for and generate a bunch of statements (constraints) which are of the form
This type is equal to this type
These types have “unification variables” in them. These aren’t normal ML/Haskell type variables. They’re generated by the compiler, for the compiler, and will eventually be filled in with either
- A rigid polymorphic variable
- A normal concrete type
They should be thought of as holes in an otherwise normal type. For example, if we’re looking at the expression
f a
We first just say that f : 'f
where 'f
is one of those unification variables I mentioned. Next we say that a : 'a
. Since we’re apply f
to a
we can generate the constraints that
'f ~ 'x -> 'y
'a ~ 'x
Since we can only apply things with of the form _ -> _
. We then unify these constraints to produce f : 'a -> 'x
and a : 'a
. We’d then using the surrounding constraints to produce more information about what exactly 'a
and 'x
might be. If this was all the constraints we had we’d then “generalize” 'a
and 'x
to be normal type variables, making our expression have the type x
where f : a -> x
and a : a
.
Now onto some specifics
Set Up
In order to actually talk about type inference we first have to define our language. We have the abstract syntax tree:
type tvar = int
local val freshSource = ref 0 in
fun fresh () : tvar =
!freshSource before freshSource := !freshSource + 1
end
datatype monotype = TBool
| TArr of monotype * monotype
| TVar of tvar
datatype polytype = PolyType of int list * monotype
datatype exp = True
| False
| Var of int
| App of exp * exp
| Let of exp * exp
| Fn of exp
| If of exp * exp * exp
First we have type variables which are globally unique integers. To give us a method for actually producing them we have fresh
which uses a ref-cell to never return the same result twice. This is probably surprising to Haskellers: SML isn’t purely functional and frankly this is less noisy than using something like monad-gen
.
From there we have mono-types. These are normal ML types without any polymorphism. There are type/unification variables, booleans, and functions. Polytypes are just monotypes with an extra forall
at the front. This is where we get polymorphism from. A polytype binds a number of type variables, stored in this representation as an int list. There is one ambiguity here, when looking at a variable it’s not clear whether it’s supposed to be a type variable (bound in a forall) and a unification variable. The idea is that we never ever inspect a type bound under a forall except when we’re converting it to a monotype with fresh unification variables in place of all of the bound variables. Thus, when inferring a type, every variable we come across is a unification variable.
Finally, we have expressions. Aside form the normal constants, we have variables, lambdas, applications, and if. The way we represent variables here is with DeBruijn variables. A variable is a number that tells you how many binders are between it and where it was bound. For example, const
would be written Fn (Fn (Var 1))
in this representation.
With this in mind we define some helpful utility functions. When type checking, we have a context full of information. The two facts we know are
datatype info = PolyTypeVar of polytype
| MonoTypeVar of monotype
type context = info list
Where the ith element of a context indicates the piece of information we know about the ith DeBruijn variable. We’ll also need to substitute a type variable for a type. We also want to be able to find out all the free variables in a type.
fun subst ty' var ty =
case ty of
TVar var' => if var = var' then ty' else TVar var'
| TArr (l, r) => TArr (subst ty' var l, subst ty' var r)
| TBool => TBool
fun freeVars t =
case t of
TVar v => [v]
| TArr (l, r) => freeVars l @ freeVars r
| TBool => []
Both of these functions just recurse over types and do some work at the variable case. Note that freeVars
can contain duplicates, this turns out not to be important in all cases except one: generalizeMonoType
. The basic idea is that given a monotype with a bunch of unification variables and a surrounding context, figure out which variables can be bound up in a polymorphic type. If they don’t appear in the surrounding context, we generalize them by binding them in a new poly type’s forall spot.
fun dedup [] = []
| dedup (x :: xs) =
if List.exists (fn y => x = y) xs
then dedup xs
else x :: dedup xs
fun generalizeMonoType ctx ty =
let fun notMem xs x = List.all (fn y => x <> y) xs
fun free (MonoTypeVar m) = freeVars m
| free (PolyTypeVar (PolyType (bs, m))) =
List.filter (notMem bs) (freeVars m)
val ctxVars = List.concat (List.map free ctx)
val polyVars = List.filter (notMem ctxVars) (freeVars ty)
in PolyType (dedup polyVars, ty) end
Here the bulk of the code is deciding whether or not a variable is free in the surrounding context using free
. It looks at a piece of info to determine what variables occur in it. We then accumulate all of these variables into cxtVars
and use this list to decide what to generalize.
Next we need to take a polytype to a monotype. This is the specialization of a polymorphic type that we love and use when we use map
on a function from int -> double
. This works by taking each bound variable and replacing it with a fresh unification variables. This is nicely handled by folds!
fun mintNewMonoType (PolyType (ls, ty)) =
foldl (fn (v, t) => subst (TVar (fresh ())) v t) ty ls
Last but not least, we have a function to take a context and a variable and give us a monotype which corresponds to it. This may produce a new monotype if we think the variable has a polytype.
exception UnboundVar of int
fun lookupVar var ctx =
case List.nth (ctx, var) handle Subscript => raise UnboundVar var of
PolyTypeVar pty => mintNewMonoType pty
| MonoTypeVar mty => mty
For the sake of nice error messages, we also throw UnboundVar
instead of just subscript in the error case. Now that we’ve gone through all of the utility functions, on to unification!
Unification
A large part of this program is basically “I’ll give you a list of constraints and you give me the solution”. The program to solve these proceeds by pattern matching on the constraints.
In the empty case, we have no constraints so we give back the empty solution.
fun unify [] = []
In the next case we actually have to look at what constraint we’re trying to solve.
| unify (c :: constrs) =
case c of
If we’re lucky, we’re just trying to unify TBool
with TBool
, this does nothing since these types have no variables and are equal. In this case we just recurse.
(TBool, TBool) => unify constrs
If we’ve got two function types, we just constrain their domains and ranges to be the same and continue on unifying things.
| (TArr (l, r), TArr (l', r')) => unify ((l, l') :: (r, r') :: constrs)
Now we have to deal with finding a variable. We definitely want to avoid adding (TVar v, TVar v)
to our solution, so we’ll have a special case for trying to unify two variables.
| (TVar i, TVar j) =>
if i = j
then unify constrs
else addSol i (TVar j) (unify (substConstrs (TVar j) i constrs))
This is our first time actually adding something to our solution so there’s several new elements here. The first is this function addSol
. It’s defined as
fun addSol v ty sol = (v, applySol sol ty) :: sol
So in order to make sure our solution is internally consistent it’s important that whenever we add a type to our solution we first apply the solution to it. This ensures that we can substitute a variable in our solution for its corresponding type and not worry about whether we need to do something further. Additionally, whenever we add a new binding we substitute for it in the constraints we have left to ensure we never have a solution which is just inconsistent. This prevents us from unifying v ~ TBool
and v ~ TArr(TBool, TBool)
in the same solution! The actual code for doing this is that substConstr (TVar j) i constrs
bit.
The next case is the general case for unifying a variable with some type. It looks very similar to this one.
| ((TVar i, ty) | (ty, TVar i)) =>
if occursIn i ty
then raise UnificationError c
else addSol i ty (unify (substConstrs ty i constrs))
Here we have the critical occursIn
check. This checks to see if a variable appears in a type and prevents us from making erroneous unifications like TVar a ~ TArr (TVar a, TVar a)
. This occurs check is actually very easy to implement
fun occursIn v ty = List.exists (fn v' => v = v') (freeVars ty)
Finally we have one last case: the failure case. This is the catch-all case for if we try to unify two things that are obviously incompatible.
| _ => raise UnificationError c
All together, that code was
fun applySol sol ty =
foldl (fn ((v, ty), ty') => subst ty v ty') ty sol
fun applySolCxt sol cxt =
let fun applyInfo i =
case i of
PolyTypeVar (PolyType (bs, m)) =>
PolyTypeVar (PolyType (bs, (applySol sol m)))
| MonoTypeVar m => MonoTypeVar (applySol sol m)
in map applyInfo cxt end
fun addSol v ty sol = (v, applySol sol ty) :: sol
fun occursIn v ty = List.exists (fn v' => v = v') (freeVars ty)
fun unify ([] : constr list) : sol = []
| unify (c :: constrs) =
case c of
(TBool, TBool) => unify constrs
| (TVar i, TVar j) =>
if i = j
then unify constrs
else addSol i (TVar j) (unify (substConstrs (TVar j) i constrs))
| ((TVar i, ty) | (ty, TVar i)) =>
if occursIn i ty
then raise UnificationError c
else addSol i ty (unify (substConstrs ty i constrs))
| (TArr (l, r), TArr (l', r')) =>
unify ((l, l') :: (r, r') :: constrs)
| _ => raise UnificationError c
Constraint Generation
The other half of this algorithm is the constraint generation part. We generate constraints and use unify
to turn them into solutions. This boils down to two functoins. The first is to glue together solutions.
fun <+> (sol1, sol2) =
let fun notInSol2 v = List.all (fn (v', _) => v <> v') sol2
val sol1' = List.filter (fn (v, _) => notInSol2 v) sol1
in
map (fn (v, ty) => (v, applySol sol1 ty)) sol2 @ sol1'
end
infixr 3 <+>
Given two solutions we figure out which things don’t occur in the in the second solution. Next, we apply solution 1 everywhere in the second solution, giving a consistent solution wihch contains everything in sol2
, finally we add in all the stuff not in sol2
but in sol1
. This doesn’t check to make sure that the solutions are actually consistent, this is done elsewhere.
Next is the main function here constrain
. This actually generates solution and type given a context and an expression. The first few cases are nice and simple
fun constrain ctx True = (TBool, [])
| constrain ctx False = (TBool, [])
| constrain ctx (Var i) = (lookupVar i ctx, [])
In these cases we don’t infer any constraints, we just figure out types based on information we know previously. Next for Fn
we generate a fresh variable to represent the arguments type and just constrain the body.
| constrain ctx (Fn body) =
let val argTy = TVar (fresh ())
val (rTy, sol) = constrain (MonoTypeVar argTy :: ctx) body
in (TArr (applySol sol argTy, rTy), sol) end
Once we have the solution for the body, we apply it to the argument type which might replace it with a concrete type if the constraints we inferred for the body demand it. For If
we do something similar except we add a few constraints of our own to solve.
| constrain ctx (If (i, t, e)) =
let val (iTy, sol1) = constrain ctx i
val (tTy, sol2) = constrain (applySolCxt sol1 ctx) t
val (eTy, sol3) = constrain (applySolCxt (sol1 <+> sol2) ctx) e
val sol = sol1 <+> sol2 <+> sol3
val sol = sol <+> unify [ (applySol sol iTy, TBool)
, (applySol sol tTy, applySol sol eTy)]
in
(tTy, sol)
end
Notice how we apply each solution to the context for the next thing we’re constraining. This is how we ensure that each solution will be consistent. Once we’ve generated solutions to the constraints in each of the subterms, we smash them together to produce the first solution. Next, we ensure that the subcomponents have the right type by generating a few constraints to ensure that iTy
is a bool and that tTy
and eTy
(the types of the branches) are both the same. We have to carefully apply the sol
to each of these prior to unifying them to make sure our solution stays consistent.
This is practically the same as what the App
case is
| constrain ctx (App (l, r)) =
let val (domTy, ranTy) = (TVar (fresh ()), TVar (fresh ()))
val (funTy, sol1) = constrain ctx l
val (argTy, sol2) = constrain (applySolCxt sol1 ctx) r
val sol = sol1 <+> sol2
val sol = sol <+> unify [(applySol sol funTy,
applySol sol (TArr (domTy, ranTy)))
, (applySol sol argTy, applySol sol domTy)]
in (ranTy, sol) end
The only real difference here is that we generate different constraints: we make sure we’re applying a function whose domain is the same as the argument type.
The most interesting case here is Let
. This implements let generalization which is how we actually get polymorphism. After inferring the type of the thing we’re binding we generalize it, giving us a poly type to use in the body of let. The key to generalizing it is that generalizeMonoType
we had before.
| constrain ctx (Let (e, body)) =
let val (eTy, sol1) = constrain ctx e
val ctx' = applySolCxt sol1 ctx
val eTy' = generalizeMonoType ctx' (applySol sol1 eTy)
val (rTy, sol2) = constrain (PolyTypeVar eTy' :: ctx') body
in (rTy, sol1 <+> sol2) end
We do pretty much everything we had before except now we carefully ensure to apply the solution we get for the body to the context and then to generalize the type with respect to that new context. This is how we actually get polymorphism, it will assign a proper polymorphic type to the argument.
That wraps up constraint generation. Now all that’s left to see if the overall driver for type inference.
fun infer e =
let val (ty, sol) = constrain [] e
in generalizeMonoType [] (applySol sol ty) end
end
So all we do is infer and generalize a type! And there you have it, that’s how ML and Haskell do type inference.
Wrap Up
Hopefully that clears up a little of the magic of how type inference works. The next challenge is to figure out how to do type inference on a language with patterns and ADTs! This is actually quite fun, pattern checking involves synthesizing a type from a pattern which needs something like linear logic to handle pattern variables correctly.
With this we’re actually a solid 70% of the way to building a type checker to SML. Until I have more free time though, I leave this as an exercise to the curious reader.
Cheers,
comments powered by Disqus