Python matplotlib.pyplot.title() Examples

The following are 30 code examples of matplotlib.pyplot.title(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: __init__.py    From EDeN with MIT License 11 votes vote down vote up
def plot_confusion_matrix(y_true, y_pred, size=None, normalize=False):
    """plot_confusion_matrix."""
    cm = confusion_matrix(y_true, y_pred)
    fmt = "%d"
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = "%.2f"
    xticklabels = list(sorted(set(y_pred)))
    yticklabels = list(sorted(set(y_true)))
    if size is not None:
        plt.figure(figsize=(size, size))
    heatmap(cm, xlabel='Predicted label', ylabel='True label',
            xticklabels=xticklabels, yticklabels=yticklabels,
            cmap=plt.cm.Blues, fmt=fmt)
    if normalize:
        plt.title("Confusion matrix (norm.)")
    else:
        plt.title("Confusion matrix")
    plt.gca().invert_yaxis() 
Example #2
Source File: data_augmentation.py    From Sound-Recognition-Tutorial with Apache License 2.0 10 votes vote down vote up
def demo_plot():
    audio = './data/esc10/audio/Dog/1-30226-A.ogg'
    y, sr = librosa.load(audio, sr=44100)
    y_ps = librosa.effects.pitch_shift(y, sr, n_steps=6)   # n_steps控制音调变化尺度
    y_ts = librosa.effects.time_stretch(y, rate=1.2)   # rate控制时间维度的变换尺度
    plt.subplot(311)
    plt.plot(y)
    plt.title('Original waveform')
    plt.axis([0, 200000, -0.4, 0.4])
    # plt.axis([88000, 94000, -0.4, 0.4])
    plt.subplot(312)
    plt.plot(y_ts)
    plt.title('Time Stretch transformed waveform')
    plt.axis([0, 200000, -0.4, 0.4])
    plt.subplot(313)
    plt.plot(y_ps)
    plt.title('Pitch Shift transformed waveform')
    plt.axis([0, 200000, -0.4, 0.4])
    # plt.axis([88000, 94000, -0.4, 0.4])
    plt.tight_layout()
    plt.show() 
Example #3
Source File: util.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 8 votes vote down vote up
def compute_roc(y_true, y_pred, plot=False):
    """
    TODO
    :param y_true: ground truth
    :param y_pred: predictions
    :param plot:
    :return:
    """
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    auc_score = auc(fpr, tpr)
    if plot:
        plt.figure(figsize=(7, 6))
        plt.plot(fpr, tpr, color='blue',
                 label='ROC (AUC = %0.4f)' % auc_score)
        plt.legend(loc='lower right')
        plt.title("ROC Curve")
        plt.xlabel("FPR")
        plt.ylabel("TPR")
        plt.show()

    return fpr, tpr, auc_score 
Example #4
Source File: feature_vis.py    From transferlearning with MIT License 8 votes vote down vote up
def plot_tsne(self, save_eps=False):
        ''' Plot TSNE figure. Set save_eps=True if you want to save a .eps file.
        '''
        tsne = TSNE(n_components=2, init='pca', random_state=0)
        features = tsne.fit_transform(self.features)
        x_min, x_max = np.min(features, 0), np.max(features, 0)
        data = (features - x_min) / (x_max - x_min)
        del features
        for i in range(data.shape[0]):
            plt.text(data[i, 0], data[i, 1], str(self.labels[i]),
                     color=plt.cm.Set1(self.labels[i] / 10.),
                     fontdict={'weight': 'bold', 'size': 9})
        plt.xticks([])
        plt.yticks([])
        plt.title('T-SNE')
        if save_eps:
            plt.savefig('tsne.eps', dpi=600, format='eps')
        plt.show() 
Example #5
Source File: plotFigures.py    From fullrmc with GNU Affero General Public License v3.0 7 votes vote down vote up
def plot(PDF, figName, imgpath, show=False, save=True):
    # plot
    output = PDF.get_constraint_value()
    plt.plot(PDF.experimentalDistances,PDF.experimentalPDF, 'ro', label="experimental", markersize=7.5, markevery=1 )
    plt.plot(PDF.shellsCenter, output["pdf"], 'k', linewidth=3.0,  markevery=25, label="total" )

    styleIndex = 0
    for key in output:
        val = output[key]
        if key in ("pdf_total", "pdf"):
            continue
        elif "inter" in key:
            plt.plot(PDF.shellsCenter, val, STYLE[styleIndex], markevery=5, label=key.split('rdf_inter_')[1] )
            styleIndex+=1
    plt.legend(frameon=False, ncol=1)
    # set labels
    plt.title("$\\chi^{2}=%.6f$"%PDF.squaredDeviations, size=20)
    plt.xlabel("$r (\AA)$", size=20)
    plt.ylabel("$g(r)$", size=20)
    # show plot
    if save: plt.savefig(figName)
    if show: plt.show()
    plt.close() 
Example #6
Source File: __init__.py    From EDeN with MIT License 7 votes vote down vote up
def plot_roc_curve(y_true, y_score, size=None):
    """plot_roc_curve."""
    false_positive_rate, true_positive_rate, thresholds = roc_curve(
        y_true, y_score)
    if size is not None:
        plt.figure(figsize=(size, size))
        plt.axis('equal')
    plt.plot(false_positive_rate, true_positive_rate, lw=2, color='navy')
    plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.ylim([-0.05, 1.05])
    plt.xlim([-0.05, 1.05])
    plt.grid()
    plt.title('Receiver operating characteristic AUC={0:0.2f}'.format(
        roc_auc_score(y_true, y_score))) 
Example #7
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 7 votes vote down vote up
def plot_evolution_results(hyp):  # from utils.utils import *; plot_evolution_results(hyp)
    # Plot hyperparameter evolution results in evolve.txt
    x = np.loadtxt('evolve.txt', ndmin=2)
    f = fitness(x)
    weights = (f - f.min()) ** 2  # for weighted results
    fig = plt.figure(figsize=(12, 10))
    matplotlib.rc('font', **{'size': 8})
    for i, (k, v) in enumerate(hyp.items()):
        y = x[:, i + 5]
        # mu = (y * weights).sum() / weights.sum()  # best weighted result
        mu = y[f.argmax()]  # best single result
        plt.subplot(4, 5, i + 1)
        plt.plot(mu, f.max(), 'o', markersize=10)
        plt.plot(y, f, '.')
        plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9})  # limit to 40 characters
        print('%15s: %.3g' % (k, mu))
    fig.tight_layout()
    plt.savefig('evolve.png', dpi=200) 
