Python mpl_toolkits.axes_grid1.ImageGrid() Examples

The following are 23 code examples of mpl_toolkits.axes_grid1.ImageGrid(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module mpl_toolkits.axes_grid1 , or try the search function .
Example #1
Source File: utils.py    From geoist with MIT License 6 votes vote down vote up
def plot_matrix(A,cbar_location='right',figsize=(18,18),cmap='coolwarm',fname=None):
    fig = plt.figure(figsize=figsize)
    axes = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(1,1),
                 axes_pad=2.0,
                 add_all=True,
                 label_mode="L",
                 cbar_mode = 'each',
                 cbar_location = cbar_location,
                 cbar_pad='2%'
                 )
    im = axes[0].imshow(A,cmap=cmap,interpolation='none')
    axes.cbar_axes[0].colorbar(im)
    if fname is None:
        plt.show()
    else:
        plt.savefig(fname)
    return fig,axes 
Example #2
Source File: demo_axes_grid.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def demo_simple_grid(fig):
    """
    A grid of 2x2 images with 0.05 inch pad between images and only
    the lower-left axes is labeled.
    """
    grid = ImageGrid(fig, 141,  # similar to subplot(141)
                     nrows_ncols=(2, 2),
                     axes_pad=0.05,
                     label_mode="1",
                     )

    Z, extent = get_demo_image()
    for i in range(4):
        im = grid[i].imshow(Z, extent=extent, interpolation="nearest")

    # This only affects axes in first column and second row as share_all =
    # False.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2]) 
Example #3
Source File: demo_axes_grid.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def demo_grid_with_single_cbar(fig):
    """
    A grid of 2x2 images with a single colorbar
    """
    grid = ImageGrid(fig, 142,  # similar to subplot(142)
                     nrows_ncols=(2, 2),
                     axes_pad=0.0,
                     share_all=True,
                     label_mode="L",
                     cbar_location="top",
                     cbar_mode="single",
                     )

    Z, extent = get_demo_image()
    for i in range(4):
        im = grid[i].imshow(Z, extent=extent, interpolation="nearest")
    grid.cbar_axes[0].colorbar(im)

    for cax in grid.cbar_axes:
        cax.toggle_label(False)

    # This affects all axes as share_all = True.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2]) 
Example #4
Source File: utils.py    From cycle-consistent-vae with MIT License 6 votes vote down vote up
def imshow_grid(images, shape=[2, 8], name='default', save=False):
    """
    Plot images in a grid of a given shape.
    Initial code from: https://github.com/pumpikano/tf-dann/blob/master/utils.py
    """
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in range(size):
        grid[i].axis('off')
        grid[i].imshow(images[i])  # The AxesGrid object work as a list of axes.

    if save:
        plt.savefig('reconstructed_images/' + str(name) + '.png')
        plt.clf()
    else:
        plt.show() 
Example #5
Source File: grid_plots.py    From IMPLEMENTATION_Variational-Auto-Encoder with MIT License 6 votes vote down vote up
def show_samples(images, row, col, image_shape, name="Unknown", save=True, shift=False):
    num_images = row*col
    if shift:
        images = (images+1.)/2.
    fig = plt.figure(figsize=(col, row))
    grid = ImageGrid(fig, 111,
                     nrows_ncols=(row, col),
                     axes_pad=0.)
    for i in xrange(num_images):
        im = images[i].reshape(image_shape)
        axis = grid[i]
        axis.axis('off')
        axis.imshow(im)
    plt.axis('off')
    plt.tight_layout()
    if save:
        fig.savefig('figs/train/grid/'+name+'.png', bbox_inches="tight", pad_inches=0, format='png')
    else:
        plt.show()


#From some github code 
Example #6
Source File: utils.py    From disentangling-factors-of-variation-using-adversarial-training with MIT License 6 votes vote down vote up
def imshow_grid(images, shape=[2, 8], name='default', save=False):
    """
    Plot images in a grid of a given shape.
    Initial code from: https://github.com/pumpikano/tf-dann/blob/master/utils.py
    """
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in range(size):
        grid[i].axis('off')
        grid[i].imshow(images[i])  # The AxesGrid object work as a list of axes.

    if save:
        plt.savefig('reconstructed_images/' + str(name) + '.png')
        plt.clf()
    else:
        plt.show() 
