Python mpl_toolkits.axes_grid1.make_axes_locatable() Examples

The following are 30 code examples of mpl_toolkits.axes_grid1.make_axes_locatable(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module mpl_toolkits.axes_grid1 , or try the search function .
Example #1
Source File: plot.py    From mplhep with MIT License 6 votes vote down vote up
def make_square_add_cbar(ax, size=0.4, pad=0.1):
    """
    Make input axes square and return an appended axes to the right for
    a colorbar. Both axes resize together to fit figure automatically.
    Works with tight_layout().
    """
    divider = make_axes_locatable(ax)

    margin_size = axes_size.Fixed(size)
    pad_size = axes_size.Fixed(pad)
    xsizes = [pad_size, margin_size]
    ysizes = xsizes

    cax = divider.append_axes("right", size=margin_size, pad=pad_size)

    divider.set_horizontal([RemainderFixed(xsizes, ysizes, divider)] + xsizes)
    divider.set_vertical([RemainderFixed(xsizes, ysizes, divider)] + ysizes)
    return cax 
Example #2
Source File: make_room_for_ylabel_using_axesgrid.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def ex3():

        fig = plt.figure(3)
        ax1 = plt.axes([0, 0, 1, 1])
        divider = make_axes_locatable(ax1)

        ax2 = divider.new_horizontal("100%", pad=0.3, sharey=ax1)
        ax2.tick_params(labelleft=False)
        fig.add_axes(ax2)

        divider.add_auto_adjustable_area(use_axes=[ax1], pad=0.1,
                                         adjust_dirs=["left"])
        divider.add_auto_adjustable_area(use_axes=[ax2], pad=0.1,
                                         adjust_dirs=["right"])
        divider.add_auto_adjustable_area(use_axes=[ax1, ax2], pad=0.1,
                                         adjust_dirs=["top", "bottom"])

        ax1.set_yticks([0.5])
        ax1.set_yticklabels(["very long label"])

        ax2.set_title("Title")
        ax2.set_xlabel("X - Label") 
Example #3
Source File: utils.py    From DIAG-NRE with MIT License 6 votes vote down vote up
def show_word_score_heatmap(score_tensor, x_ticks, y_ticks, figsize=(3, 8)):
    # to make colorbar a proper size w.r.t the image
    def colorbar(mappable):
        ax = mappable.axes
        fig = ax.figure
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="10%", pad=0.1)
        return fig.colorbar(mappable, cax=cax)

    mpl.rcParams['font.sans-serif'] = ['simhei']
    mpl.rcParams['axes.unicode_minus'] = False

    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)

    img = ax.matshow(score_tensor.numpy())

    plt.xticks(range(score_tensor.size(1)), x_ticks, fontsize=14)
    plt.yticks(range(score_tensor.size(0)), y_ticks, fontsize=14)

    colorbar(img)

    ax.set_aspect('auto')
    plt.show() 
Example #4
Source File: labeling_toolbox.py    From DeepLabCut with GNU Lesser General Public License v3.0 6 votes vote down vote up
def drawplot(self, img, img_name, itr, index, bodyparts, cmap, keep_view=False):
        xlim = self.axes.get_xlim()
        ylim = self.axes.get_ylim()
        self.axes.clear()

        # convert the image to RGB as you are showing the image with matplotlib
        im = cv2.imread(img)[..., ::-1]
        ax = self.axes.imshow(im, cmap=cmap)
        self.orig_xlim = self.axes.get_xlim()
        self.orig_ylim = self.axes.get_ylim()
        divider = make_axes_locatable(self.axes)
        colorIndex = np.linspace(np.min(im), np.max(im), len(bodyparts))
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cbar = self.figure.colorbar(
            ax, cax=cax, spacing="proportional", ticks=colorIndex
        )
        cbar.set_ticklabels(bodyparts[::-1])
        self.axes.set_title(str(str(itr) + "/" + str(len(index) - 1) + " " + img_name))
        if keep_view:
            self.axes.set_xlim(xlim)
            self.axes.set_ylim(ylim)
        self.toolbar = NavigationToolbar(self.canvas)
        return (self.figure, self.axes, self.canvas, self.toolbar) 
Example #5
Source File: metrics.py    From mindpark with GNU General Public License v3.0 6 votes vote down vote up
def _process_metric(self, ax, metric):
        if not metric.data.size:
            ax.tick_params(colors=(0, 0, 0, 0))
            ax.set_axis_bgcolor(cm.get_cmap('viridis')(0))
            divider = make_axes_locatable(ax)
            divider.append_axes('right', size='7%', pad=0.1).axis('off')
            return
        domain = self._domain(metric)
        categorical = self._is_categorical(metric.data)
        if metric.data.shape[1] == 1 and not categorical:
            self._plot_scalar(ax, domain, metric.data[:, 0])
        elif metric.data.shape[1] == 1:
            indices = metric.data[:, 0].astype(int)
            min_, max_ = indices.min(), indices.max()
            count = np.eye(max_ - min_ + 1)[indices - min_]
            self._plot_distribution(ax, domain, count)
        elif metric.data.shape[1] > 1:
            self._plot_counts(ax, domain, metric.data) 
