Python matplotlib.pyplot.subplots_adjust() Examples

The following are 30 code examples of matplotlib.pyplot.subplots_adjust(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: matplotlib_trading_chart.py    From tensortrade with Apache License 2.0 7 votes vote down vote up
def __init__(self, df):
        self.df = df

        # Create a figure on screen and set the title
        self.fig = plt.figure()

        # Create top subplot for net worth axis
        self.net_worth_ax = plt.subplot2grid((6, 1), (0, 0), rowspan=2, colspan=1)

        # Create bottom subplot for shared price/volume axis
        self.price_ax = plt.subplot2grid((6, 1), (2, 0), rowspan=8,
                                         colspan=1, sharex=self.net_worth_ax)

        # Create a new axis for volume which shares its x-axis with price
        self.volume_ax = self.price_ax.twinx()

        # Add padding to make graph easier to view
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)

        # Show the graph without blocking the rest of the program
        plt.show(block=False) 
Example #2
Source File: visualise_att_maps_epoch.py    From Attention-Gated-Networks with MIT License 7 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

# Epochs 
Example #3
Source File: visualise_attention.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilterOverlay(input_im, units, figure_id, interp='bilinear',
                        colormap=cm.jet, colormap_lim=None, title='', alpha=0.8):
    plt.ion()
    filters = units.shape[2]
    fig = plt.figure(figure_id, figsize=(5,5))
    fig.clf()

    for i in range(filters):
        plt.imshow(input_im[:,:,0], interpolation=interp, cmap='gray')
        plt.imshow(units[:,:,i], interpolation=interp, cmap=colormap, alpha=alpha)
        plt.axis('off')
        plt.colorbar()
        plt.title(title, fontsize='small')
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

    # plt.savefig('{}/{}.png'.format(dir_name,time.time()))




## Load options 
Example #4
Source File: visualise_attention.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None, title=''):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()
    plt.suptitle(title) 
Example #5
Source File: visualise_fmaps.py    From Attention-Gated-Networks with MIT License 6 votes vote down vote up
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
    plt.ion()
    filters = units.shape[2]
    n_columns = round(math.sqrt(filters))
    n_rows = math.ceil(filters / n_columns) + 1
    fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
    fig.clf()

    for i in range(filters):
        ax1 = plt.subplot(n_rows, n_columns, i+1)
        plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
        plt.axis('on')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        plt.colorbar()
        if colormap_lim:
            plt.clim(colormap_lim[0],colormap_lim[1])

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout()