Example #7
Source File: utils.py    From discgen with MIT License 5 votes vote down vote up
def plot_image_grid(images, num_rows, num_cols, save_path=None):
    """Plots images in a grid.

    Parameters
    ----------
    images : numpy.ndarray
        Images to display, with shape
        ``(num_rows * num_cols, num_channels, height, width)``.
    num_rows : int
        Number of rows for the image grid.
    num_cols : int
        Number of columns for the image grid.
    save_path : str, optional
        Where to save the image grid. Defaults to ``None``,
        which causes the grid to be displayed on screen.

    """
    figure = pyplot.figure()
    grid = ImageGrid(figure, 111, (num_rows, num_cols), axes_pad=0.1)

    for image, axis in zip(images, grid):
        axis.imshow(image.transpose(1, 2, 0), interpolation='nearest')
        axis.set_yticklabels(['' for _ in range(image.shape[1])])
        axis.set_xticklabels(['' for _ in range(image.shape[2])])
        axis.axis('off')

    if save_path is None:
        pyplot.show()
    else:
        pyplot.savefig(save_path, transparent=True, bbox_inches='tight') 
Example #8
Source File: TensorFlowInterface.py    From IntroToDeepLearning with MIT License 5 votes vote down vote up
def plotFields(layer,fieldShape=None,channel=None,figOffset=1,cmap=None,padding=0.01):
	# Receptive Fields Summary
	try:
		W = layer.W
	except:
		W = layer
	wp = W.eval().transpose();
	if len(np.shape(wp)) < 4:		# Fully connected layer, has no shape
		fields = np.reshape(wp,list(wp.shape[0:-1])+fieldShape)	
	else:			# Convolutional layer already has shape
		features, channels, iy, ix = np.shape(wp)
		if channel is not None:
			fields = wp[:,channel,:,:]
		else:
			fields = np.reshape(wp,[features*channels,iy,ix])

	perRow = int(math.floor(math.sqrt(fields.shape[0])))
	perColumn = int(math.ceil(fields.shape[0]/float(perRow)))

	fig = mpl.figure(figOffset); mpl.clf()
	
	# Using image grid
	from mpl_toolkits.axes_grid1 import ImageGrid
	grid = ImageGrid(fig,111,nrows_ncols=(perRow,perColumn),axes_pad=padding,cbar_mode='single')
	for i in range(0,np.shape(fields)[0]):
		im = grid[i].imshow(fields[i],cmap=cmap); 

	grid.cbar_axes[0].colorbar(im)
	mpl.title('%s Receptive Fields' % layer.name)
	
	# old way
	# fields2 = np.vstack([fields,np.zeros([perRow*perColumn-fields.shape[0]] + list(fields.shape[1:]))])
	# tiled = []
	# for i in range(0,perColumn*perRow,perColumn):
	# 	tiled.append(np.hstack(fields2[i:i+perColumn]))
	# 
	# tiled = np.vstack(tiled)
	# mpl.figure(figOffset); mpl.clf(); mpl.imshow(tiled,cmap=cmap); mpl.title('%s Receptive Fields' % layer.name); mpl.colorbar();
	mpl.figure(figOffset+1); mpl.clf(); mpl.imshow(np.sum(np.abs(fields),0),cmap=cmap); mpl.title('%s Total Absolute Input Dependency' % layer.name); mpl.colorbar() 