Example #8
Source File: utils.py    From dc_tts with Apache License 2.0 6 votes vote down vote up
def plot_alignment(alignment, gs, dir=hp.logdir):
    """Plots the alignment.

    Args:
      alignment: A numpy array with shape of (encoder_steps, decoder_steps)
      gs: (int) global step.
      dir: Output path.
    """
    if not os.path.exists(dir): os.mkdir(dir)

    fig, ax = plt.subplots()
    im = ax.imshow(alignment)

    fig.colorbar(im)
    plt.title('{} Steps'.format(gs))
    plt.savefig('{}/alignment_{}.png'.format(dir, gs), format='png')
    plt.close(fig) 
Example #9
Source File: stress_gui.py    From fenics-topopt with MIT License 6 votes vote down vote up
def update(self, xPhys, u, title=None):
        """Plot to screen"""
        self.im.set_array(-xPhys.reshape((self.nelx, self.nely)).T)
        stress = self.stress_calculator.calculate_stress(xPhys, u, self.nu)
        # self.stress_calculator.calculate_fdiff_stress(xPhys, u, self.nu)
        self.myColorMap.set_norm(colors.Normalize(vmin=0, vmax=max(stress)))
        stress_rgba = self.myColorMap.to_rgba(stress)
        stress_rgba[:, :, 3] = xPhys.reshape(-1, 1)
        self.stress_im.set_array(np.swapaxes(
            stress_rgba.reshape((self.nelx, self.nely, 4)), 0, 1))
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
        if title is not None:
            plt.title(title)
        else:
            plt.xlabel("Max stress = {:.2f}".format(max(stress)[0]))
        plt.pause(0.01) 
