"""
==========================================
Find the intersection of two segmentations
==========================================

When segmenting an image, you may want to combine multiple alternative
segmentations. The `skimage.segmentation.join_segmentations` function
computes the join of two segmentations, in which a pixel is placed in
the same segment if and only if it is in the same segment in _both_
segmentations.
"""

import numpy as np
from scipy import ndimage as nd
import matplotlib.pyplot as plt
import matplotlib as mpl

from skimage.filter import sobel
from skimage.segmentation import slic, join_segmentations
from skimage.morphology import watershed

from skimage import data

coins = data.coins()

# make segmentation using edge-detection and watershed
edges = sobel(coins)
markers = np.zeros_like(coins)
foreground, background = 1, 2
markers[coins < 30] = background
markers[coins > 150] = foreground

ws = watershed(edges, markers)
seg1 = nd.label(ws == foreground)[0]

# make segmentation using SLIC superpixels

# make the RGB equivalent of `coins`
coins_colour = np.tile(coins[..., np.newaxis], (1, 1, 3))
seg2 = slic(coins_colour, n_segments=30, max_iter=160, sigma=1, ratio=9,
            convert2lab=False)

# combine the two
segj = join_segmentations(seg1, seg2)

### Display the result ###

# make a random colormap for a set number of values
def random_cmap(im):
    np.random.seed(9)
    cmap_array = np.concatenate(
        (np.zeros((1, 3)), np.random.rand(np.ceil(im.max()), 3)))
    return mpl.colors.ListedColormap(cmap_array)

# show the segmentations
fig, axes = plt.subplots(ncols=4, figsize=(9, 2.5))
axes[0].imshow(coins, cmap=plt.cm.gray, interpolation='nearest')
axes[0].set_title('Image')
axes[1].imshow(seg1, cmap=random_cmap(seg1), interpolation='nearest')
axes[1].set_title('Sobel+Watershed')
axes[2].imshow(seg2, cmap=random_cmap(seg2), interpolation='nearest')
axes[2].set_title('SLIC superpixels')
axes[3].imshow(segj, cmap=random_cmap(segj), interpolation='nearest')
axes[3].set_title('Join')

for ax in axes:
    ax.axis('off')
plt.subplots_adjust(hspace=0.01, wspace=0.01, top=1, bottom=0, left=0, right=1)
plt.show()