# Load options 
Example #6
Source File: mnist.py    From WannaPark with GNU General Public License v3.0 6 votes vote down vote up
def plot_bad_images(images):
    """This takes a list of images misclassified by a pretty good
    neural network --- one achieving over 93 percent accuracy --- and
    turns them into a figure."""
    bad_image_indices = [8, 18, 33, 92, 119, 124, 149, 151, 193, 233, 241, 247, 259, 300, 313, 321, 324, 341, 349, 352, 359, 362, 381, 412, 435, 445, 449, 478, 479, 495, 502, 511, 528, 531, 547, 571, 578, 582, 597, 610, 619, 628, 629, 659, 667, 691, 707, 717, 726, 740, 791, 810, 844, 846, 898, 938, 939, 947, 956, 959, 965, 982, 1014, 1033, 1039, 1044, 1050, 1055, 1107, 1112, 1124, 1147, 1181, 1191, 1192, 1198, 1202, 1204, 1206, 1224, 1226, 1232, 1242, 1243, 1247, 1256, 1260, 1263, 1283, 1289, 1299, 1310, 1319, 1326, 1328, 1357, 1378, 1393, 1413, 1422, 1435, 1467, 1469, 1494, 1500, 1522, 1523, 1525, 1527, 1530, 1549, 1553, 1609, 1611, 1634, 1641, 1676, 1678, 1681, 1709, 1717, 1722, 1730, 1732, 1737, 1741, 1754, 1759, 1772, 1773, 1790, 1808, 1813, 1823, 1843, 1850, 1857, 1868, 1878, 1880, 1883, 1901, 1913, 1930, 1938, 1940, 1952, 1969, 1970, 1984, 2001, 2009, 2016, 2018, 2035, 2040, 2043, 2044, 2053, 2063, 2098, 2105, 2109, 2118, 2129, 2130, 2135, 2148, 2161, 2168, 2174, 2182, 2185, 2186, 2189, 2224, 2229, 2237, 2266, 2272, 2293, 2299, 2319, 2325, 2326, 2334, 2369, 2371, 2380, 2381, 2387, 2393, 2395, 2406, 2408, 2414, 2422, 2433, 2450, 2488, 2514, 2526, 2548, 2574, 2589, 2598, 2607, 2610, 2631, 2648, 2654, 2695, 2713, 2720, 2721, 2730, 2770, 2771, 2780, 2863, 2866, 2896, 2907, 2925, 2927, 2939, 2995, 3005, 3023, 3030, 3060, 3073, 3102, 3108, 3110, 3114, 3115, 3117, 3130, 3132, 3157, 3160, 3167, 3183, 3189, 3206, 3240, 3254, 3260, 3280, 3329, 3330, 3333, 3383, 3384, 3475, 3490, 3503, 3520, 3525, 3559, 3567, 3573, 3597, 3598, 3604, 3629, 3664, 3702, 3716, 3718, 3725, 3726, 3727, 3751, 3752, 3757, 3763, 3766, 3767, 3769, 3776, 3780, 3798, 3806, 3808, 3811, 3817, 3821, 3838, 3848, 3853, 3855, 3869, 3876, 3902, 3906, 3926, 3941, 3943, 3951, 3954, 3962, 3976, 3985, 3995, 4000, 4002, 4007, 4017, 4018, 4065, 4075, 4078, 4093, 4102, 4139, 4140, 4152, 4154, 4163, 4165, 4176, 4199, 4201, 4205, 4207, 4212, 4224, 4238, 4248, 4256, 4284, 4289, 4297, 4300, 4306, 4344, 4355, 4356, 4359, 4360, 4369, 4405, 4425, 4433, 4435, 4449, 4487, 4497, 4498, 4500, 4521, 4536, 4548, 4563, 4571, 4575, 4601, 4615, 4620, 4633, 4639, 4662, 4690, 4722, 4731, 4735, 4737, 4739, 4740, 4761, 4798, 4807, 4814, 4823, 4833, 4837, 4874, 4876, 4879, 4880, 4886, 4890, 4910, 4950, 4951, 4952, 4956, 4963, 4966, 4968, 4978, 4990, 5001, 5020, 5054, 5067, 5068, 5078, 5135, 5140, 5143, 5176, 5183, 5201, 5210, 5331, 5409, 5457, 5495, 5600, 5601, 5617, 5623, 5634, 5642, 5677, 5678, 5718, 5734, 5735, 5749, 5752, 5771, 5787, 5835, 5842, 5845, 5858, 5887, 5888, 5891, 5906, 5913, 5936, 5937, 5945, 5955, 5957, 5972, 5973, 5985, 5987, 5997, 6035, 6042, 6043, 6045, 6053, 6059, 6065, 6071, 6081, 6091, 6112, 6124, 6157, 6166, 6168, 6172, 6173, 6347, 6370, 6386, 6390, 6391, 6392, 6421, 6426, 6428, 6505, 6542, 6555, 6556, 6560, 6564, 6568, 6571, 6572, 6597, 6598, 6603, 6608, 6625, 6651, 6694, 6706, 6721, 6725, 6740, 6746, 6768, 6783, 6785, 6796, 6817, 6827, 6847, 6870, 6872, 6926, 6945, 7002, 7035, 7043, 7089, 7121, 7130, 7198, 7216, 7233, 7248, 7265, 7426, 7432, 7434, 7494, 7498, 7691, 7777, 7779, 7797, 7800, 7809, 7812, 7821, 7849, 7876, 7886, 7897, 7902, 7905, 7917, 7921, 7945, 7999, 8020, 8059, 8081, 8094, 8095, 8115, 8246, 8256, 8262, 8272, 8273, 8278, 8279, 8293, 8322, 8339, 8353, 8408, 8453, 8456, 8502, 8520, 8522, 8607, 9009, 9010, 9013, 9015, 9019, 9022, 9024, 9026, 9036, 9045, 9046, 9128, 9214, 9280, 9316, 9342, 9382, 9433, 9446, 9506, 9540, 9544, 9587, 9614, 9634, 9642, 9645, 9700, 9716, 9719, 9729, 9732, 9738, 9740, 9741, 9742, 9744, 9745, 9749, 9752, 9768, 9770, 9777, 9779, 9792, 9808, 9831, 9839, 9856, 9858, 9867, 9879, 9883, 9888, 9890, 9893, 9905, 9944, 9970, 9982]
    n = len(bad_image_indices)
    bad_images = [images[j] for j in bad_image_indices]
    fig = plt.figure(figsize=(10, 15))
    for j in xrange(1, n+1):
        ax = fig.add_subplot(25, 125, j)
        ax.matshow(bad_images[j-1], cmap = matplotlib.cm.binary)
        ax.set_title(str(bad_image_indices[j-1]))
        plt.xticks(np.array([]))
        plt.yticks(np.array([]))
    plt.subplots_adjust(hspace = 1.2)
    plt.show() 