Example #10
Source File: malware.py    From trees with Apache License 2.0 6 votes vote down vote up
def classify(self, features, show=False):
        recs, _ = features.shape
        result_shape = (features.shape[0], len(self.root))
        scores = np.zeros(result_shape)
        print scores.shape
        R = Record(np.arange(recs, dtype=int), features)

        for i, T in enumerate(self.root):
            for idxs, result in classify(T, R):
                for idx in idxs.indexes():
                    scores[idx, i] = float(result[0]) / sum(result.values())


        if show:
            plt.cla()
            plt.clf()
            plt.close()

            plt.imshow(scores, cmap=plt.cm.gray)
            plt.title('Scores matrix')
            plt.savefig(r"../scratch/tree_scores.png", bbox_inches='tight')
        
        return scores 
Example #11
Source File: utils.py    From kss with Apache License 2.0 6 votes vote down vote up
def plot_alignment(alignment, gs, dir=hp.logdir):
    """Plots the alignment.

    Args:
      alignment: A numpy array with shape of (encoder_steps, decoder_steps)
      gs: (int) global step.
      dir: Output path.
    """
    if not os.path.exists(dir): os.mkdir(dir)

    fig, ax = plt.subplots()
    im = ax.imshow(alignment)

    fig.colorbar(im)
    plt.title('{} Steps'.format(gs))
    plt.savefig('{}/alignment_{}.png'.format(dir, gs), format='png') 
Example #12
Source File: display_methods.py    From indras_net with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, title, varieties, data_points,
                 anim=False, data_func=None, is_headless=False, legend_pos=4):
        global anim_func

        self.title = title
        self.anim = anim
        self.data_func = data_func
        for i in varieties:
            data_points = len(varieties[i]["data"])
            break
        self.draw_graph(data_points, varieties)
        self.headless = is_headless

        if anim and not self.headless:
            anim_func = animation.FuncAnimation(self.fig,
                                    self.update_plot,
                                    frames=1000,
                                    interval=500,
                                    blit=False) 
Example #13
Source File: display_methods.py    From indras_net with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, title, varieties, data_points, attrs,
                 anim=False, data_func=None, is_headless=False):
        global anim_func

        plt.close()
        self.legend = ["Type"]
        self.title = title
        # self.anim = anim
        # self.data_func = data_func
        for i in varieties:
            data_points = len(varieties[i]["data"])
            break
        self.headless = is_headless
        self.draw_graph(data_points, varieties, attrs)

        # if anim and not self.headless:
        #     anim_func = animation.FuncAnimation(self.fig,
        #                                         self.update_plot,
        #                                         frames=1000,
        #                                         interval=500,
        #                                         blit=False) 
Example #14
Source File: massachusetts_road_segm.py    From Recipes with MIT License 6 votes vote down vote up
def plot_some_results(pred_fn, test_generator, n_images=10):
    fig_ctr = 0
    for data, seg in test_generator:
        res = pred_fn(data)
        for d, s, r in zip(data, seg, res):
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 3, 1)
            plt.imshow(d.transpose(1,2,0))
            plt.title("input patch")
            plt.subplot(1, 3, 2)
            plt.imshow(s[0])
            plt.title("ground truth")
            plt.subplot(1, 3, 3)
            plt.imshow(r)
            plt.title("segmentation")
            plt.savefig("road_segmentation_result_%03.0f.png"%fig_ctr)
            plt.close()
            fig_ctr += 1
            if fig_ctr > n_images:
                break 
Example #15
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 6 votes vote down vote up
def plot_images(imgs, targets, paths=None, fname='images.jpg'):
    # Plots training images overlaid with targets
    imgs = imgs.cpu().numpy()
    targets = targets.cpu().numpy()
    # targets = targets[targets[:, 1] == 21]  # plot only one class

    fig = plt.figure(figsize=(10, 10))
    bs, _, h, w = imgs.shape  # batch size, _, height, width
    bs = min(bs, 16)  # limit plot to 16 images
    ns = np.ceil(bs ** 0.5)  # number of subplots

    for i in range(bs):
        boxes = xywh2xyxy(targets[targets[:, 0] == i, 2:6]).T
        boxes[[0, 2]] *= w
        boxes[[1, 3]] *= h
        plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0))
        plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-')
        plt.axis('off')
        if paths is not None:
            s = Path(paths[i]).name
            plt.title(s[:min(len(s), 40)], fontdict={'size': 8})  # limit to 40 characters
    fig.tight_layout()
    fig.savefig(fname, dpi=200)
    plt.close() 