Example #6
Source File: multiple_individuals_labeling_toolbox.py    From DeepLabCut with GNU Lesser General Public License v3.0 6 votes vote down vote up
def drawplot(self, img, img_name, itr, index, bodyparts, cmap, keep_view=False):
        xlim = self.axes.get_xlim()
        ylim = self.axes.get_ylim()
        self.axes.clear()
        #        im = cv2.imread(img)
        # convert the image to RGB as you are showing the image with matplotlib
        im = cv2.imread(img)[..., ::-1]
        ax = self.axes.imshow(im, cmap=cmap)
        self.orig_xlim = self.axes.get_xlim()
        self.orig_ylim = self.axes.get_ylim()
        #        divider = make_axes_locatable(self.axes)
        #        colorIndex = np.linspace(np.min(im),np.max(im),len(bodyparts))
        #        cax = divider.append_axes("right", size="5%", pad=0.05)
        #        cbar = self.figure.colorbar(ax, cax=cax,spacing='proportional', ticks=colorIndex)
        #        cbar.set_ticklabels(bodyparts[::-1])
        self.axes.set_title(str(str(itr) + "/" + str(len(index) - 1) + " " + img_name))
        #        self.figure.canvas.draw()
        if keep_view:
            self.axes.set_xlim(xlim)
            self.axes.set_ylim(ylim)
        self.toolbar = NavigationToolbar(self.canvas)
        return (self.figure, self.axes, self.canvas, self.toolbar, ax) 
Example #7
Source File: plotting.py    From devito with MIT License 6 votes vote down vote up
def plot_image(data, vmin=None, vmax=None, colorbar=True, cmap="gray"):
    """
    Plot image data, such as RTM images or FWI gradients.

    Parameters
    ----------
    data : ndarray
        Image data to plot.
    cmap : str
        Choice of colormap. Defaults to gray scale for images as a
        seismic convention.
    """
    plot = plt.imshow(np.transpose(data),
                      vmin=vmin or 0.9 * np.min(data),
                      vmax=vmax or 1.1 * np.max(data),
                      cmap=cmap)

    # Create aligned colorbar on the right
    if colorbar:
        ax = plt.gca()
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(plot, cax=cax)
    plt.show() 
Example #8
Source File: visualization_util.py    From leap with MIT License 6 votes vote down vote up
def plot_vector_field(fig, ax, vector_field, skip_rate=1):
    skip = (slice(None, None, skip_rate), slice(None, None, skip_rate))
    p, dx, dy, x, y, _ = vector_field
    im = ax.imshow(
        np.swapaxes(p, 0, 1),  # imshow uses first axis as y-axis
        extent=[x.min(), x.max(), y.min(), y.max()],
        cmap=plt.get_cmap('plasma'),
        interpolation='nearest',
        aspect='auto',
        origin='bottom',  # <-- Important! By default top left is (0, 0)
    )
    x, y = np.meshgrid(x, y)
    ax.quiver(x[skip], y[skip], dx[skip], dy[skip])

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im, cax=cax, orientation='vertical') 
Example #9
Source File: visualization_util.py    From leap with MIT License 6 votes vote down vote up
def plot_heatmap(heatmap, fig=None, ax=None, legend_axis=None):
    if fig is None:
        fig = plt.gcf()
    if ax is None:
        ax = plt.gca()
    p, x, y, _ = heatmap
    im = ax.imshow(
        np.swapaxes(p, 0, 1),  # imshow uses first axis as y-axis
        extent=[x.min(), x.max(), y.min(), y.max()],
        cmap=plt.get_cmap('plasma'),
        interpolation='nearest',
        aspect='auto',
        origin='bottom',  # <-- Important! By default top left is (0, 0)
    )
    if legend_axis is None:
        divider = make_axes_locatable(ax)
        legend_axis = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im, cax=legend_axis, orientation='vertical')
    return im, legend_axis 