Example #7
Source File: TradingChart.py    From RLTrader with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, df):
        self.df = df

        # Create a figure on screen and set the title
        self.fig = plt.figure()

        # Create top subplot for net worth axis
        self.net_worth_ax = plt.subplot2grid((6, 1), (0, 0), rowspan=2, colspan=1)

        # Create bottom subplot for shared price/volume axis
        self.price_ax = plt.subplot2grid((6, 1), (2, 0), rowspan=8, colspan=1, sharex=self.net_worth_ax)

        # Create a new axis for volume which shares its x-axis with price
        self.volume_ax = self.price_ax.twinx()

        # Add padding to make graph easier to view
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)

        # Show the graph without blocking the rest of the program
        plt.show(block=False) 
Example #8
Source File: plot_alert_pattern_subgraphs.py    From AMLSim with Apache License 2.0 6 votes vote down vote up
def plot_alerts(_g, _bank_accts, _output_png):
    bank_ids = _bank_accts.keys()
    cmap = plt.get_cmap("tab10")
    pos = nx.nx_agraph.graphviz_layout(_g)

    plt.figure(figsize=(12.0, 8.0))
    plt.axis('off')

    for i, bank_id in enumerate(bank_ids):
        color = cmap(i)
        members = _bank_accts[bank_id]
        nx.draw_networkx_nodes(_g, pos, members, node_size=300, node_color=color, label=bank_id)
        nx.draw_networkx_labels(_g, pos, {n: n for n in members}, font_size=10)

    edge_labels = nx.get_edge_attributes(_g, "label")
    nx.draw_networkx_edges(_g, pos)
    nx.draw_networkx_edge_labels(_g, pos, edge_labels, font_size=6)

    plt.legend(numpoints=1)
    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    plt.savefig(_output_png, dpi=120) 
Example #9
Source File: example7.py    From bert-as-service with MIT License 6 votes vote down vote up
def vis(embed, vis_alg='PCA', pool_alg='REDUCE_MEAN'):
    plt.close()
    fig = plt.figure()
    plt.rcParams['figure.figsize'] = [21, 7]
    for idx, ebd in enumerate(embed):
        ax = plt.subplot(2, 6, idx + 1)
        vis_x = ebd[:, 0]
        vis_y = ebd[:, 1]
        plt.scatter(vis_x, vis_y, c=subset_label, cmap=ListedColormap(["blue", "green", "yellow", "red"]), marker='.',
                    alpha=0.7, s=2)
        ax.set_title('pool_layer=-%d' % (idx + 1))
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.1, right=0.95, top=0.9)
    cax = plt.axes([0.96, 0.1, 0.01, 0.3])
    cbar = plt.colorbar(cax=cax, ticks=range(num_label))
    cbar.ax.get_yaxis().set_ticks([])
    for j, lab in enumerate(['ent.', 'bus.', 'sci.', 'heal.']):
        cbar.ax.text(.5, (2 * j + 1) / 8.0, lab, ha='center', va='center', rotation=270)
    fig.suptitle('%s visualization of BERT layers using "bert-as-service" (-pool_strategy=%s)' % (vis_alg, pool_alg),
                 fontsize=14)
    plt.show() 