Example #9
Source File: pfmodel_ts.py    From geoist with MIT License 5 votes vote down vote up
def plot_field(self,field=None,surveys=None,fname=None,plot_station=True):
        if surveys is None:
            surveys = range(self.ns)
        if field is None:
            obs_g = self.orig_data['g']
        else:
            obs_g = pd.Series(field,index=self.orig_data.index)
        fig = plt.figure(figsize=(10, 10))
        if self.cell_type == 'prism':
            axis_order = ['y','x']
        elif self.cell_type == 'tesseroid':
            axis_order = ['lon','lat']
        nrows = int(np.ceil(np.sqrt(len(surveys))))
        grid = ImageGrid(fig, 111,
                        nrows_ncols=(nrows, nrows),
                        axes_pad=0.05,
                        cbar_mode='single',
                        cbar_location='right',
                        cbar_pad=0.1
                        )
        for ind,i_survey in enumerate(surveys):
            grid[ind].set_axis_off()
            tmp = self.orig_data[self.orig_data['i_survey']==ind]
            x = tmp[axis_order[0]].values
            y = tmp[axis_order[1]].values
            g = obs_g[self.orig_data['i_survey']==ind].values
            im = grid[ind].tricontourf(x, y, g, 20)
            if plot_station:
                im2 = grid[ind].scatter(x,y)
        cbar = grid.cbar_axes[0].colorbar(im)
        if fname is None:
            plt.show()
        else:
            plt.savefig(fname,dpi=150) 
Example #10
Source File: pfmodel_ts.py    From geoist with MIT License 5 votes vote down vote up
def plot_density(self,density=None,surveys=None,fname=None):
        if surveys is None:
            surveys = range(self.ns)
        fig = plt.figure(figsize=(10, 10))
        nrows = int(np.ceil(np.sqrt(len(surveys))))
        grid = ImageGrid(fig, 111,
                        nrows_ncols=(nrows, nrows),
                        axes_pad=0.05,
                        cbar_mode='single',
                        cbar_location='right',
                        cbar_pad=0.1
                        )
        if density is None:
            x = self.solution.reshape(self.ns,self.ny,self.nx)
        else:
            x = density.reshape(self.ns,self.ny,self.nx)
        if self.cell_type == 'prism':
            #rint(x.shape)
            x = np.transpose(x,axes=[0,2,1]) #axis //chenshi
        for ind,i_survey in enumerate(surveys):
            grid[ind].set_axis_off()
            im = grid[ind].imshow(x[i_survey],origin='lower')
        cbar = grid.cbar_axes[0].colorbar(im)
        if fname is None:
            plt.show()
        else:
            plt.savefig(fname,dpi=150) 
Example #11
Source File: utils.py    From multi-level-vae with MIT License 5 votes vote down vote up
def imshow_grid(images, shape=[2, 8], name='default', save=False):
    """Plot images in a grid of a given shape."""
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in range(size):
        grid[i].axis('off')
        grid[i].imshow(images[i])  # The AxesGrid object work as a list of axes.

    if save:
        plt.savefig('reconstructed_images/' + str(name) + '.png')
        plt.clf()
    else:
        plt.show() 
Example #12
Source File: discgen_utils.py    From Neural-Photo-Editor with MIT License 5 votes vote down vote up
def plot_image_grid(images, num_rows, num_cols, save_path=None):
    """Plots images in a grid.

    Parameters
    ----------
    images : numpy.ndarray
        Images to display, with shape
        ``(num_rows * num_cols, num_channels, height, width)``.
    num_rows : int
        Number of rows for the image grid.
    num_cols : int
        Number of columns for the image grid.
    save_path : str, optional
        Where to save the image grid. Defaults to ``None``,
        which causes the grid to be displayed on screen.

    """
    figure = pyplot.figure()
    grid = ImageGrid(figure, 111, (num_rows, num_cols), axes_pad=0.1)

    for image, axis in zip(images, grid):
        axis.imshow(image.transpose(1, 2, 0), interpolation='nearest')
        axis.set_yticklabels(['' for _ in range(image.shape[1])])
        axis.set_xticklabels(['' for _ in range(image.shape[2])])
        axis.axis('off')

    if save_path is None:
        pyplot.show()
    else:
        pyplot.savefig(save_path, transparent=True, bbox_inches='tight',dpi=212)
        pyplot.close() 
Example #13
Source File: utils_wgan.py    From Pose-Guided-Person-Image-Generation with MIT License 5 votes vote down vote up
def save_imshow_grid(images, train_dir, filename, shape):
    """
    Plot images in a grid of a given shape.
    """
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in trange(size, desc="Saving images"):
        grid[i].axis('off')
        grid[i].imshow(images[i])

    plt.savefig(os.path.join(train_dir, filename)) 
