An Explanation of Type Inference for ML/Haskell

Posted on February 28, 2015
Tags: sml, haskell, types

A couple of days ago I wrote a small implementation of a type inferencer for a mini ML language. It turns out there are very few explanations of how to do this properly and the ones that exist tend to be the really naive, super exponential algorithm. I wrote the algorithm in SML but nothing should be unfamiliar to the average Haskeller.

Type inference breaks down into essentially 2 components

  1. Constraint Generation
  2. Unification

We inspect the program we’re trying to infer a type for and generate a bunch of statements (constraints) which are of the form

This type is equal to this type

These types have “unification variables” in them. These aren’t normal ML/Haskell type variables. They’re generated by the compiler, for the compiler, and will eventually be filled in with either

  1. A rigid polymorphic variable
  2. A normal concrete type

They should be thought of as holes in an otherwise normal type. For example, if we’re looking at the expression

   f a

We first just say that f : 'f where 'f is one of those unification variables I mentioned. Next we say that a : 'a. Since we’re apply f to a we can generate the constraints that

'f ~ 'x -> 'y
'a ~ 'x

Since we can only apply things with of the form _ -> _. We then unify these constraints to produce f : 'a -> 'x and a : 'a. We’d then using the surrounding constraints to produce more information about what exactly 'a and 'x might be. If this was all the constraints we had we’d then “generalize” 'a and 'x to be normal type variables, making our expression have the type x where f : a -> x and a : a.

Now onto some specifics

Set Up

In order to actually talk about type inference we first have to define our language. We have the abstract syntax tree:

    type tvar = int
    local val freshSource = ref 0 in
    fun fresh () : tvar =
        !freshSource before freshSource := !freshSource + 1
    end


    datatype monotype = TBool
                      | TArr of monotype * monotype
                      | TVar of tvar
    datatype polytype = PolyType of int list * monotype

    datatype exp = True
                 | False
                 | Var of int
                 | App of exp * exp
                 | Let of exp * exp
                 | Fn of exp
                 | If of exp * exp * exp

First we have type variables which are globally unique integers. To give us a method for actually producing them we have fresh which uses a ref-cell to never return the same result twice. This is probably surprising to Haskellers: SML isn’t purely functional and frankly this is less noisy than using something like monad-gen.

From there we have mono-types. These are normal ML types without any polymorphism. There are type/unification variables, booleans, and functions. Polytypes are just monotypes with an extra forall at the front. This is where we get polymorphism from. A polytype binds a number of type variables, stored in this representation as an int list. There is one ambiguity here, when looking at a variable it’s not clear whether it’s supposed to be a type variable (bound in a forall) and a unification variable. The idea is that we never ever inspect a type bound under a forall except when we’re converting it to a monotype with fresh unification variables in place of all of the bound variables. Thus, when inferring a type, every variable we come across is a unification variable.

Finally, we have expressions. Aside form the normal constants, we have variables, lambdas, applications, and if. The way we represent variables here is with DeBruijn variables. A variable is a number that tells you how many binders are between it and where it was bound. For example, const would be written Fn (Fn (Var 1)) in this representation.

With this in mind we define some helpful utility functions. When type checking, we have a context full of information. The two facts we know are

    datatype info = PolyTypeVar of polytype
                  | MonoTypeVar of monotype

    type context = info list

Where the ith element of a context indicates the piece of information we know about the ith DeBruijn variable. We’ll also need to substitute a type variable for a type. We also want to be able to find out all the free variables in a type.

    fun subst ty' var ty =
        case ty of
            TVar var' => if var = var' then ty' else TVar var'
          | TArr (l, r) => TArr (subst ty' var l, subst ty' var r)
          | TBool => TBool

    fun freeVars t =
        case t of
            TVar v => [v]
          | TArr (l, r) => freeVars l @ freeVars r
          | TBool => []

Both of these functions just recurse over types and do some work at the variable case. Note that freeVars can contain duplicates, this turns out not to be important in all cases except one: generalizeMonoType. The basic idea is that given a monotype with a bunch of unification variables and a surrounding context, figure out which variables can be bound up in a polymorphic type. If they don’t appear in the surrounding context, we generalize them by binding them in a new poly type’s forall spot.

    fun dedup [] = []
      | dedup (x :: xs) =
        if List.exists (fn y => x = y) xs
        then dedup xs
        else x :: dedup xs

    fun generalizeMonoType ctx ty =
        let fun notMem xs x = List.all (fn y => x <> y) xs
            fun free (MonoTypeVar m) = freeVars m
              | free (PolyTypeVar (PolyType (bs, m))) =
                List.filter (notMem bs) (freeVars m)

            val ctxVars = List.concat (List.map free ctx)
            val polyVars = List.filter (notMem ctxVars) (freeVars ty)
        in PolyType (dedup polyVars, ty) end

