UNB/ CS/ David Bremner/ teaching/ cs3383/ lectures/ 24.0-demos/ union.py
#!/usr/bin/env python3
class Partition:
    def __init__(P,n):
      # sometimes called makeset(j)
      P.parent = [j for j in range(n)]
      P.rank = [0] * n

    def find(P, key):
      if P.parent[key] != key:
        P.parent[key] = P.find(P.parent[key])
      return P.parent[key]

    def union(P,x,y):
      rx = P.find(x)
      ry = P.find(y)
      if rx != ry:
        if P.rank[rx] > P.rank[ry]:
          P.parent[ry] = rx
        else:
          P.parent[rx] = ry
          if P.rank[rx] == P.rank[ry]:
            P.rank[ry] += 1

    def print(P,name):
        with open("{:s}.dot".format(name),'w') as f:
            print('digraph "{:s}" {{ rankdir=\"BT\"'.format(name),file=f)
            for i in range(len(P.rank)):
                print('{:d} [shape=none,label=< <table cellborder="0"><tr><td>{:d}</td></tr><tr><td><font color="blue" point-size="10">{:d}</font></td></tr></table> >]'.format(i,i,P.rank[i]),file=f)
            for i in range(len(P.parent)):
                if P.parent[i] != i:
                    print("{:d} -> {:d}".format(i,P.parent[i]),file=f)
            print("}",file=f)

if __name__ == "__main__":
    P=Partition(7)
    P.union(0,3)
    P.union(1,4)
    P.union(2,5)
    P.print("compress1")
    P.union(2,6)
    P.union(4,0)
    P.print("compress2")
    P.union(1,6)
    P.print("compress3")