Sum Types and Pattern Matching in Java
08 Sep 2019So you’re a Java programmer and while hanging out with your hip dev friends you’ve heard of sum types (aka union types or variants). Now you’re wondering if they make sense for you and how would you use them in Java. In this post I’ll explain the what and the how of sum types and their use in Java.
Product and Sum Types
Type theory says that you can model any real world data using two concepts: product and sum types. As an example, think of the following way of representing a person:
class Person {
String name;
int age;
}
Person is a product type between the type String and the type int, meaning that
any combination of a string and and number are valid (of course within some limits for the age).
The name product type comes from cartesian product,
i.e., the number of values this type can take is the product of the number
of possible values for string and the number of possible values for
integer.
As another example, think of a binary tree in which each leaf holds a value. There are two types of nodes here:
- internal nodes which have two child nodes (
leftandright) - leaf nodes that have a
value
One way to represent the Node type is this:
class Node<T> {
NodeType nodeType;
Node left;
Node right;
T value;
}
enum NodeType { Internal, Leaf }
The logic using this type would have to first check NodeType in order to know
whether it should expect left and right to be valid (non-null) or value to be
valid.
Now you might be thinking: “Yep, I’ve been doing this for ages and everything turned out hunky-dory”. But your new colleague Jimmy or your future self might be of a different opinion:
- By looking at this type they don’t know which data they should expect. They think all the combinations are valid. They have to read your other code to understand your intentions.
- Even if they get the drill, they are still humans and make mistakes by the
accessing the
valuedata even if they’ve checked this to be an internal node. The compiler smiles happily though instead of helping.
Can we do better than this? Yes, and you’ve probably guessed it: Node is a sum
type between the type InternalNode and the
type LeafNode. That is, it can be either one or the other, and depending on what it
is, you can expect different data. InternalNode and LeafNode are also called the
variants of the sum type Node.
Many languages have sum types as first class citizens. However, in Java you can still achieve the same by using inheritance. So, for our example:
interface Node<T> {
class InternalNode<T> implements Node<T> {
Node<T> left;
Node<T> right;
}
class LeafNode<T> implements Node<T> {
T value;
}
}
This representation has several advantages:
- by just looking at the type you can see what data is expected in each variant
- you are forced to narrow down the type
Nodeto one of its variants,InternalNodeorLeafNode, before accessing the data. Otherwise, the compiler will bark.
Narrowing Down a Sum Type, aka Pattern Matching
Let’s say we want to write some algorithms using the Node type:
- compute the depth of the tree
- extract all the values from the tree
- print down the tree
All these need to narrow down the Node type. How do we do that? The recommended
way is to use polymorphism in one
way or another, or in other words: “thou shalt not cast”. If you are writing more than
one algorithm, you’d probably use
the visitor pattern. However, to me it
seems that to use the visitor pattern you have to write quite a lot of plumbing for every little sum type you define.
I risk getting burned at the stake for this, but I think using instanceof
and casting is not such a bad alternative. Actually many modern languages
provide a concept equivalent to this. This concept is called pattern matching (well,
pattern matching is a bit broader concept, but that’s for another post).
This is how we’d implement computing the tree depth:
static <V> int treeDepth(final Node<V> node) {
if (node instanceof InternalNode) {
InternalNode<V> internalNode = (InternalNode<V>) node;
return 1 + max(treeDepth(internalNode.left), treeDepth(internalNode.right));
} else if (node instanceof LeafNode) {
return 1;
} else {
throw new RuntimeException("node type not known: " + node);
}
}
I might be spared for doing this, because Java is also planned to introduce
pattern matching for instanceof in the near future. I.e.,
if you use instanceof in a
conditional, the sum type gets automatically casted to that variant, so no manual casting
needed anymore.
Hiding the instanceof
All those instanceofs and casts are pretty ugly and poluting. Can we abstract them
away somehow? Sure thing! We could make a method to which we provide:
- an object of the sum type to match
- a lambda for each variant of the sum type. It will be called if the object matches that variant.
- a default lambda to be called if the object doesn’t match any of the provided variants
This is how we’d implement treeDepth with such a construct:
import static java.lang.Integer.max;
import static patternmatch.patternmatch.PatternMatch.Case.exprCase;
import static patternmatch.patternmatch.PatternMatch.Default.exprDefault;
import static patternmatch.patternmatch.PatternMatch.match;
class Depth {
private static <V> int treeDepth(final Node<V> node) {
return match(node,
exprCase(InternalNode.class, (InternalNode n) ->
1 + max(treeDepth(n.left), treeDepth(n.right))),
exprCase(LeafNode.class, (LeafNode n) -> 1),
exprDefault(() -> {
throw new RuntimeException("node not known: " + node);
}));
}
}
As you can see, exprCase receives the variant to be matched (InternalNode or Leaf)
and a lambda to be called when the variant matches. exprDefault gets a lambda to be
called if the object doesn’t match InternalNode or LeafNode. In some languages
it is possible to seal the sum type to certain variants, but in Java it’s impossible
to stop someone from adding more variants. So adding a default is always a good idea,
to detect that in the future.
So how can we implement the match method? I have a poor man’s implementation for it
on github. It’s far from fancy, but it gets the job done. It basically
provides overloaded versions of match for 1, 2, …, 7 variants. It’s trivial
to add more if you need to.
Here’s a glimpse into what it looks like:
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
public class PatternMatch {
public static <T, V1 extends T, R> R match(final T t,
final Case<V1, R> c1, final Default<R> defaultF) {
return oMatch(t, c1).orElseGet(defaultF.f);
}
public static <T, V1 extends T, V2 extends T, R> R match(final T t,
final Case<V1, R> c1, final Case<V2, R> c2,
final Default<R> defaultF) {
return oMatch(t, c1).orElseGet(() -> match(t, c2, defaultF));
}
...
private static <T, V extends T, R> Optional<R> oMatch(final T t,
final Case<V, R> c) {
return c.type.equals(t.getClass())
? Optional.of(c.f.apply(c.type.cast(t)))
: Optional.empty();
}
public static class Case<V, R> {
final Class<V> type;
final Function<V, R> f;
private Case(final Class<V> t, final Function<V, R> f) {
this.type = t;
this.f = f;
}
public static <V, R> Case<V, R> exprCase(final Class<V> t,
final Function<V, R> f) {
return new Case<>(t, f);
}
...
}
public static class Default<R> {
final Supplier<R> f;
private Default(final Supplier<R> f) {
this.f = f;
}
public static <R> Default<R> exprDefault(final Supplier<R> f) {
return new Default<>(f);
}
...
}
}
Basically the method oMatch gets an object and a Case and if the object
matches the variant from the case, it runs the lambda from the Case and returns the result
of the lambda wrapped in an Optional. If not, it returns an empty Optional.
The match(o, case1, case2, ..., casen, defaultF) calls oMatch(o, case1).
If that matches, the result is returned, otherwise
match(o, case2, ..., casen, defaultF) is called. When match(o, casen, default) is reached and oMatch(o, casen) doesn’t match, we call defaultF.
As you’ve probably guessed, exprCase and exprDefault are useful when you want your
match() to return a value, i.e., it’s an expression. If however you want to just
pass lambdas that return void (e.g., if you want to print the sum type), you can use
stmtCase and stmtDefault from the same class.
For a better example of how to use match() check Github.
Wrapping up
In this post you’ve learned about sum types and how to implement them in Java. Also we went through implementing a poor man’s version of pattern matching in Java. I hope that helps and see you next time.