Here the bulk of the code is deciding whether or not a variable is free in the surrounding context using free. It looks at a piece of info to determine what variables occur in it. We then accumulate all of these variables into cxtVars and use this list to decide what to generalize.

Next we need to take a polytype to a monotype. This is the specialization of a polymorphic type that we love and use when we use map on a function from int -> double. This works by taking each bound variable and replacing it with a fresh unification variables. This is nicely handled by folds!

    fun mintNewMonoType (PolyType (ls, ty)) =
        foldl (fn (v, t) => subst (TVar (fresh ())) v t) ty ls

Last but not least, we have a function to take a context and a variable and give us a monotype which corresponds to it. This may produce a new monotype if we think the variable has a polytype.

    exception UnboundVar of int
    fun lookupVar var ctx =
        case List.nth (ctx, var) handle Subscript => raise UnboundVar var of
            PolyTypeVar pty => mintNewMonoType pty
          | MonoTypeVar mty => mty

For the sake of nice error messages, we also throw UnboundVar instead of just subscript in the error case. Now that we’ve gone through all of the utility functions, on to unification!

Unification

A large part of this program is basically “I’ll give you a list of constraints and you give me the solution”. The program to solve these proceeds by pattern matching on the constraints.

In the empty case, we have no constraints so we give back the empty solution.

    fun unify [] = []

In the next case we actually have to look at what constraint we’re trying to solve.

      | unify (c :: constrs) =
        case c of

If we’re lucky, we’re just trying to unify TBool with TBool, this does nothing since these types have no variables and are equal. In this case we just recurse.

       (TBool, TBool) => unify constrs

If we’ve got two function types, we just constrain their domains and ranges to be the same and continue on unifying things.

     | (TArr (l, r), TArr (l', r')) => unify ((l, l') :: (r, r') :: constrs)

Now we have to deal with finding a variable. We definitely want to avoid adding (TVar v, TVar v) to our solution, so we’ll have a special case for trying to unify two variables.

     | (TVar i, TVar j) =>
       if i = j
       then unify constrs
       else addSol i (TVar j) (unify (substConstrs (TVar j) i constrs))

This is our first time actually adding something to our solution so there’s several new elements here. The first is this function addSol. It’s defined as

    fun addSol v ty sol = (v, applySol sol ty) :: sol

So in order to make sure our solution is internally consistent it’s important that whenever we add a type to our solution we first apply the solution to it. This ensures that we can substitute a variable in our solution for its corresponding type and not worry about whether we need to do something further. Additionally, whenever we add a new binding we substitute for it in the constraints we have left to ensure we never have a solution which is just inconsistent. This prevents us from unifying v ~ TBool and v ~ TArr(TBool, TBool) in the same solution! The actual code for doing this is that substConstr (TVar j) i constrs bit.

The next case is the general case for unifying a variable with some type. It looks very similar to this one.

     | ((TVar i, ty) | (ty, TVar i)) =>
       if occursIn i ty
       then raise UnificationError c
       else addSol i ty (unify (substConstrs ty i constrs))

Here we have the critical occursIn check. This checks to see if a variable appears in a type and prevents us from making erroneous unifications like TVar a ~ TArr (TVar a, TVar a). This occurs check is actually very easy to implement

    fun occursIn v ty = List.exists (fn v' => v = v') (freeVars ty)

Finally we have one last case: the failure case. This is the catch-all case for if we try to unify two things that are obviously incompatible.

     | _ => raise UnificationError c

