Source code for pygsp.graphs.nngraphs.grid2dimgpatches

# -*- coding: utf-8 -*-

# prevent circular import in Python < 3.5
from pygsp.graphs import Graph, Grid2d, ImgPatches


[docs]class Grid2dImgPatches(Graph): r"""Union of a patch graph with a 2D grid graph. Parameters ---------- img : array Input image. patch_shape : tuple, optional Dimensions of the patch window. Syntax : (height, width). n_nbrs : int Number of neighbors to consider dist_type : string Type of distance between patches to compute. See :func:`pyflann.index.set_distance_type` for possible options. aggregate: callable, optional Function used for aggregating the weights Wp of the patch graph and the weigths Wg 2d grid graph. Default is :func:`lambda Wp, Wg: Wp + Wg`. Examples -------- >>> import matplotlib.pyplot as plt >>> from skimage import data, img_as_float >>> img = img_as_float(data.camera()[::64, ::64]) >>> G = graphs.Grid2dImgPatches(img, use_flann=False) >>> fig, axes = plt.subplots(1, 2) >>> _ = axes[0].spy(G.W, markersize=2) >>> G.plot(ax=axes[1]) """ def __init__(self, img, patch_shape=(3, 3), n_nbrs=8, aggregate=lambda Wp, Wg: Wp + Wg, **kwargs): Gg = Grid2d(img.shape[0], img.shape[1], **kwargs) Gp = ImgPatches(img, patch_shape=patch_shape, n_nbrs=n_nbrs, **kwargs) gtype = '{}_{}'.format(Gg.gtype, Gp.gtype) super(Grid2dImgPatches, self).__init__(W=aggregate(Gp.W, Gg.W), gtype=gtype, coords=Gg.coords, plotting=Gg.plotting, perform_all_checks=False, **kwargs)