Example #10
Source File: asep_slow.py    From cellular_automata with GNU General Public License v2.0 6 votes vote down vote up
def plot_cells(state_cells, walls_inf, i):
    """
    plot the actual state of the cells. we need to make 'bad' walls to better visualize the cells
    :param state_cells: state of the cells
    :param walls_inf: walls for visualisation purposes
    :param i: index for figures
    """
    walls_inf = walls_inf * np.Inf
    tmp_cells = np.vstack((walls_inf, state_cells))
    tmp_cells = np.vstack((tmp_cells, walls_inf))
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.cla()
    cmap = plt.get_cmap('gray')
    cmap.set_bad(color='k', alpha=0.8)
    im = ax.imshow(tmp_cells, cmap=cmap, vmin=0, vmax=1, interpolation='nearest')
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='1%', pad=0.1)
    plt.colorbar(im, cax=cax, ticks=[0, 1])
    ax.set_axis_off()
    num = sum(state_cells)
    text = "t: %3.3d | n: %d\n" % (i, num)
    plt.title("%20s" % text, rotation=0, fontsize=10, verticalalignment='bottom')
    figure_name = os.path.join('pngs', 'peds%.3d.png' % i)
    plt.savefig(figure_name, dpi=100, facecolor='lightgray') 
Example #11
Source File: asep_fast.py    From cellular_automata with GNU General Public License v2.0 6 votes vote down vote up
def plot_cells(state_cells, walls_inf, i):
    """
    plot the actual state of the cells. we need to make 'bad' walls to better visualize the cells
    :param state_cells: state of the cells
    :param walls_inf: walls for visualisation purposes
    :param i: index for figures
    """
    walls_inf = walls_inf * np.Inf
    tmp_cells = np.vstack((walls_inf, state_cells))
    tmp_cells = np.vstack((tmp_cells, walls_inf))
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.cla()
    cmap = plt.get_cmap('gray')
    cmap.set_bad(color='k', alpha=0.8)
    im = ax.imshow(tmp_cells, cmap=cmap, vmin=0, vmax=1, interpolation='nearest')
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='1%', pad=0.1)
    plt.colorbar(im, cax=cax, ticks=[0, 1])
    ax.set_axis_off()
    num = sum(state_cells)
    text = "t: %3.3d | n: %d\n" % (i, num)
    plt.title("%20s" % text, rotation=0, fontsize=10, verticalalignment='bottom')
    figure_name = os.path.join('pngs', 'peds%.3d.png' % i)
    plt.savefig(figure_name, dpi=100, facecolor='lightgray') 
Example #12
Source File: plotting.py    From nevergrad with MIT License 6 votes vote down vote up
def __init__(self, winrates_df: pd.DataFrame) -> None:
        # make plot
        self.winrates = winrates_df
        self._fig = plt.figure()
        self._ax = self._fig.add_subplot(111)
        self._cax = self._ax.imshow(100 * np.array(self.winrates), cmap=cm.seismic, interpolation="none", vmin=0, vmax=100)
        x_names = self.winrates.columns
        self._ax.set_xticks(list(range(len(x_names))))
        self._ax.set_xticklabels(x_names, rotation=90, fontsize=7)  # , ha="left")
        y_names = self.winrates.index
        self._ax.set_yticks(list(range(len(y_names))))
        self._ax.set_yticklabels(y_names, rotation=45, fontsize=7)
        divider = make_axes_locatable(self._ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        # self._fig.colorbar(im, cax=cax)
        self._fig.colorbar(self._cax, cax=cax)  # , orientation='horizontal')
        plt.tight_layout() 
Example #13
Source File: demo_utils.py    From Resemblyzer with Apache License 2.0 6 votes vote down vote up
def plot_similarity_matrix(matrix, labels_a=None, labels_b=None, ax: plt.Axes=None, title=""):
    if ax is None:
        _, ax = plt.subplots()
    fig = plt.gcf()
        
    img = ax.matshow(matrix, extent=(-0.5, matrix.shape[0] - 0.5, 
                                     -0.5, matrix.shape[1] - 0.5))

    ax.xaxis.set_ticks_position("bottom")
    if labels_a is not None:
        ax.set_xticks(range(len(labels_a)))
        ax.set_xticklabels(labels_a, rotation=90)
    if labels_b is not None:
        ax.set_yticks(range(len(labels_b)))
        ax.set_yticklabels(labels_b[::-1])  # Upper origin -> reverse y axis
    ax.set_title(title)

    cax = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.15)
    fig.colorbar(img, cax=cax, ticks=np.linspace(0.4, 1, 7))
    img.set_clim(0.4, 1)
    img.set_cmap("inferno")
    
    return ax 