Example #16
Source File: util.py    From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def compute_roc_rfeinman(probs_neg, probs_pos, plot=False):
    """
    TODO
    :param probs_neg:
    :param probs_pos:
    :param plot:
    :return:
    """
    probs = np.concatenate((probs_neg, probs_pos))
    labels = np.concatenate((np.zeros_like(probs_neg), np.ones_like(probs_pos)))
    fpr, tpr, _ = roc_curve(labels, probs)
    auc_score = auc(fpr, tpr)
    if plot:
        plt.figure(figsize=(7, 6))
        plt.plot(fpr, tpr, color='blue',
                 label='ROC (AUC = %0.4f)' % auc_score)
        plt.legend(loc='lower right')
        plt.title("ROC Curve")
        plt.xlabel("FPR")
        plt.ylabel("TPR")
        plt.show()

    return fpr, tpr, auc_score 
Example #17
Source File: plot_lfads.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def plot_time_series(vals_bxtxn, bidx=None, n_to_plot=np.inf, scale=1.0,
                     color='r', title=None):

  if bidx is None:
    vals_txn = np.mean(vals_bxtxn, axis=0)
  else:
    vals_txn = vals_bxtxn[bidx,:,:]

  T, N = vals_txn.shape
  if n_to_plot > N:
    n_to_plot = N

  plt.plot(vals_txn[:,0:n_to_plot] + scale*np.array(range(n_to_plot)),
           color=color, lw=1.0)
  plt.axis('tight')
  if title:
    plt.title(title) 
Example #18
Source File: dataset.py    From neural-combinatorial-optimization-rl-tensorflow with MIT License 6 votes vote down vote up
def visualize_sampling(self, permutations):
        max_length = len(permutations[0])
        grid = np.zeros([max_length,max_length]) # initialize heatmap grid to 0

        transposed_permutations = np.transpose(permutations)
        for t, cities_t in enumerate(transposed_permutations): # step t, cities chosen at step t
            city_indices, counts = np.unique(cities_t,return_counts=True,axis=0)
            for u,v in zip(city_indices, counts):
                grid[t][u]+=v # update grid with counts from the batch of permutations

        # plot heatmap
        fig = plt.figure()
        rcParams.update({'font.size': 22})
        ax = fig.add_subplot(1,1,1)
        ax.set_aspect('equal')
        plt.imshow(grid, interpolation='nearest', cmap='gray')
        plt.colorbar()
        plt.title('Sampled permutations')
        plt.ylabel('Time t')
        plt.xlabel('City i')
        plt.show() 
Example #19
Source File: plot_utils.py    From keras-anomaly-detection with MIT License 6 votes vote down vote up
def visualize_anomaly(y_true, reconstruction_error, threshold):
    error_df = pd.DataFrame({'reconstruction_error': reconstruction_error,
                             'true_class': y_true})
    print(error_df.describe())

    groups = error_df.groupby('true_class')
    fig, ax = plt.subplots()

    for name, group in groups:
        ax.plot(group.index, group.reconstruction_error, marker='o', ms=3.5, linestyle='',
                label="Fraud" if name == 1 else "Normal")

    ax.hlines(threshold, ax.get_xlim()[0], ax.get_xlim()[1], colors="r", zorder=100, label='Threshold')
    ax.legend()
    plt.title("Reconstruction error for different classes")
    plt.ylabel("Reconstruction error")
    plt.xlabel("Data point index")
    plt.show() 
Example #20
Source File: plotter.py    From deep-summarization with MIT License 5 votes vote down vote up
def plot_one_metric(self, models_metric, title):
        """

        :param models_metric:
        :param title:
        :return:
        """
        for index, model_metric in enumerate(models_metric):
            plt.plot(self.steps, model_metric, label=self.file_desc[index])
        plt.title(title)
        plt.legend()
        plt.xlabel('Number of batches')
        plt.ylabel('Score') 
