Source code for pygsp.graphs.nngraphs.imgpatches

# -*- coding: utf-8 -*-

import numpy as np

from pygsp.graphs import NNGraph  # prevent circular import in Python < 3.5


[docs]class ImgPatches(NNGraph): r"""NN-graph between patches of an image. Extract a feature vector in the form of a patch for every pixel of an image, then construct a nearest-neighbor graph between these feature vectors. The feature matrix, i.e. the patches, can be found in :attr:`Xin`. Parameters ---------- img : array Input image. patch_shape : tuple, optional Dimensions of the patch window. Syntax: (height, width), or (height,), in which case width = height. kwargs : dict Parameters passed to :class:`NNGraph`. Notes ----- The feature vector of a pixel `i` will consist of the stacking of the intensity values of all pixels in the patch centered at `i`, for all color channels. So, if the input image has `d` color channels, the dimension of the feature vector of each pixel is (patch_shape[0] * patch_shape[1] * d). Examples -------- >>> import matplotlib.pyplot as plt >>> from skimage import data, img_as_float >>> img = img_as_float(data.camera()[::64, ::64]) >>> G = graphs.ImgPatches(img, patch_shape=(3, 3)) >>> print('{} nodes ({} x {} pixels)'.format(G.Xin.shape[0], *img.shape)) 64 nodes (8 x 8 pixels) >>> print('{} features per node'.format(G.Xin.shape[1])) 9 features per node >>> G.set_coordinates(kind='spring', seed=42) >>> fig, axes = plt.subplots(1, 2) >>> _ = axes[0].spy(G.W, markersize=2) >>> G.plot(ax=axes[1]) """ def __init__(self, img, patch_shape=(3, 3), **kwargs): self.img = img try: h, w, d = img.shape except ValueError: try: h, w = img.shape d = 0 except ValueError: print("Image should be at least a 2D array.") try: r, c = patch_shape except ValueError: r = patch_shape[0] c = r pad_width = [(int((r - 0.5) / 2.), int((r + 0.5) / 2.)), (int((c - 0.5) / 2.), int((c + 0.5) / 2.))] if d == 0: window_shape = (r, c) d = 1 # For the reshape in the return call else: pad_width += [(0, 0)] window_shape = (r, c, d) # Pad the image. img = np.pad(img, pad_width=pad_width, mode='symmetric') # Extract patches as node features. # Alternative: sklearn.feature_extraction.image.extract_patches_2d. # sklearn has much less dependencies than skimage. try: import skimage except Exception: raise ImportError('Cannot import skimage, which is needed to ' 'extract patches. Try to install it with ' 'pip (or conda) install scikit-image.') patches = skimage.util.view_as_windows(img, window_shape=window_shape) patches = patches.reshape((h * w, r * c * d)) super(ImgPatches, self).__init__(patches, gtype='patch-graph', **kwargs)