Example #14
Source File: wgan_utils.py    From ambient-gan with MIT License 5 votes vote down vote up
def save_imshow_grid(images, logs_dir, filename, shape):
    """
    Plot images in a grid of a given shape.
    """
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in trange(size, desc="Saving images"):
        grid[i].axis('off')
        grid[i].imshow(images[i])

    plt.savefig(os.path.join(logs_dir, filename)) 
Example #15
Source File: utils.py    From WassersteinGAN.tensorflow with MIT License 5 votes vote down vote up
def save_imshow_grid(images, logs_dir, filename, shape):
    """
    Plot images in a grid of a given shape.
    """
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)

    size = shape[0] * shape[1]
    for i in trange(size, desc="Saving images"):
        grid[i].axis('off')
        grid[i].imshow(images[i])

    plt.savefig(os.path.join(logs_dir, filename)) 
Example #16
Source File: utils.py    From MagnetLoss-PyTorch with MIT License 5 votes vote down vote up
def show_images(H):
    # make a square grid
    num = H.shape[0]
    rows = int(np.ceil(np.sqrt(float(num))))

    fig = plt.figure(1, [10, 10])
    grid = ImageGrid(fig, 111, nrows_ncols=[rows, rows])

    for i in range(num):
        grid[i].axis('off')
        grid[i].imshow(H[i], cmap='Greys')

    # Turn any unused axes off
    for j in range(i, len(grid)):
        grid[j].axis('off') 
Example #17
Source File: demo_axes_grid.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def demo_grid_with_each_cbar_labelled(fig):
    """
    A grid of 2x2 images. Each image has its own colorbar.
    """

    grid = ImageGrid(fig, 144,  # similar to subplot(144)
                     nrows_ncols=(2, 2),
                     axes_pad=(0.45, 0.15),
                     label_mode="1",
                     share_all=True,
                     cbar_location="right",
                     cbar_mode="each",
                     cbar_size="7%",
                     cbar_pad="2%",
                     )
    Z, extent = get_demo_image()

    # Use a different colorbar range every time
    limits = ((0, 1), (-2, 2), (-1.7, 1.4), (-1.5, 1))
    for i in range(4):
        im = grid[i].imshow(Z, extent=extent, interpolation="nearest",
                            vmin=limits[i][0], vmax=limits[i][1])
        grid.cbar_axes[i].colorbar(im)

    for i, cax in enumerate(grid.cbar_axes):
        cax.set_yticks((limits[i][0], limits[i][1]))

    # This affects all axes because we set share_all = True.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2]) 
Example #18
Source File: demo_axes_grid.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def demo_grid_with_each_cbar(fig):
    """
    A grid of 2x2 images. Each image has its own colorbar.
    """

    grid = ImageGrid(fig, 143,  # similar to subplot(143)
                     nrows_ncols=(2, 2),
                     axes_pad=0.1,
                     label_mode="1",
                     share_all=True,
                     cbar_location="top",
                     cbar_mode="each",
                     cbar_size="7%",
                     cbar_pad="2%",
                     )
    Z, extent = get_demo_image()
    for i in range(4):
        im = grid[i].imshow(Z, extent=extent, interpolation="nearest")
        grid.cbar_axes[i].colorbar(im)

    for cax in grid.cbar_axes:
        cax.toggle_label(False)

    # This affects all axes because we set share_all = True.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2]) 
Example #19
Source File: utlis.py    From deepJDOT with MIT License 5 votes vote down vote up
def imshow_grid(images, shape=[2, 8]):
    """Plot images in a grid of a given shape."""
    fig = plt.figure(1)
    grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)
    n_dim = np.shape(images)
    size = shape[0] * shape[1]
    for i in range(size):
        grid[i].axis('off')
        if len(n_dim)<=3:
           grid[i].imshow(images[i], cmap=plt.get_cmap('gray'))  # The AxesGrid object work as a list of axes.
        else:
           grid[i].imshow(images[i]) 
        
        
    plt.show() 