Example #21
Source File: gail-eval.py    From HardRLWithYoutube with MIT License 5 votes vote down vote up
def plot(env_name, bc_log, gail_log, stochastic):
    upper_bound = bc_log['upper_bound']
    bc_avg_ret = bc_log['avg_ret']
    gail_avg_ret = gail_log['avg_ret']
    plt.plot(CONFIG['traj_limitation'], upper_bound)
    plt.plot(CONFIG['traj_limitation'], bc_avg_ret)
    plt.plot(CONFIG['traj_limitation'], gail_avg_ret)
    plt.xlabel('Number of expert trajectories')
    plt.ylabel('Accumulated reward')
    plt.title('{} unnormalized scores'.format(env_name))
    plt.legend(['expert', 'bc-imitator', 'gail-imitator'], loc='lower right')
    plt.grid(b=True, which='major', color='gray', linestyle='--')
    if stochastic:
        title_name = 'result/{}-unnormalized-stochastic-scores.png'.format(env_name)
    else:
        title_name = 'result/{}-unnormalized-deterministic-scores.png'.format(env_name)
    plt.savefig(title_name)
    plt.close()

    bc_normalized_ret = bc_log['normalized_ret']
    gail_normalized_ret = gail_log['normalized_ret']
    plt.plot(CONFIG['traj_limitation'], np.ones(len(CONFIG['traj_limitation'])))
    plt.plot(CONFIG['traj_limitation'], bc_normalized_ret)
    plt.plot(CONFIG['traj_limitation'], gail_normalized_ret)
    plt.xlabel('Number of expert trajectories')
    plt.ylabel('Normalized performance')
    plt.title('{} normalized scores'.format(env_name))
    plt.legend(['expert', 'bc-imitator', 'gail-imitator'], loc='lower right')
    plt.grid(b=True, which='major', color='gray', linestyle='--')
    if stochastic:
        title_name = 'result/{}-normalized-stochastic-scores.png'.format(env_name)
    else:
        title_name = 'result/{}-normalized-deterministic-scores.png'.format(env_name)
    plt.ylim(0, 1.6)
    plt.savefig(title_name)
    plt.close() 