All together, that code was

    fun applySol sol ty =
        foldl (fn ((v, ty), ty') => subst ty v ty') ty sol
    fun applySolCxt sol cxt =
        let fun applyInfo i =
                case i of
                    PolyTypeVar (PolyType (bs, m)) =>
                    PolyTypeVar (PolyType (bs, (applySol sol m)))
                  | MonoTypeVar m => MonoTypeVar (applySol sol m)
        in map applyInfo cxt end

    fun addSol v ty sol = (v, applySol sol ty) :: sol

    fun occursIn v ty = List.exists (fn v' => v = v') (freeVars ty)

    fun unify ([] : constr list) : sol = []
      | unify (c :: constrs) =
        case c of
            (TBool, TBool) => unify constrs
          | (TVar i, TVar j) =>
            if i = j
            then unify constrs
            else addSol i (TVar j) (unify (substConstrs (TVar j) i constrs))
          | ((TVar i, ty) | (ty, TVar i)) =>
            if occursIn i ty
            then raise UnificationError c
            else addSol i ty (unify (substConstrs ty i constrs))
          | (TArr (l, r), TArr (l', r')) =>
            unify ((l, l') :: (r, r') :: constrs)
          | _ => raise UnificationError c

Constraint Generation

The other half of this algorithm is the constraint generation part. We generate constraints and use unify to turn them into solutions. This boils down to two functoins. The first is to glue together solutions.

    fun <+> (sol1, sol2) =
        let fun notInSol2 v = List.all (fn (v', _) => v <> v') sol2
            val sol1' = List.filter (fn (v, _) => notInSol2 v) sol1
        in
            map (fn (v, ty) => (v, applySol sol1 ty)) sol2 @ sol1'
        end
    infixr 3 <+>

Given two solutions we figure out which things don’t occur in the in the second solution. Next, we apply solution 1 everywhere in the second solution, giving a consistent solution wihch contains everything in sol2, finally we add in all the stuff not in sol2 but in sol1. This doesn’t check to make sure that the solutions are actually consistent, this is done elsewhere.

Next is the main function here constrain. This actually generates solution and type given a context and an expression. The first few cases are nice and simple

    fun constrain ctx True = (TBool, [])
      | constrain ctx False = (TBool, [])
      | constrain ctx (Var i) = (lookupVar i ctx, [])

In these cases we don’t infer any constraints, we just figure out types based on information we know previously. Next for Fn we generate a fresh variable to represent the arguments type and just constrain the body.

      | constrain ctx (Fn body) =
        let val argTy = TVar (fresh ())
            val (rTy, sol) = constrain (MonoTypeVar argTy :: ctx) body
        in (TArr (applySol sol argTy, rTy), sol) end

Once we have the solution for the body, we apply it to the argument type which might replace it with a concrete type if the constraints we inferred for the body demand it. For If we do something similar except we add a few constraints of our own to solve.

      | constrain ctx (If (i, t, e)) =
        let val (iTy, sol1) = constrain ctx i
            val (tTy, sol2) = constrain (applySolCxt sol1 ctx) t
            val (eTy, sol3) = constrain (applySolCxt (sol1 <+> sol2) ctx) e
            val sol = sol1 <+> sol2 <+> sol3
            val sol = sol <+> unify [ (applySol sol iTy, TBool)
                                    , (applySol sol tTy, applySol sol eTy)]
        in
            (tTy, sol)
        end

Notice how we apply each solution to the context for the next thing we’re constraining. This is how we ensure that each solution will be consistent. Once we’ve generated solutions to the constraints in each of the subterms, we smash them together to produce the first solution. Next, we ensure that the subcomponents have the right type by generating a few constraints to ensure that iTy is a bool and that tTy and eTy (the types of the branches) are both the same. We have to carefully apply the sol to each of these prior to unifying them to make sure our solution stays consistent.

This is practically the same as what the App case is

      | constrain ctx (App (l, r)) =
        let val (domTy, ranTy) = (TVar (fresh ()), TVar (fresh ()))
            val (funTy, sol1) = constrain ctx l
            val (argTy, sol2) = constrain (applySolCxt sol1 ctx) r
            val sol = sol1 <+> sol2
            val sol = sol <+> unify [(applySol sol funTy,
                                      applySol sol (TArr (domTy, ranTy)))
                                    , (applySol sol argTy, applySol sol domTy)]
        in (ranTy, sol) end

The only real difference here is that we generate different constraints: we make sure we’re applying a function whose domain is the same as the argument type.

The most interesting case here is Let. This implements let generalization which is how we actually get polymorphism. After inferring the type of the thing we’re binding we generalize it, giving us a poly type to use in the body of let. The key to generalizing it is that generalizeMonoType we had before.

      | constrain ctx (Let (e, body)) =
        let val (eTy, sol1) = constrain ctx e
            val ctx' = applySolCxt sol1 ctx
            val eTy' = generalizeMonoType ctx' (applySol sol1 eTy)
            val (rTy, sol2) = constrain (PolyTypeVar eTy' :: ctx') body
        in (rTy, sol1 <+> sol2) end

We do pretty much everything we had before except now we carefully ensure to apply the solution we get for the body to the context and then to generalize the type with respect to that new context. This is how we actually get polymorphism, it will assign a proper polymorphic type to the argument.

That wraps up constraint generation. Now all that’s left to see if the overall driver for type inference.

    fun infer e =
        let val (ty, sol) = constrain [] e
        in generalizeMonoType [] (applySol sol ty) end
    end

So all we do is infer and generalize a type! And there you have it, that’s how ML and Haskell do type inference.

Wrap Up

Hopefully that clears up a little of the magic of how type inference works. The next challenge is to figure out how to do type inference on a language with patterns and ADTs! This is actually quite fun, pattern checking involves synthesizing a type from a pattern which needs something like linear logic to handle pattern variables correctly.

With this we’re actually a solid 70% of the way to building a type checker to SML. Until I have more free time though, I leave this as an exercise to the curious reader.

Cheers,

comments powered by Disqus