Example #10
Source File: visualize.py    From supair with MIT License 6 votes vote down vote up
def show_images(images, nrow=10, text=None, overlay=None, matplot=False):
    images = np.squeeze(images)
    if not matplot:
        images = np.stack([draw_overlay(images[j], None if text is None else text[j])
                           for j in range(len(images))], axis=0)
        vis.images(images, padding=4, nrow=nrow)
    else:
        fig, axes = plt.subplots(2, 4, figsize=(8, 4))
        plt.subplots_adjust(top=1.0, bottom=0.0,
                            left=0.1, right=0.9,
                            wspace=0.1, hspace=-0.15)
        for i, image in enumerate(images):
            cur_axes = axes[i // 4, i % 4]
            setup_axis(cur_axes)
            cur_axes.imshow(image, cmap='gray', interpolation='none')
            if overlay is not None:
                cur_overlay = scipy.misc.imresize(overlay[i], image.shape)
                cur_axes.imshow(cur_overlay, cmap='RdYlGn', alpha=0.5)
        vis.matplot(plt)
        plt.close(fig) 
Example #11
Source File: test_skew.py    From neural-network-animation with MIT License 6 votes vote down vote up
def test_skew_rectange():

    fix, axes = plt.subplots(5, 5, sharex=True, sharey=True, figsize=(16, 12))
    axes = axes.flat

    rotations = list(itertools.product([-3, -1, 0, 1, 3], repeat=2))

    axes[0].set_xlim([-4, 4])
    axes[0].set_ylim([-4, 4])
    axes[0].set_aspect('equal')

    for ax, (xrots, yrots) in zip(axes, rotations):
        xdeg, ydeg = 45 * xrots, 45 * yrots
        t = transforms.Affine2D().skew_deg(xdeg, ydeg)

        ax.set_title('Skew of {0} in X and {1} in Y'.format(xdeg, ydeg))
        ax.add_patch(mpatch.Rectangle([-1, -1], 2, 2,
                                      transform=t + ax.transData,
                                      alpha=0.5, facecolor='coral'))

    plt.subplots_adjust(wspace=0, left=0, right=1, bottom=0) 
Example #12
Source File: dal.py    From dal with MIT License 6 votes vote down vote up
def init_figure(self):
        self.init_fig = True
        if self.args.figure == True:# and self.obj_fig==None:
            self.obj_fig = plt.figure(figsize=(16,12))
            plt.set_cmap('viridis')

            self.gridspec = gridspec.GridSpec(3,5)
            self.ax_map = plt.subplot(self.gridspec[0,0])
            self.ax_scan = plt.subplot(self.gridspec[1,0])
            self.ax_pose =  plt.subplot(self.gridspec[2,0])

            self.ax_bel =  plt.subplot(self.gridspec[0,1])
            self.ax_lik =  plt.subplot(self.gridspec[1,1])
            self.ax_gtl =  plt.subplot(self.gridspec[2,1])


            self.ax_pbel =  plt.subplot(self.gridspec[0,2:4])
            self.ax_plik =  plt.subplot(self.gridspec[1,2:4])
            self.ax_pgtl =  plt.subplot(self.gridspec[2,2:4])

            self.ax_act = plt.subplot(self.gridspec[0,4])
            self.ax_rew = plt.subplot(self.gridspec[1,4])
            self.ax_err = plt.subplot(self.gridspec[2,4])

            plt.subplots_adjust(hspace = 0.4, wspace=0.4, top=0.95, bottom=0.05) 
Example #13
Source File: dal_ros_aml.py    From dal with MIT License 6 votes vote down vote up
def init_figure(self):
        self.init_fig = True
        if self.args.figure == True:# and self.obj_fig==None:
            self.obj_fig = plt.figure(figsize=(16,12))
            plt.set_cmap('viridis')

            self.gridspec = gridspec.GridSpec(3,5)
            self.ax_map = plt.subplot(self.gridspec[0,0])
            self.ax_scan = plt.subplot(self.gridspec[1,0])
            self.ax_pose =  plt.subplot(self.gridspec[2,0])

            self.ax_bel =  plt.subplot(self.gridspec[0,1])
            self.ax_lik =  plt.subplot(self.gridspec[1,1])
            self.ax_gtl =  plt.subplot(self.gridspec[2,1])


            self.ax_pbel =  plt.subplot(self.gridspec[0,2:4])
            self.ax_plik =  plt.subplot(self.gridspec[1,2:4])
            self.ax_pgtl =  plt.subplot(self.gridspec[2,2:4])

            self.ax_act = plt.subplot(self.gridspec[0,4])
            self.ax_rew = plt.subplot(self.gridspec[1,4])
            self.ax_err = plt.subplot(self.gridspec[2,4])

            plt.subplots_adjust(hspace = 0.4, wspace=0.4, top=0.95, bottom=0.05) 
Example #14
Source File: test_skew.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def test_skew_rectangle():

    fix, axes = plt.subplots(5, 5, sharex=True, sharey=True, figsize=(8, 8))
    axes = axes.flat

    rotations = list(itertools.product([-3, -1, 0, 1, 3], repeat=2))

    axes[0].set_xlim([-3, 3])
    axes[0].set_ylim([-3, 3])
    axes[0].set_aspect('equal', share=True)

    for ax, (xrots, yrots) in zip(axes, rotations):
        xdeg, ydeg = 45 * xrots, 45 * yrots
        t = transforms.Affine2D().skew_deg(xdeg, ydeg)

        ax.set_title('Skew of {0} in X and {1} in Y'.format(xdeg, ydeg))
        ax.add_patch(mpatch.Rectangle([-1, -1], 2, 2,
                                      transform=t + ax.transData,
                                      alpha=0.5, facecolor='coral'))

    plt.subplots_adjust(wspace=0, left=0.01, right=0.99, bottom=0.01, top=0.99) 
Example #15
Source File: atlas3.py    From ssbio with MIT License 5 votes vote down vote up
def make_pairplot(self, num_components_to_plot=4, outpath=None, dpi=150):
        # Get columns
        components_to_plot = [self.principal_observations_df.columns[x] for x in range(num_components_to_plot)]

        # Plot
        plot = sns.pairplot(data=self.principal_observations_df, hue=self.observation_colname,
                                vars=components_to_plot, markers=self.markers, size=4)
        plt.subplots_adjust(top=.95)
        plt.suptitle(self.plot_title)

        if outpath:
            plot.fig.savefig(outpath, dpi=dpi)
        else:
            plt.show()
        plt.close() 
Example #16
Source File: test_colorbar.py    From neural-network-animation with MIT License 5 votes vote down vote up
def _test_remove_from_figure(use_gridspec):
    """
    Test `remove_from_figure` with the specified ``use_gridspec`` setting
    """
    fig = plt.figure()
    ax = fig.add_subplot(111)
    sc = ax.scatter([1, 2], [3, 4], cmap="spring")
    sc.set_array(np.array([5, 6]))
    pre_figbox = np.array(ax.figbox)
    cb = fig.colorbar(sc, use_gridspec=use_gridspec)
    fig.subplots_adjust()
    cb.remove()
    fig.subplots_adjust()
    post_figbox = np.array(ax.figbox)
    assert (pre_figbox == post_figbox).all() 
Example #17
Source File: lqramsey.py    From QuantEcon.lectures.code with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def gen_fig_2(path):
    """
    The parameter is the path namedtuple returned by compute_paths().  See
    the docstring of that function for details.
    """

    T = len(path.c)

    # == Prepare axes == #
    num_rows, num_cols = 2, 1
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(10, 10))
    plt.subplots_adjust(hspace=0.5)
    bbox = (0., 1.02, 1., .102)
    bbox = (0., 1.02, 1., .102)
    legend_args = {'bbox_to_anchor': bbox, 'loc': 3, 'mode': 'expand'}
    p_args = {'lw': 2, 'alpha': 0.7}

    # == Plot adjustment factor == #
    ax = axes[0]
    ax.plot(list(range(2, T+1)), path.ξ, label=r'$\xi_t$', **p_args)
    ax.grid()
    ax.set_xlabel('Time')
    ax.legend(ncol=1, **legend_args)

    # == Plot adjusted cumulative return == #
    ax = axes[1]
    ax.plot(list(range(2, T+1)), path.Π, label=r'$\Pi_t$', **p_args)
    ax.grid()
    ax.set_xlabel('Time')
    ax.legend(ncol=1, **legend_args)

    plt.show() 