Example #20
Source File: TensorFlowInterface.py    From IntroToDeepLearning with MIT License 5 votes vote down vote up
def plotFields(layer,fieldShape=None,channel=None,maxFields=25,figName='ReceptiveFields',cmap=None,padding=0.01):
	# Receptive Fields Summary
	W = layer.W
	wp = W.eval().transpose();
	if len(np.shape(wp)) < 4:		# Fully connected layer, has no shape
		fields = np.reshape(wp,list(wp.shape[0:-1])+fieldShape)
	else:			# Convolutional layer already has shape
		features, channels, iy, ix = np.shape(wp)
		if channel is not None:
			fields = wp[:,channel,:,:]
		else:
			fields = np.reshape(wp,[features*channels,iy,ix])

	fieldsN = min(fields.shape[0],maxFields)
	perRow = int(math.floor(math.sqrt(fieldsN)))
	perColumn = int(math.ceil(fieldsN/float(perRow)))

	fig = mpl.figure(figName); mpl.clf()

	# Using image grid
	from mpl_toolkits.axes_grid1 import ImageGrid
	grid = ImageGrid(fig,111,nrows_ncols=(perRow,perColumn),axes_pad=padding,cbar_mode='single')
	for i in range(0,fieldsN):
		im = grid[i].imshow(fields[i],cmap=cmap);

	grid.cbar_axes[0].colorbar(im)
	mpl.title('%s Receptive Fields' % layer.name)

	# old way
	# fields2 = np.vstack([fields,np.zeros([perRow*perColumn-fields.shape[0]] + list(fields.shape[1:]))])
	# tiled = []
	# for i in range(0,perColumn*perRow,perColumn):
	# 	tiled.append(np.hstack(fields2[i:i+perColumn]))
	#
	# tiled = np.vstack(tiled)
	# mpl.figure(figOffset); mpl.clf(); mpl.imshow(tiled,cmap=cmap); mpl.title('%s Receptive Fields' % layer.name); mpl.colorbar();
	mpl.figure(figName+' Total'); mpl.clf(); mpl.imshow(np.sum(np.abs(fields),0),cmap=cmap); mpl.title('%s Total Absolute Input Dependency' % layer.name); mpl.colorbar() 
