Python matplotlib.patches.ConnectionPatch() Examples

The following are 3 code examples of matplotlib.patches.ConnectionPatch(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.patches , or try the search function .
Example #1
Source File: sift.py    From SoTu with MIT License 6 votes vote down vote up
def draw(self, img_q, img_t, pt_qt):
        import matplotlib
        matplotlib.use('Agg')
        from matplotlib import pyplot as plt
        from matplotlib.patches import ConnectionPatch

        fig, (ax_q, ax_t) = plt.subplots(1, 2, figsize=(8, 4))
        for pt_q, pt_t in pt_qt:
            con = ConnectionPatch(pt_t, pt_q,
                                  coordsA='data', coordsB='data',
                                  axesA=ax_t, axesB=ax_q,
                                  color='g', linewidth=0.5)
            ax_t.add_artist(con)
            ax_q.plot(pt_q[0], pt_q[1], 'rx')
            ax_t.plot(pt_t[0], pt_t[1], 'rx')
        ax_q.imshow(img_q)
        ax_t.imshow(img_t)
        ax_q.axis('off')
        ax_t.axis('off')
        plt.subplots_adjust(wspace=0, hspace=0)
        plt.show() 
Example #2
Source File: test_patches.py    From twitter-stock-recommendation with MIT License 5 votes vote down vote up
def test_connection_patch():
    fig, (ax1, ax2) = plt.subplots(1, 2)

    con = mpatches.ConnectionPatch(xyA=(0.1, 0.1), xyB=(0.9, 0.9),
                                   coordsA='data', coordsB='data',
                                   axesA=ax2, axesB=ax1,
                                   arrowstyle="->")
    ax2.add_artist(con) 
Example #3
Source File: vis_utils.py    From VisualizingNDF with MIT License 4 votes vote down vote up
def get_path_saliency(samples, labels, paths, pred, model, tree_idx, name, orientation = 'horizontal'):
    # show the saliency maps for the input samples with their 
    # computational paths 
    plt.figure(figsize=(20,4))
    plt.rcParams.update({'font.size': 18})
    num_samples = len(samples)
    path_length = len(paths[0])
    for sample_idx in range(num_samples):
        sample = samples[sample_idx]
        # plot the sample
        plt.subplot(num_samples, path_length + 1, sample_idx*(path_length + 1) + 1)
        sample_to_plot = revert_preprocessing(sample.unsqueeze(dim=0), name)
        plt.imshow(sample_to_plot.squeeze().cpu().numpy().transpose((1,2,0)))            
        plt.axis('off')        
        plt.title('Pred:{:.2f}, GT:{:.0f}'.format(pred[sample_idx].data.item()*100,
                  labels[sample_idx]*100))
        path = paths[sample_idx]
        for node_idx in range(path_length):
            # compute and plot saliency for each node
            node = path[node_idx][0]
            # probability of arriving at this node
            prob = path[node_idx][1]            
            saliency_map = get_map(model, sample, node, tree_idx, name)
            if orientation == 'horizontal':
                sub_plot_idx = sample_idx*(path_length + 1) + node_idx + 2
                plt.subplot(num_samples, path_length + 1, sub_plot_idx)
            elif orientation == 'vertical':
                raise NotImplementedError             
            else:
                raise NotImplementedError
            plt.imshow(saliency_map,cmap='hot')
            plt.title('(N{:d}, P{:.2f})'.format(node, prob))
            plt.axis('off')
        # draw some arrows 
        for arrow_idx in range(num_samples*(path_length + 1) - 1):
            if (arrow_idx+1) % (path_length+1) == 0 and arrow_idx != 0:
                continue
            ax1 = plt.subplot(num_samples, path_length + 1, arrow_idx + 1)
            ax2 = plt.subplot(num_samples, path_length + 1, arrow_idx + 2)
            arrow = ConnectionPatch(xyA=[1.1,0.5], xyB=[-0.1, 0.5], coordsA='axes fraction', coordsB='axes fraction',
                      axesA=ax1, axesB=ax2, arrowstyle="fancy")
            ax1.add_artist(arrow)
    left  = 0.02  # the left side of the subplots of the figure
    right = 1   # the right side of the subplots of the figure
    bottom = 0.01   # the bottom of the subplots of the figure
    top = 0.90     # the top of the subplots of the figure
    wspace = 0.20 # the amount of width reserved for space between subplots,
                   # expressed as a fraction of the average axis width
    hspace = 0.24   # the amount of height reserved for space between subplots,
                   # expressed as a fraction of the average axis height  
    plt.subplots_adjust(left, bottom, right, top, wspace, hspace)
    plt.show()
    return