Example #18
Source File: test_colorbar.py    From neural-network-animation with MIT License 5 votes vote down vote up
def _colorbar_extension_shape(spacing):
    '''
    Produce 4 colorbars with rectangular extensions for either uniform
    or proportional spacing.

    Helper function for test_colorbar_extension_shape.
    '''
    # Get a colormap and appropriate norms for each extension type.
    cmap, norms = _get_cmap_norms()
    # Create a figure and adjust whitespace for subplots.
    fig = plt.figure()
    fig.subplots_adjust(hspace=4)
    for i, extension_type in enumerate(('neither', 'min', 'max', 'both')):
        # Get the appropriate norm and use it to get colorbar boundaries.
        norm = norms[extension_type]
        boundaries = values = norm.boundaries
        # Create a subplot.
        cax = fig.add_subplot(4, 1, i + 1)
        # Turn off text and ticks.
        for item in cax.get_xticklabels() + cax.get_yticklabels() +\
                cax.get_xticklines() + cax.get_yticklines():
            item.set_visible(False)
        # Generate the colorbar.
        cb = ColorbarBase(cax, cmap=cmap, norm=norm,
                boundaries=boundaries, values=values,
                extend=extension_type, extendrect=True,
                orientation='horizontal', spacing=spacing)
    # Return the figure to the caller.
    return fig 
Example #19
Source File: test_colorbar.py    From neural-network-animation with MIT License 5 votes vote down vote up
def _colorbar_extension_length(spacing):
    '''
    Produce 12 colorbars with variable length extensions for either
    uniform or proportional spacing.

    Helper function for test_colorbar_extension_length.
    '''
    # Get a colormap and appropriate norms for each extension type.
    cmap, norms = _get_cmap_norms()
    # Create a figure and adjust whitespace for subplots.
    fig = plt.figure()
    fig.subplots_adjust(hspace=.6)
    for i, extension_type in enumerate(('neither', 'min', 'max', 'both')):
        # Get the appropriate norm and use it to get colorbar boundaries.
        norm = norms[extension_type]
        boundaries = values = norm.boundaries
        for j, extendfrac in enumerate((None, 'auto', 0.1)):
            # Create a subplot.
            cax = fig.add_subplot(12, 1, i*3 + j + 1)
            # Turn off text and ticks.
            for item in cax.get_xticklabels() + cax.get_yticklabels() +\
                    cax.get_xticklines() + cax.get_yticklines():
                item.set_visible(False)
            # Generate the colorbar.
            cb = ColorbarBase(cax, cmap=cmap, norm=norm,
                    boundaries=boundaries, values=values,
                    extend=extension_type, extendfrac=extendfrac,
                    orientation='horizontal', spacing=spacing)
    # Return the figure to the caller.
    return fig 
