Sum Types and Pattern Matching in Java

So you’re a Java programmer and while hanging out with your hip dev friends you’ve heard of sum types (aka union types or variants). Now you’re wondering if they make sense for you and how would you use them in Java. In this post I’ll explain the what and the how of sum types and their use in Java.

Product and Sum Types

Type theory says that you can model any real world data using two concepts: product and sum types. As an example, think of the following way of representing a person:

class Person {
  String name;
  int age;
}

Person is a product type between the type String and the type int, meaning that any combination of a string and and number are valid (of course within some limits for the age). The name product type comes from cartesian product, i.e., the number of values this type can take is the product of the number of possible values for string and the number of possible values for integer.

As another example, think of a binary tree in which each leaf holds a value. There are two types of nodes here:

One way to represent the Node type is this:

class Node<T> {
  NodeType nodeType;
  Node left;
  Node right;
  T value;
}

enum NodeType { Internal, Leaf }

The logic using this type would have to first check NodeType in order to know whether it should expect left and right to be valid (non-null) or value to be valid.

Now you might be thinking: “Yep, I’ve been doing this for ages and everything turned out hunky-dory”. But your new colleague Jimmy or your future self might be of a different opinion:

Can we do better than this? Yes, and you’ve probably guessed it: Node is a sum type between the type InternalNode and the type LeafNode. That is, it can be either one or the other, and depending on what it is, you can expect different data. InternalNode and LeafNode are also called the variants of the sum type Node.

Many languages have sum types as first class citizens. However, in Java you can still achieve the same by using inheritance. So, for our example:

interface Node<T> {
    class InternalNode<T> implements Node<T> {
        Node<T> left;
        Node<T> right;
    }

    class LeafNode<T> implements Node<T> {
        T value;
    }
}

This representation has several advantages:

Narrowing Down a Sum Type, aka Pattern Matching

Let’s say we want to write some algorithms using the Node type:

  1. compute the depth of the tree
  2. extract all the values from the tree
  3. print down the tree

All these need to narrow down the Node type. How do we do that? The recommended way is to use polymorphism in one way or another, or in other words: “thou shalt not cast”. If you are writing more than one algorithm, you’d probably use the visitor pattern. However, to me it seems that to use the visitor pattern you have to write quite a lot of plumbing for every little sum type you define.

I risk getting burned at the stake for this, but I think using instanceof and casting is not such a bad alternative. Actually many modern languages provide a concept equivalent to this. This concept is called pattern matching (well, pattern matching is a bit broader concept, but that’s for another post).

This is how we’d implement computing the tree depth:

static <V> int treeDepth(final Node<V> node) {
    if (node instanceof InternalNode) {
        InternalNode<V> internalNode = (InternalNode<V>) node;
        return 1 + max(treeDepth(internalNode.left), treeDepth(internalNode.right));
    } else if (node instanceof LeafNode) {
       return 1;
    } else {
        throw new RuntimeException("node type not known: " + node);
    }
}

I might be spared for doing this, because Java is also planned to introduce pattern matching for instanceof in the near future. I.e., if you use instanceof in a conditional, the sum type gets automatically casted to that variant, so no manual casting needed anymore.

Hiding the instanceof

All those instanceofs and casts are pretty ugly and poluting. Can we abstract them away somehow? Sure thing! We could make a method to which we provide:

This is how we’d implement treeDepth with such a construct:

import static java.lang.Integer.max;
import static patternmatch.patternmatch.PatternMatch.Case.exprCase;
import static patternmatch.patternmatch.PatternMatch.Default.exprDefault;
import static patternmatch.patternmatch.PatternMatch.match;

class Depth {
    private static <V> int treeDepth(final Node<V> node) {
        return match(node,
            exprCase(InternalNode.class, (InternalNode n) ->
                    1 + max(treeDepth(n.left), treeDepth(n.right))),
            exprCase(LeafNode.class, (LeafNode n) -> 1),
            exprDefault(() -> {
                throw new RuntimeException("node not known: " + node);
            }));
    }
}

As you can see, exprCase receives the variant to be matched (InternalNode or Leaf) and a lambda to be called when the variant matches. exprDefault gets a lambda to be called if the object doesn’t match InternalNode or LeafNode. In some languages it is possible to seal the sum type to certain variants, but in Java it’s impossible to stop someone from adding more variants. So adding a default is always a good idea, to detect that in the future.

So how can we implement the match method? I have a poor man’s implementation for it on github. It’s far from fancy, but it gets the job done. It basically provides overloaded versions of match for 1, 2, …, 7 variants. It’s trivial to add more if you need to.

Here’s a glimpse into what it looks like:

import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

public class PatternMatch {
    public static <T, V1 extends T, R> R match(final T t,
        final Case<V1, R> c1, final Default<R> defaultF) {
        return oMatch(t, c1).orElseGet(defaultF.f);
    }

    public static <T, V1 extends T, V2 extends T, R> R match(final T t,
        final Case<V1, R> c1, final Case<V2, R> c2,
        final Default<R> defaultF) {
        return oMatch(t, c1).orElseGet(() -> match(t, c2, defaultF));
    }

    ...

    private static <T, V extends T, R> Optional<R> oMatch(final T t,
        final Case<V, R> c) {
        return c.type.equals(t.getClass())
            ? Optional.of(c.f.apply(c.type.cast(t)))
            : Optional.empty();
    }

    public static class Case<V, R> {
        final Class<V> type;
        final Function<V, R> f;

        private Case(final Class<V> t, final Function<V, R> f) {
            this.type = t;
            this.f = f;
        }

        public static <V, R> Case<V, R> exprCase(final Class<V> t,
            final Function<V, R> f) {
            return new Case<>(t, f);
        }
        ...
    }

    public static class Default<R> {
        final Supplier<R> f;

        private Default(final Supplier<R> f) {
            this.f = f;
        }

        public static <R> Default<R> exprDefault(final Supplier<R> f) {
            return new Default<>(f);
        }

        ...
    }
}

Basically the method oMatch gets an object and a Case and if the object matches the variant from the case, it runs the lambda from the Case and returns the result of the lambda wrapped in an Optional. If not, it returns an empty Optional.

The match(o, case1, case2, ..., casen, defaultF) calls oMatch(o, case1). If that matches, the result is returned, otherwise match(o, case2, ..., casen, defaultF) is called. When match(o, casen, default) is reached and oMatch(o, casen) doesn’t match, we call defaultF.

As you’ve probably guessed, exprCase and exprDefault are useful when you want your match() to return a value, i.e., it’s an expression. If however you want to just pass lambdas that return void (e.g., if you want to print the sum type), you can use stmtCase and stmtDefault from the same class.

For a better example of how to use match() check Github.

Wrapping up

In this post you’ve learned about sum types and how to implement them in Java. Also we went through implementing a poor man’s version of pattern matching in Java. I hope that helps and see you next time.