ettolrach
(she/her)ettolrach
(she/her)Written by ettolrach, 2025-12-23.
Most commonly-used programming languages don't use a well-known algorithm for their type inference. A few languages use Hindley-Milner type inference, such as Rust, Haskell, and F#.
Another algorithm which isn't used by any popular languages is called bidirectional type inference. In contrast to Hindley-Milner, it's supposed to be easier to implement and gives better error messages with the drawback of being more difficult to understand and to start with.
There aren't a lot of tutorials out there, especially ones that are designed for someone who doesn't have much knowledge of the theory. Two other tutorials I would recommend if you'd like to read more are Dunfield and Krishnaswami (2013) for a more formal introduction, and Christiansen (2013) for another informal and easy tutorial.
We will start with a simple language, the simply typed lambda calculus. If there is interest, I'll try to write a nice implementation of the 'complete and easy' variant for polymorphism so you can learn how to typecheck a polymorphically typed langauge in a future blog post.
Although I'll be using Rust below, it should be fairly similar for most other languages.
The simply typed lambda calculus (STLC) captures the essence of a functional programming language. We only have free variables, abstractions (a.k.a. lambdas), and (function) application. We will also have primitives in the form of the booleans false and true. Later, we'll extend it with another primitive: if-then-else. If you've used a modern language (any commonly used language apart from C, really), then you've already got experience using these! We also allow you to annotate (sometimes called decorate) your term with a type. Our syntax is as follows:
$$\begin{align*} L, M :=& & &\texttt{false}, \, \texttt{true} \quad \text{(boolean literals)}\\ |& & &x \quad \text{(variables)}\\ |& & &\lambda x. L \quad \text{(abstractions)}\\ |& & &L M \quad \text{(applications)}\\ |& & &L: A \quad \text{(annotations)} \end{align*}$$
Our types will be kept simple: we just have one base type, Bool, and a function (arrow) type.
$$\begin{align*} A, B :=& & &\texttt{Bool}\\ |& & &A \to B \end{align*}$$
As the name would imply, the typechecking will go in two directions: some terms infer the type and others check that they are a given type. Some authors prefer to say that terms synthesise types and others inherit a type. I've also seen a mix of both: synthesising and checking. I prefer the first of these, calling it inferring and checking, but that's purely my personal choice.
First, we'll go over how we can choose which direction our terms should go. If you ever get bored or lost, feel free to skip ahead to the implementation, and if something doesn't make sense, you can always go back to the previous section.
But how do you decide which terms infer and which check? We will follow the Pfenning method (Dunfield and Pfenning, 2004). First, start by writing out your typing rules.
$$\begin{gathered} \dfrac{x \colon A \in \Gamma}{\Gamma \vdash x \colon A} \; (\text{Var}) \quad \dfrac{\Gamma \vdash L \colon A}{\Gamma \vdash (L: A) \colon A} \; (\text{Anno}) \quad \dfrac{\Gamma \vdash L \colon A \quad A = B}{\Gamma \vdash L \colon B} \; (\text{Ty Eq})\\[7mm] \dfrac{}{\Gamma \vdash \texttt{false} \colon \texttt{Bool}} \; (\text{False}) \quad \dfrac{}{\Gamma \vdash \texttt{true} \colon \texttt{Bool}} \; (\text{True}) \quad\\[5mm] \dfrac{\Gamma, x \colon A \vdash L \colon B}{\Gamma \vdash (\lambda x. L) \colon A \to B} \; (\text{\( \to \) Intr}) \quad \dfrac{\Gamma \vdash L \colon A \to B \quad \Gamma \vdash M \colon A}{\Gamma \vdash L M \colon B} \; (\text{\( \to \) Elim}) \end{gathered}$$
Var is a rule which typechecks free variables. If our context contains the mapping x to type A, then the free variable x has the type A.Anno rule helps us make the type inference work: we can annotate a term with a type A whenever the term L has type A. The annotation itself has the same type; A.Ty Eq for type equality, which says that if two types A and B are equal, then we can swap a term whose type is A with the type B. This seems pretty useless for now, but will prove useful soon.false and true are primitives which have the type Bool and can be constructed any time, like in Rust.-> Intr is our rule for arrow type introduction, that is, creating an abstraction (a lambda). If the term L has type B when we add a free variable named x of type A to the context, then the abstraction with the argument x of type A and body L has type A -> B.-> Elim eliminates arrow types, that is, it's the rule for application. If L is a term with type A -> B (which implies that it's an abstraction), and M is a term of type A, then we can apply M to L and get the type B.The second step involves looking at each type rule and turning them into inference and checking rules. We can do this by replacing the colons with arrows. We'll use => for a rule which infers the type of a term and <= for a rule which checks the type of a term.
We'll start with introduction and elimination rules. We need to find the principal judgement of a rule. That is, in an introduction, the term which is being introduced and is usually the conclusion (and for an elimination, the term which is being eliminated, which is usually the first assumption). Then, we bidirectionalise this principal judgement, that is, we replace the colon with either => for inferring or <= for checking. Introductions are checking, eliminations are inferring.
Then, we need to bidirectionalise the other parts of the rule. We start with the earliest judgement in the assumption before moving to the later ones and finally the conclusion. To decide what direction to choose, we need to use the types available from previously bidirectionalised judgements if possible. For example, if we already know that L : B, then if another part of the rule has L : B, then we should be checking to use that information.
Let's start! The first introduction or elimination rules are False and True. These are both introductions, so they should be checking rules. There are no more judgements in the rule, so we are done.
$$\begin{gathered} \dfrac{}{\Gamma \vdash \texttt{false} \Leftarrow \texttt{Bool}} \; (\text{False}) \quad \dfrac{}{\Gamma \vdash \texttt{true} \Leftarrow \texttt{Bool}} \; (\text{True}) \end{gathered}$$
Next is -> Intr, an introduction. The principal judgement is in the conclusion, Γ ⊢ (\x. L) : A -> B, so we make this a checking rule. Because we're checking, that means we already know the type B. Thus, we can check that L really is of type B.
$$\begin{gathered} \dfrac{\Gamma, x \colon A \vdash L \Leftarrow B}{\Gamma \vdash (\lambda x. L) \Leftarrow A \to B} \; (\text{\( \to \) Intr}) \end{gathered}$$
-> Elim is an elimination rule. The principal judgement is the abstraction, Γ ⊢ L : A -> B, so we make it an inferring rule. Now that we know we need a type A for the argument to the abstraction, we can make the other judgement in the assumption (Γ ⊢ M : A) a checking rule. Lastly, because we inferred the type B already, the conclusion infers B too.
$$\begin{gathered} \dfrac{\Gamma \vdash L \Rightarrow A \to B \quad \Gamma \vdash M \Leftarrow A}{\Gamma \vdash L M \Rightarrow B} \; (\text{\( \to \) Elim}) \end{gathered}$$
Finally, there are three rules that are neither introductions, nor eliminations: Var, Anno, and Ty Eq.
Var, we infer because we already know the type from the type context.Anno, the principal judgement is the conclusion (Γ ⊢ (L: A) : A). Since we're given the type by the annotation, we might as well infer it. Now for the assumption, because we know the type A already, we should check that the expression really is of type A. Intuitively, we're making sure that the term really has the type which the user has annotated it as.Ty Eq was added as a rule to address a problem that we would have without it. Notice that if we want to infer a term when the term is only available as checking, then we can use an annotation to "flip" the arrow around and turn a check term to an infer term. But we don't have the inverse! So we use Ty Eq as a way to do this. The reason for having this explicit equals check is to allow subtyping (simply replace = with <:). All that to say, the conclusion should be a checking judgement and the assumption an inferring judgement.Thus, we have the following rules (along with the previous ones for a nice overview):
$$\begin{gathered} \dfrac{x \colon A \in \Gamma}{\Gamma \vdash x \Rightarrow A} \; (\text{Var}) \quad \dfrac{\Gamma \vdash L \Leftarrow A}{\Gamma \vdash (L: A) \Rightarrow A} \; (\text{Anno}) \quad \dfrac{\Gamma \vdash L \Rightarrow A \quad A = B}{\Gamma \vdash L \Leftarrow B} \; (\text{Ty Eq})\\[7mm] \dfrac{}{\Gamma \vdash \texttt{false} \Leftarrow \texttt{Bool}} \; (\text{False}) \quad \dfrac{}{\Gamma \vdash \texttt{true} \Leftarrow \texttt{Bool}} \; (\text{True})\\[5mm] \dfrac{\Gamma, x \colon A \vdash L \Leftarrow B}{\Gamma \vdash (\lambda x. L) \Leftarrow A \to B} \; (\text{\( \to \) Intr}) \dfrac{\Gamma \vdash L \Rightarrow A \to B \quad \Gamma \vdash M \Leftarrow A}{\Gamma \vdash L M \Rightarrow B} \; (\text{\( \to \) Elim}) \end{gathered}$$
Finally! Let's get to writing some code. The full source code will be available at the bidirectional-rs repository on my GitHub. I've also written the tokeniser and parser, so you may wish to play around with the language using the code in the repo. We implement our terms using an abstract syntax tree.
#[derive(Debug, Clone, PartialEq, Eq)] pub enum Type { Bool, Arrow(Box<Type>, Box<Type>), /// Only used when a type error ends type inference early. Unknown, } #[derive(Debug, Clone)] pub enum Expr { True, False, // IfThenElse(Box<TrackedExpr>, Box<TrackedExpr>, Box<TrackedExpr>), Fv(String), Abs(String, Box<TrackedExpr>), App(Box<TrackedExpr>, Box<TrackedExpr>), Anno(Box<TrackedExpr>, Type), } #[derive(Debug, Clone)] pub struct TrackedExpr { pub expr: Expr, pub line: usize, pub column: usize, }
Type has our base type (Bool) and our arrow type. It also has an "unknown" type which will only be used in error messages when we need to return early without having worked out the correct type yet.
Expr is as you would expect. We have an additional term: IfThenElse. This is for an exercise later. We also reference TrackedExpr for the variants which recurse so that we can provide better error messages. TrackedExpr is a struct to help us keep "track" of what line and column the term we're looking at has.
Also note, this will not be the most efficient implementation. There is almost certainly a way to make the algorithm work without using the Clone trait, instead using some of the Copy types (such as u64). But that's not the focus of this post.
One of the benefits of this method of typechecking is that we can provide quite specific error messages. We will end up using this struct for reporting errors:
#[derive(Debug)] pub enum ErrorKind { /// Expr didn't match type. AnnotationMismatch { caused_by: Box<TypeError> }, /// Expression requires an annotation. AnnotationRequired, /// An application wasn't a function (arrow) type. ApplicationNotArrow { actual: Type }, /// Argument to abstraction had wrong type. ArgWrongType { expected: Type, actual: Type }, /// Checked for the expected type, got the actual type. CheckedWrongType { expected: Type, actual: Type }, /// A free variable was used before it was in scope. UndeclaredFv, } #[derive(Debug)] pub struct TypeError { pub kind: ErrorKind, pub line: usize, pub column: usize, }
We will need to write two mutually recursive functions: infer and check. They correspond to the => and <= terms respectively. The infer function will return a type or an error, while the check function returns the unit type or an error (we only have something meaningful to say when the check failed; if it succeeded, then there's nothing to report).
use std::collections::HashMap; pub fn infer( tracked_expr: TrackedExpr, context: &mut HashMap<String, Type>, ) -> Result<Type, TypeError> { // --snip-- } pub fn check( tracked_expr: TrackedExpr, ty: Type, context: &mut HashMap<String, Type>, ) -> Result<(), TypeError> { // --snip-- }
Let's start with infer.
pub fn infer( tracked_expr: TrackedExpr, context: &mut HashMap<String, Type>, ) -> Result<Type, TypeError> { let line = tracked_expr.line; let column = tracked_expr.column; match tracked_expr.expr { Expr::Fv(s) => { if let Some(ty) = context.get(&s) { Ok(ty.clone()) } else { Err(TypeError { kind: ErrorKind::UndeclaredFv, line, column, }) } }
For free variables, we need to check our context for which type the identifier maps to. If the identifier exists, we return the type it maps to. Otherwise, we report an error.
Expr::Anno(expr, ty) => { match check(*expr.clone(), ty.clone(), context) { Ok(_) => Ok(ty), Err(e) => Err(TypeError { kind: ErrorKind::AnnotationMismatch { caused_by: Box::new(e), }, line, column, }), } }
Recall our bidirectional rule Anno where we checked that L had type A. In the code above, we do this by calling the check function using the term expr and type ty along with our context, which we leave unchanged. If it was successful, then we return the annotated type. Otherwise, we report an error, making sure to pass along the underlying typechecking failure.
Expr::App(l, m) => match infer(*l, context)? { Type::Arrow(a, b) => match check(*m.clone(), *a, context) { Ok(_) => Ok(*b), Err(TypeError { kind: ErrorKind::CheckedWrongType { expected: a, actual, }, line, column, }) => Err(TypeError { kind: ErrorKind::ArgWrongType { expected: a, actual, }, line, column, }), Err(e) => Err(e), }, actual => Err(TypeError { kind: ErrorKind::ApplicationNotArrow { actual }, line, column, }), },
Now for applications, which correspond to our -> Elim rule. This looks like a complicated part, but most of the code is just error reporting, so don't get intimidated! In fact, only the first three lines of code are actually important.
The rule tells us that we should first infer the type of L. In the code, we call infer and match on the returned type. If we get anything but an arrow type back, then we've got ourselves a type error. But if we do, then we inferred the type A -> B. We now need to make sure our argument M has the correct type B. If it does, then we return the resulting type of the application; B. If it doesn't, we report an error.
_ => Err(TypeError { kind: ErrorKind::AnnotationRequired, line, column, }), } }
Finally, if we get anything else, we're going in the 'wrong direction' and we need the user to add a type annotation to help us out.
Now for the check function.
pub fn check( tracked_expr: TrackedExpr, ty: Type, context: &mut HashMap<String, Type>, ) -> Result<(), TypeError> { let line = tracked_expr.line; let column = tracked_expr.column; match tracked_expr.expr { Expr::False | Expr::True => { if ty == Type::Bool { Ok(()) } else { Err(TypeError { kind: ErrorKind::CheckedWrongType { expected: ty, actual: Type::Bool, }, line, column, }) } }
Like before, we get the line and column for error reporting and match on the expression. If it's the literals false or true, then we need to make sure that the type we're checking is Bool. Otherwise, we return an error.
Expr::Abs(s, expr) => match ty { Type::Arrow(a, b) => { context.insert(s.clone(), *a.clone()); let res = check(*expr, *b, context); context.remove(&s); res } _ => Err(TypeError { kind: ErrorKind::CheckedWrongType { expected: ty, actual: Type::Arrow(Box::new(Type::Unknown), Box::new(Type::Unknown)), }, line, column, }), },
For abstractions, recall our rule -> Intr. We first make sure we've got an arrow type A -> B. Then, in the assumption, we have that L is of type B when we extend the context with x : A. So in the code, we first make sure we're checking an arrow type, and insert our String s to the context HashMap. We then call check to see whether the body of the abstraction expr has type B, remove s again, and return the result of the call to check.
It's worth noting here that we don't need to remove the identifier again if we weren't using a mutable reference to the HashMap. We could instead clone it every time we pass it to a recursive call (and presumably use a less expensive map, like a ListMap/VecMap).
_ => { let inferred_ty = infer(tracked_expr, context)?; if inferred_ty == ty { Ok(()) } else { Err(TypeError { kind: ErrorKind::CheckedWrongType { expected: ty, actual: inferred_ty, }, line, column, }) } } } }
Lastly, we have Ty Eq. It simply calls infer and checks that it inferred the same type as the one we're checking.
And that's it! We now have the two functions which will drive our typechecking algorithm. We can write a typecheck function like so:
pub fn typecheck(tracked_expr: TrackedExpr, ty: Type) -> Result<(), TypeError> { let mut context = HashMap::new(); check(tracked_expr, ty, &mut context) }
We can write some tests to show that our algorithm works correctly:
#[cfg(test)] pub fn typecheck_assert(tracked_expr: TrackedExpr, ty: Type) { let mut context = HashMap::new(); let checked = check(tracked_expr, ty, &mut context); assert!(checked.is_ok()); } #[cfg(test)] mod tests { use super::*; use crate::parser::parse; #[test] fn simple_typechecking() { let s = String::from("((\\x. x): Bool -> Bool) false"); let parsed = parse(s).unwrap(); typecheck_assert(parsed, Type::Bool); } #[test] fn wrongtype() { let s = String::from("((\\x. if x then false else true): Bool -> Bool -> Bool) false"); let parsed = parse(s).unwrap(); assert!(typecheck(parsed, Type::Bool).is_err()); } }
The implementation was pretty straight-forward because we did the difficult part (figuring out which direction the arrows should be pointing) before we went to write some code. An advantage we have over the Hindley-Milner inference algorithm is that we can give a more specific error than simply 'failed to unify types A and B'.
One of the commonly cited disadvantages of bidirectional type inference is that we need to sometimes include type annotations in places where it seems 'obvious'. However: if we extend our langauge with let-bindings with type annotations and force everything at the top level to be a let-binding, then we should almost never need another type annotation in the body of the let bindings! This restriction isn't unrealistic for a language. For instance, both Rust and C require all top-level declarations to be functions, constants, or user-constructed types (as well as a few other constructs like macros). Even Haskell, which doesn't require top-level functions to have type signatures, advises to use them anyway. Thus, this wouldn't be too annoying for the user.
Consider an extension to the language called if-then-else:
$$\begin{gathered} \dfrac {\Gamma \vdash L \colon \texttt{Bool} \quad \Gamma \vdash M \colon A \quad \Gamma \vdash N \colon A} {\Gamma \vdash \texttt{if} \, L \, \texttt{then} \, M \, \texttt{else} \, N \colon A} \; \text{(If)} \end{gathered}$$
L : Bool judgement checking or inferring. I have a slight preference for making it checking because that reduces the need for type annotations, but both are correct. As for the other rules, we aren't given the type B. So instead, we will need to be given the type B from elsewhere.
stlc-exercise. You will find IfThenElse added as a variant in the src/term.rs file. Now add a match arm to either the infer or check function in src/typechecking.rs so that the IfThenElse case is handled and the tests pass.
If is a checking rule (if not, check part 1's solution). So you need to add it to the check function.
After a bit of tinkering, you would have hopefully noticed that this rule works as both a checking and inferring rule. In cases where either option is sound, we can use another criteria: would it make writing code annoying? If we used inferring rules on the judgements here, it would be quite annoying. Imagine writing if (length l > 2): Bool then (l `at` 3): Bool else (false: Bool). So, let's make each judgement checking.
$$\begin{gathered} \dfrac {\Gamma \vdash L \Leftarrow \texttt{Bool} \quad \Gamma \vdash M \Leftarrow A \quad \Gamma \vdash N \Leftarrow A} {\Gamma \vdash \texttt{if} \, L \, \texttt{then} \, M \, \texttt{else} \, N \Leftarrow A} \; \text{(If)} \end{gathered}$$
stlc crate.Please be respectful. Any comment which I find objectionable will be removed. By submitting a comment, you agree to the privacy policy.
No comments yet.