Example #20
Source File: test_colorbar.py    From neural-network-animation with MIT License 5 votes vote down vote up
def test_gridspec_make_colorbar():
    plt.figure()
    data = np.arange(1200).reshape(30, 40)
    levels = [0, 200, 400, 600, 800, 1000, 1200]

    plt.subplot(121)
    plt.contourf(data, levels=levels)
    plt.colorbar(use_gridspec=True, orientation='vertical')

    plt.subplot(122)
    plt.contourf(data, levels=levels)
    plt.colorbar(use_gridspec=True, orientation='horizontal')

    plt.subplots_adjust(top=0.95, right=0.95, bottom=0.2, hspace=0.25) 
Example #21
Source File: image_processing.py    From tindetheus with MIT License 5 votes vote down vote up
def show_images(images, holdon=False, title=None, nmax=49):
    # use matplotlib to display profile images
    n = len(images)
    if n > nmax:
        n = nmax
        n_col = 7
    else:
        n_col = 3
    if n % n_col == 0:
        n_row = n // n_col
    else:
        n_row = n // 3 + 1
    if title is None:
        plt.figure()
    else:
        plt.figure(title)
    plt.tight_layout()
    for j, i in enumerate(images):
        if j == nmax:
            print('\n\nToo many images to show... \n\n')
            break
        temp_image = imageio.imread(i)
        if len(temp_image.shape) < 3:
            # needs to be converted to rgb
            temp_image = to_rgb(temp_image)
        plt.subplot(n_row, n_col, j+1)
        plt.imshow(temp_image)
        plt.axis('off')
        plt.subplots_adjust(wspace=0, hspace=0)

    if holdon is False:
        plt.show(block=False)
        plt.pause(0.1) 
Example #22
Source File: utils.py    From labelKeypoint with GNU General Public License v3.0 5 votes vote down vote up
def draw_label(label, img, label_names, colormap=None):
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
                        wspace=0, hspace=0)
    plt.margins(0, 0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())

    if colormap is None:
        colormap = label_colormap(len(label_names))

    label_viz = label2rgb(label, img, n_labels=len(label_names))
    plt.imshow(label_viz)
    plt.axis('off')

    plt_handlers = []
    plt_titles = []
    for label_value, label_name in enumerate(label_names):
        fc = colormap[label_value]
        p = plt.Rectangle((0, 0), 1, 1, fc=fc)
        plt_handlers.append(p)
        plt_titles.append(label_name)
    plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)

    f = io.BytesIO()
    plt.savefig(f, bbox_inches='tight', pad_inches=0)
    plt.cla()
    plt.close()

    out_size = (img.shape[1], img.shape[0])
    out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
    out = np.asarray(out)
    return out 
Example #23
Source File: gradient_ascent.py    From flashtorch with MIT License 5 votes vote down vote up
def _visualize_filters(self, layer, filter_idxs, num_iter, num_subplots,
                           title):
        # Prepare the main plot

        num_cols = 4
        num_rows = int(np.ceil(num_subplots / num_cols))

        fig = plt.figure(figsize=(16, num_rows * 5))
        plt.title(title)
        plt.axis('off')

        self.output = []

        # Plot subplots

        for i, filter_idx in enumerate(filter_idxs):
            output = self.optimize(layer, filter_idx, num_iter=num_iter)

            self.output.append(output)

            ax = fig.add_subplot(num_rows, num_cols, i+1)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title(f'filter {filter_idx}')

            ax.imshow(format_for_plotting(
                standardize_and_clip(output[-1],
                                     saturation=0.15,
                                     brightness=0.7)))

        plt.subplots_adjust(wspace=0, hspace=0); # noqa 
Example #24
Source File: models.py    From pyhawkes with MIT License 5 votes vote down vote up
def plot(self, fig=None, handles=None, figsize=(6,4), color="#377eb8",
             data_index=0, T_slice=None):
        """
        Plot the rates, events, and weights
        :param fig:
        :return:
        """
        import matplotlib.pyplot as plt
        if handles is None:
            if fig is None:
                fig = plt.figure(figsize=figsize)

            # Plot network on left
            rate_width = 3
            ax_net = plt.subplot2grid((self.K, 1+rate_width), (0,0), rowspan=self.K, colspan=1)
            # im = self.plot_adjacency_matrix(ax=ax_net)
            net_lns = self.plot_network(ax=ax_net, color=color)

            # Plot the rates on the right
            axs_rate = [plt.subplot2grid((self.K,4), (k,1), rowspan=1, colspan=rate_width)
                        for k in range(self.K)]
            rate_lns = self.plot_rates(axs=axs_rate, data_index=data_index, T_slice=T_slice, color=color)

            plt.subplots_adjust(wspace=1.0)

        else:
            # Update given handles
            net_lns, rate_lns = handles
            # self.plot_adjacency_matrix(im=im)
            self.plot_network(lns=net_lns)
            self.plot_rates(lns=rate_lns, data_index=data_index)
            plt.pause(0.001)

        return fig, (net_lns, rate_lns) 
