The union-find is a classic data structure used in lots of algorithms, for instance, cycle detection in undirected graph, Tarjan lowest common ancestor.  The code illustated in this blog passes AOJ Disjoint Set: Union Find Tree and available at Github AlgImpl.

## The Problem

Given n elements, initially each element is in a different set, {1}, {2}, …, {n}.
Two operations are supported on the sets by union-find, which are union and find.

• A union operation combines two sets into one.
• A find operation identifies the set that contains a particular element.

An intermixed sequence of union and find operations is performed. We need to devise most efficient algorithm for union and find operations.

## Set Representation

An int array id is a parent-link reprensentation of a forest of trees. A concrete example is shown below

 idx 0 1 2 3 4 5 6 7 8 9 parent 0 1 1 8 3 0 5 7 8 8

Thus, each set can be represented by the ancestor of the tree. This is the meaning of what Find returns.

## Optimizations

Two optimizations are applied to boost performance.

1. Make trees Weighted in order to append lighter tree to heavier one.
2. Path compression to flatten tree

These two tricks result in amortized time complexity of find and union to be almost O(1), or to be more exact, inverse of Ackermann’s function.

## Full Java Code

public class UnionFind {
int[] id;
int count;
int[] weight;  // size indexed by root id

public UnionFind(int n) {
id = new int[n];
weight = new int[n];
count = n;
for (int idx = 0; idx < id.length; idx++) {
id[idx] = idx;
weight[idx] = 1;
}
}

public void union(int p, int q) {
int pRoot = find(p);
int qRoot = find(q);
if (pRoot == qRoot) {
return;
}
// make smaller root point to larger one
if (weight[pRoot] < weight[qRoot]) {
id[pRoot] = qRoot;
weight[qRoot] += weight[pRoot];
} else {
id[qRoot] = pRoot;
weight[pRoot] += weight[qRoot];
}
count--;
}

// path compression
public int find(int p) {
if (id[p] != p) {
id[p] = find(id[p]);
}
return id[p];
}

public boolean connected(int p, int q) {
return find(p) == find(q);
}

public int count() {
return count;
}
}