# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import matplotlib.pyplot as plt def image_grid( images, rows=None, cols=None, fill: bool = True, show_axes: bool = False, rgb: bool = True, ): """ A util function for plotting a grid of images. Args: images: (N, H, W, 4) array of RGBA images rows: number of rows in the grid cols: number of columns in the grid fill: boolean indicating if the space between images should be filled show_axes: boolean indicating if the axes of the plots should be visible rgb: boolean, If True, only RGB channels are plotted. If False, only the alpha channel is plotted. Returns: None """ if (rows is None) != (cols is None): raise ValueError("Specify either both rows and cols or neither.") if rows is None: rows = len(images) cols = 1 gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {} fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9)) bleed = 0 fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)) for ax, im in zip(axarr.ravel(), images): if rgb: # only render RGB channels ax.imshow(im[..., :3]) else: # only render Alpha channel ax.imshow(im[..., 3]) if not show_axes: ax.set_axis_off()