Example #25
Source File: make_figure.py    From pyhawkes with MIT License 5 votes vote down vote up
def make_figure_a(S, F, C):
    """
    Plot fluorescence traces, filtered fluorescence, and spike times
    for three neurons
    """
    col = harvard_colors()
    dt = 0.02
    T_start = 0
    T_stop = 1 * 50 * 60
    t = dt * np.arange(T_start, T_stop)

    ks = [0,1]
    nk = len(ks)
    fig = create_figure((3,3))
    for ind,k in enumerate(ks):
        ax = fig.add_subplot(nk,1,ind+1)
        ax.plot(t, F[T_start:T_stop, k], color=col[1], label="$F$")    # Plot the raw flourescence in blue
        ax.plot(t, C[T_start:T_stop, k], color=col[0], lw=1.5, label="$\widehat{F}$")    # Plot the filtered flourescence in red
        spks  = np.where(S[T_start:T_stop, k])[0]
        ax.plot(t[spks], C[spks,k], 'ko', label="S")            # Plot the spike times in black

        # Make a legend
        if ind == 0:
            # Put a legend above
            plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
                       ncol=3, mode="expand", borderaxespad=0.,
                       prop={'size':9})

        # Add labels
        ax.set_ylabel("$F_%d(t)$" % (k+1))
        if ind == nk-1:
            ax.set_xlabel("Time $t$ [sec]")

        # Format the ticks
        ax.set_ylim([-0.1,1.0])
        plt.locator_params(nbins=5, axis="y")


    plt.subplots_adjust(left=0.2, bottom=0.2)
    fig.savefig("figure3a.pdf")
    plt.show() 
Example #26
Source File: tsne.py    From lightnet with MIT License 5 votes vote down vote up
def tsne_plot(labels, tokens):
    "Creates and TSNE model and plots it"
    
    tsne_model = TSNE(perplexity=40, n_components=2, init='pca', n_iter=2500, random_state=23)
    X_2d = tsne_model.fit_transform(tokens)
    X_2d -= X_2d.min(axis=0)
    X_2d /= X_2d.max(axis=0)

    width = 1200
    grid, to_plot = tsne_to_grid(X_2d)
    out_dim = int(width / np.sqrt(to_plot))
   
    fig, ax = plt.subplots(figsize=(width/100, width/100))
    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
    
    for pos, label in zip(grid, labels[0:to_plot]):
        ax.scatter(pos[0], pos[1])
        if False:
            ax.annotate(label,
                     xy=(pos[0], pos[1]),
                     xytext=(5, 2),
                     fontsize=9,
                     textcoords='offset points',
                     ha='right',
                     va='bottom')
        ab = AnnotationBbox(getImage(label, new_size = out_dim / 2), (pos[0], pos[1]), frameon=False)
        ax.add_artist(ab)

    plt.show() 
Example #27
Source File: utils.py    From pbt with MIT License 5 votes vote down vote up
def plots(imgs, figsize=(12, 12), rows=None, cols=None,
          interp=None, titles=None, cmap='gray',
          fig=None):
    if not isinstance(imgs, list):
        imgs = [imgs]
    imgs = [np.array(img) for img in imgs]
    if not isinstance(cmap, list):
        if imgs[0].ndim == 2:
            cmap = 'gray'
        cmap = [cmap] * len(imgs)
    if not isinstance(interp, list):
        interp = [interp] * len(imgs)
    n = len(imgs)
    if not rows and not cols:
        cols = n
        rows = 1
    elif not rows:
        rows = cols
    elif not cols:
        cols = rows
    if not fig:
        rows = int(np.ceil(len(imgs) / cols))
        w = 12
        h = rows * (w / cols + 1)
        figsize = (w, h)
        fig = plt.figure(figsize=figsize)
    fontsize = 13 if cols == 5 else 16
    fig.set_figheight(figsize[1], forward=True)
    fig.clear()
    for i in range(len(imgs)):
        sp = fig.add_subplot(rows, cols, i+1)
        if titles:
            sp.set_title(titles[i], fontsize=fontsize)
        plt.imshow(imgs[i], interpolation=interp[i], cmap=cmap[i])
        plt.axis('off')
        plt.subplots_adjust(0, 0, 1, 1, .1, 0)
        #  plt.tight_layout()
    if fig:
        fig.canvas.draw() 