Example #21
Source File: utils.py    From geoist with MIT License 4 votes vote down vote up
def plot_kernel(ggz,nxyz=None,nobs=None,obs_extent=(-100,100,-100,100),image_grid=(3,5),fname=None):
    '''inspect the kernel matrix

    Args:
        ggz (ndarray): Kernel matrix. Each column correspond to a source point. Each row correspond to an observe station.
        nxyz (tuple): How many source points along x,y,z axis respectively.
        nobs (tuple): How many observe stations along x,y axis respectively.
        obs_extent (tuple): Define the observe area in a order of (min_x,max_x,min_y,max_y).
        image_grid (tuple): Define the dimension of image grid.
    '''
    if nxyz is None:
        nx = ny = nz = int(round(np.power(ggz.shape[1],1./3.)))
    else:
        nx,ny,nz = nxyz
    if nobs is None:
        obs_x = obs_y = int(round(np.power(ggz.shape[0],1./2.)))
    else:
        obs_x,obs_y = nobs
    n_rows,n_cols = image_grid
    fmt = ticker.ScalarFormatter()
    fmt.set_powerlimits((0,0))
    fig = plt.figure(figsize=(15,12))
    iz = np.linspace(0,nz-1,n_rows).astype(np.int32)
    ind = []
    for i in iz:
        ind.extend([i*nx*ny,i*nx*ny+nx-1,i*nx*ny+nx*ny//2+nx//2,(i+1)*nx*ny-nx,(i+1)*nx*ny-1])
    axes = ImageGrid(fig, 111,  # similar to subplot(111)
                     nrows_ncols=(n_rows,n_cols),
                     axes_pad=0.5,
                     add_all=True,
                     label_mode="L",
                     cbar_mode = 'each',
                     cbar_location = 'right',
                     cbar_pad='30%'
                     )
    i = 0
    for row in range(n_rows):
        for col in range(n_cols):
            ixyz = get_prism_pos(int(ind[i]),nx,ny,nz)
            im = axes[col+row*n_cols].imshow(ggz[:,int(ind[i])].reshape(-1,obs_x).transpose(),extent=obs_extent,origin='lower')
            axes[col+row*n_cols].set_title('layer {} of {}'.format(ixyz[2],nz))
            i += 1
            axes.cbar_axes[col+row*n_cols].colorbar(im,format=fmt)
    if fname is None:
        plt.show()
    else:
        plt.savefig(fname) 
Example #22
Source File: plot.py    From adversarial-autoencoder with MIT License 4 votes vote down vote up
def plot_latent_space(weightsfile):
    print('building model')
    layers = model.build_model()
    batch_size = 128
    decoder_func = theano_funcs.create_decoder_func(layers)

    print('loading weights from %s' % (weightsfile))
    model.load_weights([
        layers['l_decoder_out'],
        layers['l_discriminator_out'],
    ], weightsfile)

    # regularly-spaced grid of points sampled from p(z)
    Z = np.mgrid[2:-2.2:-0.2, -2:2.2:0.2].reshape(2, -1).T[:, ::-1].astype(np.float32)

    reconstructions = []
    print('generating samples')
    for idx in get_batch_idx(Z.shape[0], batch_size):
        Z_batch = Z[idx]
        X_batch = decoder_func(Z_batch)
        reconstructions.append(X_batch)

    X = np.vstack(reconstructions)
    X = X.reshape(X.shape[0], 28, 28)

    fig = plt.figure(1, (12., 12.))
    ax1 = plt.axes(frameon=False)
    ax1.get_xaxis().set_visible(False)
    ax1.get_yaxis().set_visible(False)
    plt.title('samples generated from latent space of autoencoder')
    grid = ImageGrid(
        fig, 111, nrows_ncols=(21, 21),
        share_all=True)

    print('plotting latent space')
    for i, x in enumerate(X):
        img = (x * 255).astype(np.uint8)
        grid[i].imshow(img, cmap='Greys_r')
        grid[i].get_xaxis().set_visible(False)
        grid[i].get_yaxis().set_visible(False)
        grid[i].set_frame_on(False)

    plt.savefig('latent_train_val.png', bbox_inches='tight') 
Example #23
Source File: plot.py    From pde-surrogate with MIT License 4 votes vote down vote up
def save_samples(save_dir, images, epoch, index, name, nrow=4, heatmap=True, cmap='jet', title=False):
    """Save samples in grid as images or plots
    Args:
        images (Tensor): B x C x H x W
    """

    # if images.shape[0] < 10:
    #     nrow = 2
    #     ncol = images.shape[0] // nrow
    # else:
    #     ncol = nrow
    images = to_numpy(images)
    ncol = images.shape[0] // nrow

    if heatmap:
        for c in range(images.shape[1]):
            # (11, 12)
            fig = plt.figure(1, (12, 12))
            grid = ImageGrid(fig, 111,
                             nrows_ncols=(nrow, ncol),
                             axes_pad=0.1,
                             share_all=False,
                             cbar_location="top",
                             cbar_mode="single",
                             cbar_size="3%",
                             cbar_pad=0.1
                             )
            for j, ax in enumerate(grid):
                im = ax.imshow(images[j][c], cmap=cmap)
                ax.set_axis_off()
                ax.set_aspect('equal')
            cbar = grid.cbar_axes[0].colorbar(im)
            cbar.ax.tick_params(labelsize=10)
            cbar.ax.toggle_label(True)
            # change plot back to epoch
            if title:
                plt.suptitle(f'Epoch {epoch}')
                plt.subplots_adjust(top=0.95)
            plt.savefig(save_dir + '/epoch{}_{}_c{}_index{}.png'.format(epoch, name, c, index),
                        bbox_inches='tight')
            plt.close(fig)
    else:
        torchvision.utils.save_image(images, 
                          save_dir + '/fake_samples_epoch_{}.png'.format(epoch),
                          nrow=nrow,
                          normalize=True)