Example #14
Source File: plot2D.py    From Python_DIC with Apache License 2.0 6 votes vote down vote up
def plot2D_correlation(self, plotFig, plotAx, data_x, data_y, corr):

    plotAx.cla() #clear the figure
    plotAx.patch.set_facecolor('none') #remove figure background

    try:
        plotFig.delaxes(plotFig.axes[1])
    except:
        pass

    #plotAx.mappable = plotAx.contourf(data_x, data_y, corr, np.linspace(0, 1, 9), cmap = 'Spectral', extend='min', spacing='proportional')
    plotAx.mappable = plotAx.imshow(corr, cmap = 'RdBu')
    plotAx.mappable.axes.xaxis.set_ticklabels([])
    plotAx.mappable.axes.yaxis.set_ticklabels([])
    plotAx.invert_yaxis()

    #colorbar display
    divider = make_axes_locatable(plotAx)
    plotAx.cax = divider.append_axes('right', size='5%', pad='1%')
    plotAx.cbar = plotFig.colorbar(plotAx.mappable, cax=plotAx.cax, extend='min')
    plotAx.cbar.ax.tick_params(labelsize=7)
    labels = np.linspace(0, 1, 11)
    ticks = np.linspace(-0.1, 0.1, 11)
    plotAx.cbar.set_ticks(ticks)
    plotAx.cbar.set_ticklabels(labels) 
Example #15
Source File: test_frame.py    From elasticintel with GNU General Public License v3.0 5 votes vote down vote up
def test_plain_axes(self):

        # supplied ax itself is a SubplotAxes, but figure contains also
        # a plain Axes object (GH11556)
        fig, ax = self.plt.subplots()
        fig.add_axes([0.2, 0.2, 0.2, 0.2])
        Series(rand(10)).plot(ax=ax)

        # suppliad ax itself is a plain Axes, but because the cmap keyword
        # a new ax is created for the colorbar -> also multiples axes (GH11520)
        df = DataFrame({'a': randn(8), 'b': randn(8)})
        fig = self.plt.figure()
        ax = fig.add_axes((0, 0, 1, 1))
        df.plot(kind='scatter', ax=ax, x='a', y='b', c='a', cmap='hsv')

        # other examples
        fig, ax = self.plt.subplots()
        from mpl_toolkits.axes_grid1 import make_axes_locatable
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        Series(rand(10)).plot(ax=ax)
        Series(rand(10)).plot(ax=cax)

        fig, ax = self.plt.subplots()
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes
        iax = inset_axes(ax, width="30%", height=1., loc=3)
        Series(rand(10)).plot(ax=ax)
        Series(rand(10)).plot(ax=iax) 
Example #16
Source File: plot.py    From PyTorchWavelets with MIT License 5 votes vote down vote up
def plot_scalogram(power, scales, t, normalize_columns=True, cmap=None, ax=None, scale_legend=True):
    """
    Plot the wavelet power spectrum (scalogram).

    :param power: np.ndarray, CWT power spectrum of shape [n_scales,signal_length]
    :param scales: np.ndarray, scale distribution of shape [n_scales]
    :param t: np.ndarray, temporal range of shape [signal_length]
    :param normalize_columns: boolean, whether to normalize spectrum per timestep
    :param cmap: matplotlib cmap, please refer to their documentation
    :param ax: matplotlib axis object, if None creates a new subplot
    :param scale_legend: boolean, whether to include scale legend on the right
    :return: ax, matplotlib axis object that contains the scalogram
    """

    if not cmap: cmap = plt.get_cmap("PuBu_r")
    if ax is None: fig, ax = plt.subplots()
    if normalize_columns: power = power/np.max(power, axis=0)

    T, S = np.meshgrid(t, scales)
    cnt = ax.contourf(T, S, power, 100, cmap=cmap)

    # Fix for saving as PDF (aliasing)
    for c in cnt.collections:
        c.set_edgecolor("face")

    ax.set_yscale('log')
    ax.set_ylabel("Scale (Log Scale)")
    ax.set_xlabel("Time (s)")
    ax.set_title("Wavelet Power Spectrum")

    if scale_legend:
        def format_axes_label(x, pos):
            return "{:.2f}".format(x)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(cnt, cax=cax, ticks=[np.min(power), 0, np.max(power)],
                     format=ticker.FuncFormatter(format_axes_label))

    return ax 
Example #17
Source File: plotting.py    From bindsnet with GNU Affero General Public License v3.0 5 votes vote down vote up
def plot_weights(
    weights: torch.Tensor,
    wmin: Optional[float] = 0,
    wmax: Optional[float] = 1,
    im: Optional[AxesImage] = None,
    figsize: Tuple[int, int] = (5, 5),
    cmap: str = "hot_r",
) -> AxesImage:
    # language=rst
    """
    Plot a connection weight matrix.

    :param weights: Weight matrix of ``Connection`` object.
    :param wmin: Minimum allowed weight value.
    :param wmax: Maximum allowed weight value.
    :param im: Used for re-drawing the weights plot.
    :param figsize: Horizontal, vertical figure size in inches.
    :param cmap: Matplotlib colormap.
    :return: ``AxesImage`` for re-drawing the weights plot.
    """
    local_weights = weights.detach().clone().cpu().numpy()
    if not im:
        fig, ax = plt.subplots(figsize=figsize)

        im = ax.imshow(local_weights, cmap=cmap, vmin=wmin, vmax=wmax)
        div = make_axes_locatable(ax)
        cax = div.append_axes("right", size="5%", pad=0.05)

        ax.set_xticks(())
        ax.set_yticks(())
        ax.set_aspect("auto")

        plt.colorbar(im, cax=cax)
        fig.tight_layout()
    else:
        im.set_data(local_weights)

    return im 