Example #28
Source File: dal_ros_aml.py    From dal with MIT License 5 votes vote down vote up
def init_figure(self):
        self.init_fig = True
        if self.args.figure == True:# and self.obj_fig==None:
            self.obj_fig = plt.figure(figsize=(16,12))
            plt.set_cmap('viridis')

            self.gridspec = gridspec.GridSpec(3,5)
            self.ax_map = plt.subplot(self.gridspec[0,0])
            self.ax_scan = plt.subplot(self.gridspec[1,0])
            self.ax_pose =  plt.subplot(self.gridspec[2,0])

            self.ax_bel =  plt.subplot(self.gridspec[0,1])
            self.ax_lik =  plt.subplot(self.gridspec[1,1])
            self.ax_gtl =  plt.subplot(self.gridspec[2,1])


            # self.ax_prior =  plt.subplot(self.gridspec[2,1])
            self.ax_pbel =  plt.subplot(self.gridspec[0,2:4])
            self.ax_plik =  plt.subplot(self.gridspec[1,2:4])
            self.ax_pgtl =  plt.subplot(self.gridspec[2,2:4])

            self.ax_act = plt.subplot(self.gridspec[0,4])
            self.ax_rew = plt.subplot(self.gridspec[1,4])
            self.ax_err = plt.subplot(self.gridspec[2,4])

            plt.subplots_adjust(hspace = 0.4, wspace=0.4, top=0.95, bottom=0.05) 
Example #29
Source File: core.py    From sdwan-harvester with GNU General Public License v2.0 5 votes vote down vote up
def create_pie_chart(elements, suptitle, png, figure_id):
    """
    Create pie chart

    :param elements: dict with elements (dict)
    :param suptitle: name of chart (str)
    :param png: name of output file (str)
    :param figure_id: id of current plot (started with 1) (int)
    :return: None
    """
    values = [value for value in elements.values()]
    keys = [key for key in elements.keys()]
    plt.figure(figure_id)
    plt.subplots_adjust(bottom=.05, left=.01, right=.99, top=.90, hspace=.35)

    explode = [0 for x in range(len(keys))]
    max_value = max(values)
    explode[list(values).index(max_value)] = 0.1

    plt.pie(values, labels=keys,
            autopct=make_autopct(values), explode=explode,
            textprops={'fontsize': PIE_LABEL_FONT_SIZE})
    plt.axis("equal")
    plt.suptitle(suptitle, fontsize=PIE_SUPTITLE_FONT_SIZE)

    plt.gcf().set_dpi(PIE_DPI)
    plt.savefig("{dest}/{png}/{result_file}".format(dest=RESULTS_DIR,
                                                    png=PNG_DIR,
                                                    result_file=png)) 
Example #30
Source File: visualization.py    From NTM-Keras with MIT License 5 votes vote down vote up
def update(self, matrix_list, name_list):
        # draw first line
        axes_input = plt.subplot2grid((3, 1), (0, 0), colspan=1)
        axes_input.set_aspect('equal')
        plt.imshow(matrix_list[0], interpolation='none')
        axes_input.set_xticks([])
        axes_input.set_yticks([])

        # draw second line
        axes_output = plt.subplot2grid((3, 1), (1, 0), colspan=1)
        plt.imshow(matrix_list[1], interpolation='none')
        axes_output.set_xticks([])
        axes_output.set_yticks([])

        # draw third line
        axes_predict = plt.subplot2grid((3, 1), (2, 0), colspan=1)
        plt.imshow(matrix_list[2], interpolation='none')
        axes_predict.set_xticks([])
        axes_predict.set_yticks([])

        # # add text
        # plt.text(-2, -19.5, name_list[0], ha='right')
        # plt.text(-2, -7.5, name_list[1], ha='right')
        # plt.text(-2, 4.5, name_list[2], ha='right')
        # plt.text(6, 10, 'Time $\longrightarrow$', ha='right')

        # set tick labels invisible
        make_tick_labels_invisible(plt.gcf())
        # adjust spaces
        plt.subplots_adjust(hspace=0.05, wspace=0.05, bottom=0.1, right=0.8, top=0.9)
        # add color bars
        # *rect* = [left, bottom, width, height]
        cax = plt.axes([0.85, 0.125, 0.015, 0.75])
        plt.colorbar(cax=cax)

        # show figure
        # plt.show()
        plt.draw()
        plt.pause(0.025)
        # plt.pause(15)