class UnionFind:
def __init__(self):
self.parent = {}
self.size_of_set = {}
self.num_of_set = 0
def add(self, x):
if x in self.parent: return
# initialize parent
self.parent[x] = None
self.num_of_set += 1
self.size_of_set[x] = 1
def merge(self, x, y):
# find 2 nodes's root
root_x, root_y = self.find(x), self.find(y)
# merge if not the same root
if root_x != root_y:
if self.size_of_set[root_x] > self.size_of_set[root_y]:
# make sure root_y is the bigger one
# merge smaller tree to the bigger one
root_x, root_y = root_y, root_x
self.parent[root_x] = root_y
self.num_of_set -= 1
self.size_of_set[root_y] += self.size_of_set[root_x]
def find(self, x):
# root point to x, keep searching root's parents till find x's parent
root = x
while self.parent[root] != None:
root = self.parent[root]
# set all nodes to point to the same root
# this is called path compression.
# It makes the tree height <=3 after some calls.
# then find() will be O(1)
while x != root:
origin_parent = self.parent[x]
self.parent[x] = root
x = origin_parent
return root
def is_connected(self, x, y):
return self.find(x) == self.find(y)