Example #18
Source File: util.py    From satimage with MIT License 5 votes vote down vote up
def nice_imshow(ax, data, vmin=None, vmax=None, cmap=None):
    """Wrapper around pl.imshow"""
    if cmap is None:
        cmap = cm.jet
    if vmin is None:
        vmin = data.min()
    if vmax is None:
        vmax = data.max()
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    im = ax.imshow(data, vmin=vmin, vmax=vmax, interpolation='nearest', cmap=cmap)
    pl.colorbar(im, cax=cax) 
Example #19
Source File: picketfence.py    From pylinac with MIT License 5 votes vote down vote up
def _add_leaf_error_subplot(self, ax: plt.Axes):
        """Add a bar subplot showing the leaf error."""
        tol_line_height = [self.settings.tolerance, self.settings.tolerance]
        tol_line_width = [0, max(self.image.shape)]

        # make the new axis
        divider = make_axes_locatable(ax)
        if self.settings.orientation == UP_DOWN:
            axtop = divider.append_axes('right', 2, pad=1, sharey=ax)
        else:
            axtop = divider.append_axes('bottom', 2, pad=1, sharex=ax)

        # get leaf positions, errors, standard deviation, and leaf numbers
        pos, vals, err, leaf_nums = self.pickets.error_hist()

        # plot the leaf errors as a bar plot
        if self.settings.orientation == UP_DOWN:
            axtop.barh(pos, vals, xerr=err, height=self.pickets[0].sample_width * 2, alpha=0.4, align='center')
            # plot the tolerance line(s)
            # TODO: replace .plot() calls with .axhline when mpld3 fixes funtionality
            axtop.plot(tol_line_height, tol_line_width, 'r-', linewidth=3)
            if self.settings.action_tolerance is not None:
                axtop.plot(tol_line_height, tol_line_width, 'y-', linewidth=3)

            # reset xlims to comfortably include the max error or tolerance value
            axtop.set_xlim([0, max(max(vals), self.settings.tolerance) + 0.1])
        else:
            axtop.bar(pos, vals, yerr=err, width=self.pickets[0].sample_width * 2, alpha=0.4, align='center')
            axtop.plot(tol_line_width, tol_line_height,
                       'r-', linewidth=3)
            if self.settings.action_tolerance is not None:
                axtop.plot(tol_line_width, tol_line_height, 'y-', linewidth=3)
            axtop.set_ylim([0, max(max(vals), self.settings.tolerance) + 0.1])

        # add formatting to axis
        axtop.grid(True)
        axtop.set_title("Average Error (mm)") 