Example #22
Source File: utils.py    From GST-Tacotron with MIT License 5 votes vote down vote up
def plot_alignment(alignment, gs):
    """Plots the alignment
    alignments: A list of (numpy) matrix of shape (encoder_steps, decoder_steps)
    gs : (int) global step
    """
    fig, ax = plt.subplots()
    im = ax.imshow(alignment)

    # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
    fig.colorbar(im)
    plt.title('{} Steps'.format(gs))
    plt.savefig('{}/alignment_{}k.png'.format(hp.logdir, gs // 1000), format='png') 
Example #23
Source File: display_methods.py    From indras_net with GNU General Public License v3.0 5 votes vote down vote up
def draw_graph(graph, title, hierarchy=False, root=None):
    """
    Drawing networkx graphs.
    graph is the graph to draw.
    hierarchy is whether we should draw it as a tree.
    """
    # pos = None
    plt.title(title)
    # if hierarchy:
    #     pos = hierarchy_pos(graph, root)
    # out for now:
    # nx.draw(graph, pos=pos, with_labels=True)
    plt.show() 
Example #24
Source File: results_plotter.py    From HardRLWithYoutube with MIT License 5 votes vote down vote up
def plot_curves(xy_list, xaxis, title):
    plt.figure(figsize=(8,2))
    maxx = max(xy[0][-1] for xy in xy_list)
    minx = 0
    for (i, (x, y)) in enumerate(xy_list):
        color = COLORS[i]
        plt.scatter(x, y, s=2)
        x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) #So returns average of last EPISODE_WINDOW episodes
        plt.plot(x, y_mean, color=color)
    plt.xlim(minx, maxx)
    plt.title(title)
    plt.xlabel(xaxis)
    plt.ylabel("Episode Rewards")
    plt.tight_layout() 
Example #25
Source File: plot_utils.py    From keras-anomaly-detection with MIT License 5 votes vote down vote up
def plot_training_history(history):
    if history is None:
        return
    plt.plot(history['loss'])
    plt.plot(history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper right')
    plt.show() 
Example #26
Source File: plot_utils.py    From keras-anomaly-detection with MIT License 5 votes vote down vote up
def visualize_reconstruction_error(reconstruction_error, threshold):
    plt.plot(reconstruction_error, marker='o', ms=3.5, linestyle='',
             label='Point')

    plt.hlines(threshold, xmin=0, xmax=len(reconstruction_error)-1, colors="r", zorder=100, label='Threshold')
    plt.legend()
    plt.title("Reconstruction error")
    plt.ylabel("Reconstruction error")
    plt.xlabel("Data point index")
    plt.show() 
Example #27
Source File: cli.py    From tmhmm.py with MIT License 5 votes vote down vote up
def plot(posterior_file, outputfile):
    inside, membrane, outside = load_posterior_file(posterior_file)

    plt.figure(figsize=(16, 8))
    plt.title('Posterior probabilities')
    plt.suptitle('tmhmm.py')
    plt.plot(inside, label='inside', color='blue')
    plt.plot(membrane, label='transmembrane', color='red')
    plt.fill_between(range(len(inside)), membrane, color='red')
    plt.plot(outside, label='outside', color='black')
    plt.legend(frameon=False, bbox_to_anchor=[0.5, 0],
               loc='upper center', ncol=3, borderaxespad=1.5)
    plt.tight_layout(pad=3)
    plt.savefig(outputfile) 
Example #28
Source File: plot_lfads.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def _plot_item(W, name, full_name, nspaces):
  plt.figure()
  if W.shape == ():
    print(name, ": ", W)
  elif W.shape[0] == 1:
    plt.stem(W.T)
    plt.title(full_name)
  elif W.shape[1] == 1:
    plt.stem(W)
    plt.title(full_name)
  else:
    plt.imshow(np.abs(W), interpolation='nearest', cmap='jet');
    plt.colorbar()
    plt.title(full_name) 
Example #29
Source File: plot_lfads.py    From DOTA_models with Apache License 2.0 5 votes vote down vote up
def plot_priors():
  g0s_prior_mean_bxn = train_modelvals['prior_g0_mean']
  g0s_prior_var_bxn = train_modelvals['prior_g0_var']
  g0s_post_mean_bxn = train_modelvals['posterior_g0_mean']
  g0s_post_var_bxn = train_modelvals['posterior_g0_var']

  plt.figure(figsize=(10,4), tight_layout=True);
  plt.subplot(1,2,1)
  plt.hist(g0s_post_mean_bxn.flatten(), bins=20, color='b');
  plt.hist(g0s_prior_mean_bxn.flatten(), bins=20, color='g');

  plt.title('Histogram of Prior/Posterior Mean Values')
  plt.subplot(1,2,2)
  plt.hist((g0s_post_var_bxn.flatten()), bins=20, color='b');
  plt.hist((g0s_prior_var_bxn.flatten()), bins=20, color='g');
  plt.title('Histogram of Prior/Posterior Log Variance Values')

  plt.figure(figsize=(10,10), tight_layout=True)
  plt.subplot(2,2,1)
  plt.imshow(g0s_prior_mean_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Prior g0 means')

  plt.subplot(2,2,2)
  plt.imshow(g0s_post_mean_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Posterior g0 means');

  plt.subplot(2,2,3)
  plt.imshow(g0s_prior_var_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Prior g0 variance Values')

  plt.subplot(2,2,4)
  plt.imshow(g0s_post_var_bxn.T, interpolation='nearest', cmap='jet')
  plt.colorbar(fraction=0.025, pad=0.04)
  plt.title('Posterior g0 variance Values')

  plt.figure(figsize=(10,5))
  plt.stem(np.sort(np.log(g0s_post_mean_bxn.std(axis=0))));
  plt.title('Log standard deviation of h0 means'); 
Example #30
Source File: dcgan.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def visual(title, X, name):
    assert len(X.shape) == 4
    X = X.transpose((0, 2, 3, 1))
    X = np.clip((X - np.min(X))*(255.0/(np.max(X) - np.min(X))), 0, 255).astype(np.uint8)
    n = np.ceil(np.sqrt(X.shape[0]))
    buff = np.zeros((int(n*X.shape[1]), int(n*X.shape[2]), int(X.shape[3])), dtype=np.uint8)
    for i, img in enumerate(X):
        fill_buf(buff, i, img, X.shape[1:3])
    buff = buff[:,:,::-1]
    plt.imshow(buff)
    plt.title(title)
    plt.savefig(name)