Python matplotlib.pyplot.tight_layout() Examples
The following are 30
code examples of matplotlib.pyplot.tight_layout().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
matplotlib.pyplot
, or try the search function
.
Example #1
Source File: data_augmentation.py From Sound-Recognition-Tutorial with Apache License 2.0 | 10 votes |
def demo_plot(): audio = './data/esc10/audio/Dog/1-30226-A.ogg' y, sr = librosa.load(audio, sr=44100) y_ps = librosa.effects.pitch_shift(y, sr, n_steps=6) # n_steps控制音调变化尺度 y_ts = librosa.effects.time_stretch(y, rate=1.2) # rate控制时间维度的变换尺度 plt.subplot(311) plt.plot(y) plt.title('Original waveform') plt.axis([0, 200000, -0.4, 0.4]) # plt.axis([88000, 94000, -0.4, 0.4]) plt.subplot(312) plt.plot(y_ts) plt.title('Time Stretch transformed waveform') plt.axis([0, 200000, -0.4, 0.4]) plt.subplot(313) plt.plot(y_ps) plt.title('Pitch Shift transformed waveform') plt.axis([0, 200000, -0.4, 0.4]) # plt.axis([88000, 94000, -0.4, 0.4]) plt.tight_layout() plt.show()
Example #2
Source File: plot_utils.py From celer with BSD 3-Clause "New" or "Revised" License | 7 votes |
def plot_path_hist(results, labels, tols, figsize, ylim=None): configure_plt() sns.set_palette('colorblind') n_competitors = len(results) fig, ax = plt.subplots(figsize=figsize) width = 1. / (n_competitors + 1) ind = np.arange(len(tols)) b = (1 - n_competitors) / 2. for i in range(n_competitors): plt.bar(ind + (i + b) * width, results[i], width, label=labels[i]) ax.set_ylabel('path computation time (s)') ax.set_xticks(ind + width / 2) plt.xticks(range(len(tols)), ["%.0e" % tol for tol in tols]) if ylim is not None: plt.ylim(ylim) ax.set_xlabel(r"$\epsilon$") plt.legend(loc='upper left') plt.tight_layout() plt.show(block=False) return fig
Example #3
Source File: pearsons_filtering.py From simba with GNU Lesser General Public License v3.0 | 7 votes |
def pearson_filter(projectPath, featuresDf, del_corr_status, del_corr_threshold, del_corr_plot_status): print('Reducing features. Correlation threshold: ' + str(del_corr_threshold)) col_corr = set() corr_matrix = featuresDf.corr() for i in range(len(corr_matrix.columns)): for j in range(i): if (corr_matrix.iloc[i, j] >= del_corr_threshold) and (corr_matrix.columns[j] not in col_corr): colname = corr_matrix.columns[i] col_corr.add(colname) if colname in featuresDf.columns: del featuresDf[colname] if del_corr_plot_status == 'yes': print('Creating feature correlation heatmap...') dateTime = datetime.now().strftime('%Y%m%d%H%M%S') plt.matshow(featuresDf.corr()) plt.tight_layout() plt.savefig(os.path.join(projectPath, 'logs', 'Feature_correlations_' + dateTime + '.png'), dpi=300) plt.close('all') print('Feature correlation heatmap .png saved in project_folder/logs directory') return featuresDf
Example #4
Source File: visualise_att_maps_epoch.py From Attention-Gated-Networks with MIT License | 7 votes |
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None): plt.ion() filters = units.shape[2] n_columns = round(math.sqrt(filters)) n_rows = math.ceil(filters / n_columns) + 1 fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3)) fig.clf() for i in range(filters): ax1 = plt.subplot(n_rows, n_columns, i+1) plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap) plt.axis('on') ax1.set_xticklabels([]) ax1.set_yticklabels([]) plt.colorbar() if colormap_lim: plt.clim(colormap_lim[0],colormap_lim[1]) plt.subplots_adjust(wspace=0, hspace=0) plt.tight_layout() # Epochs
Example #5
Source File: thames.py From pywr with GNU General Public License v3.0 | 6 votes |
def figures(ext, show): for name, df in TablesRecorder.generate_dataframes('thames_output.h5'): df.columns = ['Very low', 'Low', 'Central', 'High', 'Very high'] fig, (ax1, ax2) = plt.subplots(figsize=(12, 4), ncols=2, sharey='row', gridspec_kw={'width_ratios': [3, 1]}) df['2100':'2125'].plot(ax=ax1) df.quantile(np.linspace(0, 1)).plot(ax=ax2) if name.startswith('reservoir'): ax1.set_ylabel('Volume [$Mm^3$]') else: ax1.set_ylabel('Flow [$Mm^3/day$]') for ax in (ax1, ax2): ax.set_title(name) ax.grid(True) plt.tight_layout() if ext is not None: fig.savefig(f'{name}.{ext}', dpi=300) if show: plt.show()
Example #6
Source File: test_frame.py From recruit with Apache License 2.0 | 6 votes |
def test_if_scatterplot_colorbars_are_next_to_parent_axes(self): import matplotlib.pyplot as plt random_array = np.random.random((1000, 3)) df = pd.DataFrame(random_array, columns=['A label', 'B label', 'C label']) fig, axes = plt.subplots(1, 2) df.plot.scatter('A label', 'B label', c='C label', ax=axes[0]) df.plot.scatter('A label', 'B label', c='C label', ax=axes[1]) plt.tight_layout() points = np.array([ax.get_position().get_points() for ax in fig.axes]) axes_x_coords = points[:, :, 0] parent_distance = axes_x_coords[1, :] - axes_x_coords[0, :] colorbar_distance = axes_x_coords[3, :] - axes_x_coords[2, :] assert np.isclose(parent_distance, colorbar_distance, atol=1e-7).all()
Example #7
Source File: data_provider.py From ICDAR-2019-SROIE with MIT License | 6 votes |
def generator(vis=False): image_list = np.array(get_training_data()) print('{} training images in {}'.format(image_list.shape[0], DATA_FOLDER)) index = np.arange(0, image_list.shape[0]) while True: np.random.shuffle(index) for i in index: try: im_fn = image_list[i] im = cv2.imread(im_fn) h, w, c = im.shape im_info = np.array([h, w, c]).reshape([1, 3]) _, fn = os.path.split(im_fn) fn, _ = os.path.splitext(fn) txt_fn = os.path.join(DATA_FOLDER, "label", fn + '.txt') if not os.path.exists(txt_fn): print("Ground truth for image {} not exist!".format(im_fn)) continue bbox = load_annoataion(txt_fn) if len(bbox) == 0: print("Ground truth for image {} empty!".format(im_fn)) continue if vis: for p in bbox: cv2.rectangle(im, (p[0], p[1]), (p[2], p[3]), color=(0, 0, 255), thickness=1) fig, axs = plt.subplots(1, 1, figsize=(30, 30)) axs.imshow(im[:, :, ::-1]) axs.set_xticks([]) axs.set_yticks([]) plt.tight_layout() plt.show() plt.close() yield [im], bbox, im_info except Exception as e: print(e) continue
Example #8
Source File: plot.py From TaskBot with GNU General Public License v3.0 | 6 votes |
def plot_attention(sentences, attentions, labels, **kwargs): fig, ax = plt.subplots(**kwargs) im = ax.imshow(attentions, interpolation='nearest', vmin=attentions.min(), vmax=attentions.max()) plt.colorbar(im, shrink=0.5, ticks=[0, 1]) plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") ax.set_yticks(range(len(labels))) ax.set_yticklabels(labels, fontproperties=getChineseFont()) # Loop over data dimensions and create text annotations. for i in range(attentions.shape[0]): for j in range(attentions.shape[1]): text = ax.text(j, i, sentences[i][j], ha="center", va="center", color="b", size=10, fontproperties=getChineseFont()) ax.set_title("Attention Visual") fig.tight_layout() plt.show()
Example #9
Source File: simplest_raster_plot.py From ibllib with MIT License | 6 votes |
def raster_complete(R, times, Clusters): ''' Plot a rasterplot for the complete recording (might be slow, restrict R if so), ordered by insertion depth ''' plt.imshow(R, aspect='auto', cmap='binary', vmax=T_BIN / 0.001 / 4, origin='lower', extent=np.r_[times[[0, -1]], Clusters[[0, -1]]]) plt.xlabel('Time (s)') plt.ylabel('Cluster #; ordered by depth') plt.show() # plt.savefig('/home/mic/Rasters/%s.svg' %(trial_number)) # plt.close('all') plt.tight_layout()
Example #10
Source File: time_bench.py From astroalign with MIT License | 6 votes |
def plot_command(self, ns): import matplotlib.pyplot as plt results = pd.read_csv(ns.file) orientation = COLSROWS[ns.orientation] size = ns.size if ns.size else DEFAULT_SIZES[ns.orientation] fig, axes = plt.subplots(**orientation) fig.set_size_inches(*size) plot(results, *axes) fig.suptitle("") plt.tight_layout() if ns.out is None: print(f"Showing plot for data stored in '{ns.file.name}'...") fig.canvas.set_window_title(f"{self.parser.prog} - {ns.file.name}") plt.show() else: print( f"Storing plot for data in '{ns.file.name}' -> '{ns.out}'...") plt.savefig(ns.out) print("DONE!")
Example #11
Source File: flux_bench.py From astroalign with MIT License | 6 votes |
def plot_command(self, ns): import matplotlib.pyplot as plt results = pd.read_csv(ns.file) size = ns.size if ns.size else DEFAULT_SIZE fig, ax = plt.subplots() fig.set_size_inches(*size) plot(results, ax) fig.suptitle("") plt.tight_layout() if ns.out is None: print(f"Showing plot for data stored in '{ns.file.name}'...") fig.canvas.set_window_title(f"{self.parser.prog} - {ns.file.name}") plt.show() else: print( f"Storing plot for data in '{ns.file.name}' -> '{ns.out}'...") plt.savefig(ns.out) print("DONE!")
Example #12
Source File: time_regression.py From astroalign with MIT License | 6 votes |
def plot_command(self, ns): import matplotlib.pyplot as plt results = pd.read_csv(ns.file) size = ns.size if ns.size else DEFAULT_SIZE fig, ax = plt.subplots() fig.set_size_inches(*size) plot(results, ax) fig.suptitle("") plt.tight_layout() if ns.out is None: print(f"Showing plot for data stored in '{ns.file.name}'...") fig.canvas.set_window_title(f"{self.parser.prog} - {ns.file.name}") plt.show() else: print( f"Storing plot for data in '{ns.file.name}' -> '{ns.out}'...") plt.savefig(ns.out) print("DONE!")
Example #13
Source File: visualise_fmaps.py From Attention-Gated-Networks with MIT License | 6 votes |
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None): plt.ion() filters = units.shape[2] n_columns = round(math.sqrt(filters)) n_rows = math.ceil(filters / n_columns) + 1 fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3)) fig.clf() for i in range(filters): ax1 = plt.subplot(n_rows, n_columns, i+1) plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap) plt.axis('on') ax1.set_xticklabels([]) ax1.set_yticklabels([]) plt.colorbar() if colormap_lim: plt.clim(colormap_lim[0],colormap_lim[1]) plt.subplots_adjust(wspace=0, hspace=0) plt.tight_layout() # Load options
Example #14
Source File: visualise_attention.py From Attention-Gated-Networks with MIT License | 6 votes |
def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None, title=''): plt.ion() filters = units.shape[2] n_columns = round(math.sqrt(filters)) n_rows = math.ceil(filters / n_columns) + 1 fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3)) fig.clf() for i in range(filters): ax1 = plt.subplot(n_rows, n_columns, i+1) plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap) plt.axis('on') ax1.set_xticklabels([]) ax1.set_yticklabels([]) plt.colorbar() if colormap_lim: plt.clim(colormap_lim[0],colormap_lim[1]) plt.subplots_adjust(wspace=0, hspace=0) plt.tight_layout() plt.suptitle(title)
Example #15
Source File: visualise_attention.py From Attention-Gated-Networks with MIT License | 6 votes |
def plotNNFilterOverlay(input_im, units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None, title='', alpha=0.8): plt.ion() filters = units.shape[2] fig = plt.figure(figure_id, figsize=(5,5)) fig.clf() for i in range(filters): plt.imshow(input_im[:,:,0], interpolation=interp, cmap='gray') plt.imshow(units[:,:,i], interpolation=interp, cmap=colormap, alpha=alpha) plt.axis('off') plt.colorbar() plt.title(title, fontsize='small') if colormap_lim: plt.clim(colormap_lim[0],colormap_lim[1]) plt.subplots_adjust(wspace=0, hspace=0) plt.tight_layout() # plt.savefig('{}/{}.png'.format(dir_name,time.time())) ## Load options
Example #16
Source File: core.py From prickle with MIT License | 6 votes |
def imshow(data, which, levels): """ Display order book data as an image, where order book data is either of `df_price` or `df_volume` returned by `load_hdf5` or `load_postgres`. """ if which == 'prices': idx = ['askprc.' + str(i) for i in range(levels, 0, -1)] idx.extend(['bidprc.' + str(i) for i in range(1, levels + 1, 1)]) elif which == 'volumes': idx = ['askvol.' + str(i) for i in range(levels, 0, -1)] idx.extend(['bidvol.' + str(i) for i in range(1, levels + 1, 1)]) plt.imshow(data.loc[:, idx].T, interpolation='nearest', aspect='auto') plt.yticks(range(0, levels * 2, 1), idx) plt.colorbar() plt.tight_layout() plt.show()
Example #17
Source File: demo.py From TFFRCNN with MIT License | 5 votes |
def vis_detections(im, class_name, dets, ax, thresh=0.5): """Draw detected bounding boxes.""" inds = np.where(dets[:, -1] >= thresh)[0] if len(inds) == 0: return for i in inds: bbox = dets[i, :4] score = dets[i, -1] ax.add_patch( plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False, edgecolor='red', linewidth=3.5) ) ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5), fontsize=14, color='white') ax.set_title(('{} detections with ' 'p({} | box) >= {:.1f}').format(class_name, class_name, thresh), fontsize=14) plt.axis('off') plt.tight_layout() plt.draw()
Example #18
Source File: visualization_utils.py From ludwig with Apache License 2.0 | 5 votes |
def predictions_distribution_plot( probabilities, algorithm_names=None, filename=None ): sns.set_style('whitegrid') colors = plt.get_cmap('tab10').colors num_algorithms = len(probabilities) plt.figure(figsize=(9, 9)) plt.grid(which='both') plt.grid(which='minor', alpha=0.5) plt.grid(which='major', alpha=0.75) for i in range(num_algorithms): plt.hist(probabilities[i], range=(0, 1), bins=41, color=colors[i], label=algorithm_names[ i] if algorithm_names is not None and i < len( algorithm_names) else '', histtype='stepfilled', alpha=0.5, lw=2) plt.xlabel('Mean predicted value') plt.xlim([0, 1]) plt.xticks(np.linspace(0.0, 1.0, num=21)) plt.ylabel('Count') plt.legend(loc='upper center', ncol=2) plt.tight_layout() ludwig.contrib.contrib_command("visualize_figure", plt.gcf()) if filename: plt.savefig(filename) else: plt.show()
Example #19
Source File: visualization_utils.py From ludwig with Apache License 2.0 | 5 votes |
def compare_classifiers_multiclass_multimetric_plot( scores, metrics, labels=None, title=None, filename=None ): assert len(scores) > 0 sns.set_style('whitegrid') fig, ax = plt.subplots() if title is not None: ax.set_title(title) width = 0.9 / len(scores) ticks = np.arange(len(scores[0])) colors = plt.get_cmap('tab10').colors ax.set_xlabel('class') ax.set_xticks(ticks + width) if labels is not None: ax.set_xticklabels(labels, rotation=90) else: ax.set_xticklabels(ticks, rotation=90) for i, score in enumerate(scores): ax.bar(ticks + i * width, score, width, label=metrics[i], color=colors[i]) ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) plt.tight_layout() ludwig.contrib.contrib_command("visualize_figure", plt.gcf()) if filename: plt.savefig(filename) else: plt.show()
Example #20
Source File: generate.py From TFFRCNN with MIT License | 5 votes |
def _vis_proposals(im, dets, thresh=0.5): """Draw detected bounding boxes.""" inds = np.where(dets[:, -1] >= thresh)[0] if len(inds) == 0: return class_name = 'obj' im = im[:, :, (2, 1, 0)] fig, ax = plt.subplots(figsize=(12, 12)) ax.imshow(im, aspect='equal') for i in inds: bbox = dets[i, :4] score = dets[i, -1] ax.add_patch( plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False, edgecolor='red', linewidth=3.5) ) ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5), fontsize=14, color='white') ax.set_title(('{} detections with ' 'p({} | box) >= {:.1f}').format(class_name, class_name, thresh), fontsize=14) plt.axis('off') plt.tight_layout() plt.draw()
Example #21
Source File: random_walk.py From reinforcement-learning-an-introduction with MIT License | 5 votes |
def example_6_2(): plt.figure(figsize=(10, 20)) plt.subplot(2, 1, 1) compute_state_value() plt.subplot(2, 1, 2) rms_error() plt.tight_layout() plt.savefig('../images/example_6_2.png') plt.close()
Example #22
Source File: visualization_utils.py From ludwig with Apache License 2.0 | 5 votes |
def compare_classifiers_line_plot( xs, scores, metric, algorithm_names=None, title=None, filename=None ): sns.set_style('whitegrid') colors = plt.get_cmap('tab10').colors fig, ax = plt.subplots() ax.grid(which='both') ax.grid(which='minor', alpha=0.5) ax.grid(which='major', alpha=0.75) if title is not None: ax.set_title(title) ax.set_xticks(xs) ax.set_xticklabels(xs) ax.set_xlabel('k') ax.set_ylabel(metric) for i, score in enumerate(scores): ax.plot(xs, score, label=algorithm_names[ i] if algorithm_names is not None and i < len( algorithm_names) else 'Algorithm {}'.format(i), color=colors[i], linewidth=3, marker='o') ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) plt.tight_layout() ludwig.contrib.contrib_command("visualize_figure", plt.gcf()) if filename: plt.savefig(filename) else: plt.show()
Example #23
Source File: DLC_pupil_event.py From ibllib with MIT License | 5 votes |
def plot_mean_std_around_event(event, diameter, times, eid): ''' event in {'stimOn_times', 'feedback_times', 'stimOff_times'} ''' event_times = trials[event] window_size = 70 segments = [] # skip first and last trials to get same window length for t in event_times[5:-5]: idx = find_nearest(times, t) segments.append(diameter[idx - window_size: idx + window_size]) M = np.nanmean(np.array(segments), axis=0) E = np.nanstd(np.array(segments), axis=0) fig, ax = plt.subplots() ax.fill_between( range( len(M)), M - E, M + E, alpha=0.5, edgecolor='#CC4F1B', facecolor='#FF9848') plt.plot(range(len(M)), M, color='k', linewidth=3) plt.axvline(x=window_size, color='r', linewidth=1, label=event) plt.legend() plt.ylabel('pupil diameter [px]') plt.xlabel('frames') plt.title(eid) plt.tight_layout()
Example #24
Source File: plots.py From yatsm with MIT License | 5 votes |
def plot_crossvalidation_scores(kfold_scores, test_labels): """ Plots KFold test summary statistics Args: kfold_scores (np.ndarray): n by 2 shaped array of mean and standard deviation of KFold scores test_labels (list): n length list of KFold label names """ return ind = np.arange(kfold_scores.shape[0]) width = 0.5 fig, ax = plt.subplots() bars = ax.bar(ind, kfold_scores[:, 0], width) _, caplines, _ = ax.errorbar(ind + width / 2.0, kfold_scores[:, 0], fmt='none', yerr=kfold_scores[:, 1], capsize=10, elinewidth=3) for capline in caplines: capline.set_linewidth(10) capline.set_markeredgewidth(3) capline.set_color('red') for i, bar in enumerate(bars): txt = r'%.3f $\pm$ %.3f' % (kfold_scores[i, 0], kfold_scores[i, 1]) ax.text(ind[i] + width / 2.0, kfold_scores[i, 0] / 2.0, txt, ha='center', va='bottom', size='large') ax.set_xticks(ind + width / 2.0) ax.set_xticklabels(test_labels, ha='center') # plt.ylim((0, 1.0)) plt.title('KFold Cross Validation Summary Statistics') plt.xlabel('Test') plt.ylabel(r'Accuracy ($\pm$ standard deviation)') plt.tight_layout() plt.show()
Example #25
Source File: bsds300.py From nsf with MIT License | 5 votes |
def main(): dataset = BSDS300Dataset(split='train') print(type(dataset.data)) print(dataset.data.shape) print(dataset.data.min(), dataset.data.max()) fig, axs = plt.subplots(8, 8, figsize=(10, 10), sharex=True, sharey=True) axs = axs.reshape(-1) for i, dimension in enumerate(dataset.data.T): axs[i].hist(dimension, bins=100) # plt.hist(dataset.data.reshape(-1), bins=250) plt.tight_layout() plt.show() print(len(dataset)) loader = data.DataLoader(dataset, batch_size=128, drop_last=True) print(len(loader))
Example #26
Source File: gas.py From nsf with MIT License | 5 votes |
def main(): dataset = GasDataset(split='train') print(type(dataset.data)) print(dataset.data.shape) print(dataset.data.min(), dataset.data.max()) print(np.where(dataset.data == dataset.data.max())) fig, axs = plt.subplots(3, 3, figsize=(10, 10), sharex=True, sharey=True) axs = axs.reshape(-1) for i, dimension in enumerate(dataset.data.T): print(i) axs[i].hist(dimension, bins=100) plt.tight_layout() plt.show()
Example #27
Source File: visualize.py From dataiku-contrib with Apache License 2.0 | 5 votes |
def plot_overlaps(gt_class_ids, pred_class_ids, pred_scores, overlaps, class_names, threshold=0.5): """Draw a grid showing how ground truth objects are classified. gt_class_ids: [N] int. Ground truth class IDs pred_class_id: [N] int. Predicted class IDs pred_scores: [N] float. The probability scores of predicted classes overlaps: [pred_boxes, gt_boxes] IoU overlaps of predictions and GT boxes. class_names: list of all class names in the dataset threshold: Float. The prediction probability required to predict a class """ gt_class_ids = gt_class_ids[gt_class_ids != 0] pred_class_ids = pred_class_ids[pred_class_ids != 0] plt.figure(figsize=(12, 10)) plt.imshow(overlaps, interpolation='nearest', cmap=plt.cm.Blues) plt.yticks(np.arange(len(pred_class_ids)), ["{} ({:.2f})".format(class_names[int(id)], pred_scores[i]) for i, id in enumerate(pred_class_ids)]) plt.xticks(np.arange(len(gt_class_ids)), [class_names[int(id)] for id in gt_class_ids], rotation=90) thresh = overlaps.max() / 2. for i, j in itertools.product(range(overlaps.shape[0]), range(overlaps.shape[1])): text = "" if overlaps[i, j] > threshold: text = "match" if gt_class_ids[j] == pred_class_ids[i] else "wrong" color = ("white" if overlaps[i, j] > thresh else "black" if overlaps[i, j] > 0 else "grey") plt.text(j, i, "{:.3f}\n{}".format(overlaps[i, j], text), horizontalalignment="center", verticalalignment="center", fontsize=9, color=color) plt.tight_layout() plt.xlabel("Ground Truth") plt.ylabel("Predictions")
Example #28
Source File: cluster.py From 2D-Motion-Retargeting with MIT License | 5 votes |
def cluster_motion(net, cluster_data, device, save_path, nr_anims=15, mode='both'): data, animations = cluster_data[0], cluster_data[1] idx = np.linspace(0, data.shape[0] - 1, nr_anims, dtype=int).tolist() data = data[idx] animations = animations[idx] if mode == 'body': data = data[:, :, 0, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4]) elif mode == 'view': data = data[:, 3, :, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4]) else: data = data[:, :4, ::2, :, :].reshape(nr_anims, -1, data.shape[3], data.shape[4]) nr_anims, nr_cv = data.shape[:2] labels = np.arange(0, nr_anims).reshape(-1, 1) labels = np.tile(labels, (1, nr_cv)).reshape(-1) features = net.mot_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3]).to(device)) features = features.detach().cpu().numpy().reshape(features.shape[0], -1) features_2d = tsne_on_pca(features) features_2d = features_2d.reshape(nr_anims, nr_cv, -1) if features_2d.shape[1] < 5: features_2d = np.tile(features_2d, (1, 2, 1)) plt.figure(figsize=(8, 4)) colors = cm.rainbow(np.linspace(0, 1, nr_anims)) for i in range(nr_anims): x = features_2d[i, :, 0] y = features_2d[i, :, 1] plt.scatter(x, y, c=colors[i], label=animations[i]) plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0) plt.tight_layout(rect=[0,0,0.8,1]) plt.savefig(save_path)
Example #29
Source File: cluster.py From 2D-Motion-Retargeting with MIT License | 5 votes |
def cluster_view(net, cluster_data, device, save_path): data, views = cluster_data[0], cluster_data[3] idx = np.random.randint(data.shape[1] - 1) # np.linspace(0, data.shape[1] - 1, 4, dtype=int).tolist() data = data[:, idx, :, :, :] nr_mc, nr_view = data.shape[0], data.shape[1] labels = np.arange(0, nr_view).reshape(1, -1) labels = np.tile(labels, (nr_mc, 1)).reshape(-1) if hasattr(net, 'static_encoder'): features = net.static_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device)) else: features = net.view_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device)) features = features.detach().cpu().numpy().reshape(features.shape[0], -1) features_2d = tsne_on_pca(features, is_PCA=False) features_2d = features_2d.reshape(nr_mc, nr_view, -1) plt.figure(figsize=(7, 4)) colors = cm.rainbow(np.linspace(0, 1, nr_view)) for i in range(nr_view): x = features_2d[:, i, 0] y = features_2d[:, i, 1] plt.scatter(x, y, c=colors[i], label=views[i]) plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0) plt.tight_layout(rect=[0, 0, 0.75, 1]) plt.savefig(save_path)
Example #30
Source File: cluster.py From 2D-Motion-Retargeting with MIT License | 5 votes |
def cluster_body(net, cluster_data, device, save_path): data, characters = cluster_data[0], cluster_data[2] data = data[:, :, 0, :, :] # data = data.reshape(-1, data.shape[2], data.shape[3], data.shape[4]) nr_mv, nr_char = data.shape[0], data.shape[1] labels = np.arange(0, nr_char).reshape(1, -1) labels = np.tile(labels, (nr_mv, 1)).reshape(-1) if hasattr(net, 'static_encoder'): features = net.static_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device)) else: features = net.body_encoder(data.contiguous().view(-1, data.shape[2], data.shape[3])[:, :-2, :].to(device)) features = features.detach().cpu().numpy().reshape(features.shape[0], -1) features_2d = tsne_on_pca(features, is_PCA=False) features_2d = features_2d.reshape(nr_mv, nr_char, -1) plt.figure(figsize=(7, 4)) colors = cm.rainbow(np.linspace(0, 1, nr_char)) for i in range(nr_char): x = features_2d[:, i, 0] y = features_2d[:, i, 1] plt.scatter(x, y, c=colors[i], label=characters[i]) plt.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0) plt.tight_layout(rect=[0,0,0.75,1]) plt.savefig(save_path)