Example #20
Source File: model_band_plot.py    From lenstronomy with MIT License 5 votes vote down vote up
def convergence_plot(self, ax, text='Convergence', v_min=None, v_max=None,
                         font_size=15, colorbar_label=r'$\log_{10}\ \kappa$',
                         **kwargs):
        """

        :param x_grid:
        :param y_grid:
        :param kwargs_lens:
        :param kwargs_else:
        :return:
        """
        if not 'cmap' in kwargs:
            kwargs['cmap'] = self._cmap

        kappa_result = util.array2image(self._lensModel.kappa(self._x_grid, self._y_grid, self._kwargs_lens_partial))
        im = ax.matshow(np.log10(kappa_result), origin='lower',
                        extent=[0, self._frame_size, 0, self._frame_size],
                        cmap=kwargs['cmap'], vmin=v_min, vmax=v_max)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax, self._frame_size, dist=1, text='1"', color='w', font_size=font_size)
        if 'no_arrow' not in kwargs or not kwargs['no_arrow']:
            plot_util.coordinate_arrows(ax, self._frame_size, self._coords, color='w',
                              arrow_size=self._arrow_size, font_size=font_size)
            plot_util.text_description(ax, self._frame_size, text=text,
                         color="w", backgroundcolor='k', flipped=False,
                         font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(colorbar_label, fontsize=font_size)
        return ax 
Example #21
Source File: Plot.py    From Wave-U-Net with MIT License 5 votes vote down vote up
def draw_spectrogram(example_wav="musb_005_angela thomas wade_audio_model_without_context_cut_28234samples_61002samples_93770samples_126538.wav"):
    y, sr = Utils.load(example_wav, sr=None)
    spec = np.abs(librosa.stft(y, 512, 256, 512))
    norm_spec = librosa.power_to_db(spec**2)
    black_time_frames = np.array([28234, 61002, 93770, 126538]) / 256.0

    fig, ax = plt.subplots()
    img = ax.imshow(norm_spec)
    plt.vlines(black_time_frames, [0, 0, 0, 0], [10, 10, 10, 10], colors="red", lw=2, alpha=0.5)
    plt.vlines(black_time_frames, [256, 256, 256, 256], [246, 246, 246, 246], colors="red", lw=2, alpha=0.5)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)
    plt.colorbar(img, cax=cax)

    ax.xaxis.set_label_position("bottom")
    #ticks_x = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x * 256.0 / sr))
    #ax.xaxis.set_major_formatter(ticks_x)
    ax.xaxis.set_major_locator(ticker.FixedLocator(([i * sr / 256. for i in range(len(y)//sr + 1)])))
    ax.xaxis.set_major_formatter(ticker.FixedFormatter(([str(i) for i in range(len(y)//sr + 1)])))

    ax.yaxis.set_major_locator(ticker.FixedLocator(([float(i) * 2000.0 / (sr/2.0) * 256. for i in range(6)])))
    ax.yaxis.set_major_formatter(ticker.FixedFormatter([str(i*2) for i in range(6)]))

    ax.set_xlabel("t (s)")
    ax.set_ylabel('f (KHz)')

    fig.set_size_inches(7., 3.)
    fig.savefig("spectrogram_example.pdf", bbox_inches='tight') 
Example #22
Source File: plotting.py    From recurrent-slds with MIT License 5 votes vote down vote up
def plot_z_samples(K, zs, zref=None,
                   plt_slice=None,
                   N_iters=None,
                   title=None,
                   ax=None):
    if ax is None:
        fig = plt.figure(figsize=(10, 5))
        ax = fig.add_subplot(111)

    zs = np.array(zs)
    if plt_slice is None:
        plt_slice = (0, zs.shape[1])
    if N_iters is None:
        N_iters = zs.shape[0]

    im = ax.imshow(zs[:, slice(*plt_slice)], aspect='auto', vmin=0, vmax=K - 1,
                   cmap=gradient_cmap(colors[:K]), interpolation="nearest",
                   extent=plt_slice + (N_iters, 0))

    ax.set_xticks([])
    ax.set_ylabel("Iteration")

    if zref is not None:
        divider = make_axes_locatable(ax)
        ax2 = divider.append_axes("bottom", size="10%", pad=0.05)

        zref = np.atleast_2d(zref)
        im = ax2.imshow(zref[:, slice(*plt_slice)], aspect='auto', vmin=0, vmax=K - 1,
                        cmap=gradient_cmap(colors[:K]), interpolation="nearest")
        ax.set_xticks([])
        ax2.set_yticks([])
        ax2.set_ylabel("True $z$", rotation=0)
        ax2.yaxis.set_label_coords(-.15, -.5)
        ax2.set_xlabel("Time")

    if title is not None:
        ax.set_title(title) 
Example #23
Source File: plotting.py    From recurrent-slds with MIT License 5 votes vote down vote up
def plot_separate_trans_probs(reg, xlim=(-4, 4), ylim=(-3, 3), n_pts=100, ax=None):
    K = reg.D_out
    XX, YY = np.meshgrid(np.linspace(*xlim, n_pts),
                         np.linspace(*ylim, n_pts))
    XY = np.column_stack((np.ravel(XX), np.ravel(YY)))

    D_reg = reg.D_in
    inputs = np.hstack((np.zeros((n_pts ** 2, D_reg - 2)), XY))
    test_prs = reg.pi(inputs)

    if ax is None:
        fig = plt.figure(figsize=(12, 3))

    for k in range(K):
        ax = fig.add_subplot(1, K, k + 1)
        cmap = gradient_cmap([np.ones(3), colors[k % len(colors)]])
        im1 = ax.imshow(test_prs[:, k].reshape(*XX.shape),
                        extent=xlim + tuple(reversed(ylim)),
                        vmin=0, vmax=1, cmap=cmap)

        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im1, cax=cax, ax=ax)

    plt.tight_layout()
    return ax 
Example #24
Source File: model_band_plot.py    From lenstronomy with MIT License 5 votes vote down vote up
def subtract_from_data_plot(self, ax, text='Subtracted', v_min=None,
                                v_max=None, point_source_add=False,
                                source_add=False, lens_light_add=False,
                                font_size=15
                                ):
        model = self.bandmodel.image(self._kwargs_lens_partial, self._kwargs_source_partial, self._kwargs_lens_light_partial,
                                          self._kwargs_ps_partial, unconvolved=False, source_add=source_add,
                                          lens_light_add=lens_light_add, point_source_add=point_source_add)
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        im = ax.matshow(np.log10(self._data - model), origin='lower', vmin=v_min, vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size], cmap=self._cmap)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax, self._frame_size, dist=1, text='1"', font_size=font_size)
        plot_util.text_description(ax, self._frame_size, text=text, color="w",
                         backgroundcolor='k', font_size=font_size)
        plot_util.coordinate_arrows(ax, self._frame_size, self._coords,
                          arrow_size=self._arrow_size, font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=font_size)
        return ax 
Example #25
Source File: model_band_plot.py    From lenstronomy with MIT License 5 votes vote down vote up
def decomposition_plot(self, ax, text='Reconstructed', v_min=None, v_max=None,
                           unconvolved=False, point_source_add=False,
                           font_size=15,
                           source_add=False, lens_light_add=False, **kwargs):
        """

        :param ax:
        :param text:
        :param v_min:
        :param v_max:
        :param unconvolved:
        :param point_source_add:
        :param source_add:
        :param lens_light_add:
        :param kwargs: kwargs to send matplotlib.pyplot.matshow()
        :return:
        """
        model = self.bandmodel.image(self._kwargs_lens_partial, self._kwargs_source_partial, self._kwargs_lens_light_partial,
                                          self._kwargs_ps_partial, unconvolved=unconvolved, source_add=source_add,
                                          lens_light_add=lens_light_add, point_source_add=point_source_add)
        if v_min is None:
            v_min = self._v_min_default
        if v_max is None:
            v_max = self._v_max_default
        if not 'cmap' in kwargs:
            kwargs['cmap'] = self._cmap
        im = ax.matshow(np.log10(model), origin='lower', vmin=v_min, vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size], **kwargs)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax, self._frame_size, dist=1, text='1"', font_size=font_size)
        plot_util.text_description(ax, self._frame_size, text=text, color="w", backgroundcolor='k')
        plot_util.coordinate_arrows(ax, self._frame_size, self._coords,
                          arrow_size=self._arrow_size, font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(r'log$_{10}$ flux', fontsize=font_size)
        return ax 
Example #26
Source File: model_band_plot.py    From lenstronomy with MIT License 5 votes vote down vote up
def deflection_plot(self, ax, v_min=None, v_max=None, axis=0,
                        with_caustics=False, image_name_list=None,
                        text="Deflection model", font_size=15,
                        colorbar_label=r'arcsec'):
        """

        :param kwargs_lens:
        :param kwargs_else:
        :return:
        """

        alpha1, alpha2 = self._lensModel.alpha(self._x_grid, self._y_grid, self._kwargs_lens_partial)
        alpha1 = util.array2image(alpha1)
        alpha2 = util.array2image(alpha2)
        if axis == 0:
            alpha = alpha1
        else:
            alpha = alpha2
        im = ax.matshow(alpha, origin='lower', extent=[0, self._frame_size, 0, self._frame_size],
                        vmin=v_min, vmax=v_max, cmap=self._cmap, alpha=0.5)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax, self._frame_size, dist=1, text='1"', color='k', font_size=font_size)
        plot_util.coordinate_arrows(ax, self._frame_size, self._coords, color='k',
                          arrow_size=self._arrow_size, font_size=font_size)
        plot_util.text_description(ax, self._frame_size, text=text, color="k",
                         backgroundcolor='w', font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(colorbar_label, fontsize=font_size)
        if with_caustics is True:
            ra_crit_list, dec_crit_list = self._critical_curves()
            ra_caustic_list, dec_caustic_list = self._caustics()
            plot_util.plot_line_set(ax, self._coords, ra_caustic_list, dec_caustic_list, color='b')
            plot_util.plot_line_set(ax, self._coords, ra_crit_list, dec_crit_list, color='r')
        ra_image, dec_image = self.bandmodel.PointSource.image_position(self._kwargs_ps_partial, self._kwargs_lens_partial)
        plot_util.image_position_plot(ax, self._coords, ra_image, dec_image, image_name_list=image_name_list)
        return ax 
Example #27
Source File: model_band_plot.py    From lenstronomy with MIT License 5 votes vote down vote up
def magnification_plot(self, ax, v_min=-10, v_max=10,
                           image_name_list=None, font_size=15, no_arrow=False,
                           text="Magnification model",
                           colorbar_label=r"$\det\ (\mathsf{A}^{-1})$",
                           **kwargs):
        """

        :param ax: matplotib axis instance
        :param v_min: minimum range of plotting
        :param v_max: maximum range of plotting
        :param kwargs: kwargs to send to matplotlib.pyplot.matshow()
        :return:
        """
        if not 'cmap' in kwargs:
            kwargs['cmap'] = self._cmap
        if not 'alpha' in kwargs:
            kwargs['alpha'] = 0.5
        mag_result = util.array2image(self._lensModel.magnification(self._x_grid, self._y_grid, self._kwargs_lens_partial))
        im = ax.matshow(mag_result, origin='lower', extent=[0, self._frame_size, 0, self._frame_size],
                        vmin=v_min, vmax=v_max, **kwargs)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax, self._frame_size, dist=1, text='1"', color='k', font_size=font_size)
        if not no_arrow:
            plot_util.coordinate_arrows(ax, self._frame_size, self._coords, color='k', arrow_size=self._arrow_size,
                                        font_size=font_size)
        plot_util.text_description(ax, self._frame_size, text=text, color="k",
                         backgroundcolor='w', font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(colorbar_label, fontsize=font_size)
        ra_image, dec_image = self.bandmodel.PointSource.image_position(self._kwargs_ps_partial, self._kwargs_lens_partial)
        plot_util.image_position_plot(ax, self._coords, ra_image, dec_image, color='k', image_name_list=image_name_list)
        return ax 
Example #28
Source File: model_band_plot.py    From lenstronomy with MIT License 5 votes vote down vote up
def absolute_residual_plot(self, ax, v_min=-1, v_max=1, font_size=15,
                               text="Residuals",
                               colorbar_label=r'(f$_{model}$-f$_{data}$)'):
        """

        :param ax:
        :param residuals:
        :return:
        """
        im = ax.matshow(self._model - self._data, vmin=v_min, vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size], cmap='bwr', origin='lower')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax, self._frame_size, dist=1, text='1"', color='k',
                  font_size=font_size)
        plot_util.text_description(ax, self._frame_size, text=text, color="k",
                         backgroundcolor='w', font_size=font_size)
        plot_util.coordinate_arrows(ax, self._frame_size, self._coords,
                          font_size=font_size,
                          color='k',
                          arrow_size=self._arrow_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(colorbar_label, fontsize=font_size)
        return ax 
Example #29
Source File: model_band_plot.py    From lenstronomy with MIT License 5 votes vote down vote up
def normalized_residual_plot(self, ax, v_min=-6, v_max=6, font_size=15, text="Normalized Residuals",
                                 colorbar_label=r'(f${}_{\rm model}$ - f${}_{\rm data}$)/$\sigma$',
                                 no_arrow=False, **kwargs):
        """

        :param ax:
        :param v_min:
        :param v_max:
        :param kwargs: kwargs to send to matplotlib.pyplot.matshow()
        :return:
        """
        if not 'cmap' in kwargs:
            kwargs['cmap'] = 'bwr'
        im = ax.matshow(self._norm_residuals, vmin=v_min, vmax=v_max,
                        extent=[0, self._frame_size, 0, self._frame_size], origin='lower',
                        **kwargs)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.autoscale(False)
        plot_util.scale_bar(ax, self._frame_size, dist=1, text='1"', color='k',
                  font_size=font_size)
        plot_util.text_description(ax, self._frame_size, text=text, color="k",
                         backgroundcolor='w', font_size=font_size)
        if not no_arrow:
            plot_util.coordinate_arrows(ax, self._frame_size, self._coords, color='w',
                              arrow_size=self._arrow_size, font_size=font_size)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cb = plt.colorbar(im, cax=cax)
        cb.set_label(colorbar_label,
                     fontsize=font_size)
        return ax 
Example #30
Source File: network.py    From psst with MIT License 5 votes vote down vote up
def plot_line_power(obj, results, hour, ax=None):
    '''
    obj: case or network
    '''

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(16, 10))
        ax.axis('off')

    case, network = _return_case_network(obj)

    network.draw_buses(ax=ax)
    network.draw_loads(ax=ax)
    network.draw_generators(ax=ax)
    network.draw_connections('gen_to_bus', ax=ax)
    network.draw_connections('load_to_bus', ax=ax)

    edgelist, edge_color, edge_width, edge_labels = _generate_edges(results, case, hour)
    branches = network.draw_branches(ax=ax, edgelist=edgelist, edge_color=edge_color, width=edge_width, edge_labels=edge_labels)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cb = plt.colorbar(branches, cax=cax, orientation='vertical')
    cax.yaxis.set_label_position('left')
    cax.yaxis.set_ticks_position('left')
    cb.set_label('Loading Factor')

    return ax