Python matplotlib.pyplot.savefig() Examples
The following are 30
code examples of matplotlib.pyplot.savefig().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
matplotlib.pyplot
, or try the search function
.
Example #1
Source File: utils.py From pruning_yolov3 with GNU General Public License v3.0 | 8 votes |
def plot_wh_methods(): # from utils.utils import *; plot_wh_methods() # Compares the two methods for width-height anchor multiplication # https://github.com/ultralytics/yolov3/issues/168 x = np.arange(-4.0, 4.0, .1) ya = np.exp(x) yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2 fig = plt.figure(figsize=(6, 3), dpi=150) plt.plot(x, ya, '.-', label='yolo method') plt.plot(x, yb ** 2, '.-', label='^2 power method') plt.plot(x, yb ** 2.5, '.-', label='^2.5 power method') plt.xlim(left=-4, right=4) plt.ylim(bottom=0, top=6) plt.xlabel('input') plt.ylabel('output') plt.legend() fig.tight_layout() fig.savefig('comparison.png', dpi=200)
Example #2
Source File: feature_vis.py From transferlearning with MIT License | 8 votes |
def plot_tsne(self, save_eps=False): ''' Plot TSNE figure. Set save_eps=True if you want to save a .eps file. ''' tsne = TSNE(n_components=2, init='pca', random_state=0) features = tsne.fit_transform(self.features) x_min, x_max = np.min(features, 0), np.max(features, 0) data = (features - x_min) / (x_max - x_min) del features for i in range(data.shape[0]): plt.text(data[i, 0], data[i, 1], str(self.labels[i]), color=plt.cm.Set1(self.labels[i] / 10.), fontdict={'weight': 'bold', 'size': 9}) plt.xticks([]) plt.yticks([]) plt.title('T-SNE') if save_eps: plt.savefig('tsne.eps', dpi=600, format='eps') plt.show()
Example #3
Source File: plotFigures.py From fullrmc with GNU Affero General Public License v3.0 | 7 votes |
def plot(PDF, figName, imgpath, show=False, save=True): # plot output = PDF.get_constraint_value() plt.plot(PDF.experimentalDistances,PDF.experimentalPDF, 'ro', label="experimental", markersize=7.5, markevery=1 ) plt.plot(PDF.shellsCenter, output["pdf"], 'k', linewidth=3.0, markevery=25, label="total" ) styleIndex = 0 for key in output: val = output[key] if key in ("pdf_total", "pdf"): continue elif "inter" in key: plt.plot(PDF.shellsCenter, val, STYLE[styleIndex], markevery=5, label=key.split('rdf_inter_')[1] ) styleIndex+=1 plt.legend(frameon=False, ncol=1) # set labels plt.title("$\\chi^{2}=%.6f$"%PDF.squaredDeviations, size=20) plt.xlabel("$r (\AA)$", size=20) plt.ylabel("$g(r)$", size=20) # show plot if save: plt.savefig(figName) if show: plt.show() plt.close()
Example #4
Source File: utils.py From pruning_yolov3 with GNU General Public License v3.0 | 7 votes |
def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_results(hyp) # Plot hyperparameter evolution results in evolve.txt x = np.loadtxt('evolve.txt', ndmin=2) f = fitness(x) weights = (f - f.min()) ** 2 # for weighted results fig = plt.figure(figsize=(12, 10)) matplotlib.rc('font', **{'size': 8}) for i, (k, v) in enumerate(hyp.items()): y = x[:, i + 5] # mu = (y * weights).sum() / weights.sum() # best weighted result mu = y[f.argmax()] # best single result plt.subplot(4, 5, i + 1) plt.plot(mu, f.max(), 'o', markersize=10) plt.plot(y, f, '.') plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters print('%15s: %.3g' % (k, mu)) fig.tight_layout() plt.savefig('evolve.png', dpi=200)
Example #5
Source File: utils.py From pruning_yolov3 with GNU General Public License v3.0 | 6 votes |
def plot_results_overlay(start=0, stop=0): # from utils.utils import *; plot_results_overlay() # Plot training results files 'results*.txt', overlaying train and val losses s = ['train', 'train', 'train', 'Precision', 'mAP', 'val', 'val', 'val', 'Recall', 'F1'] # legends t = ['GIoU', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')): results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T n = results.shape[1] # number of rows x = range(start, min(stop, n) if stop else n) fig, ax = plt.subplots(1, 5, figsize=(14, 3.5)) ax = ax.ravel() for i in range(5): for j in [i, i + 5]: y = results[j, x] if i in [0, 1, 2]: y[y == 0] = np.nan # dont show zero loss values ax[i].plot(x, y, marker='.', label=s[j]) ax[i].set_title(t[i]) ax[i].legend() ax[i].set_ylabel(f) if i == 0 else None # add filename fig.tight_layout() fig.savefig(f.replace('.txt', '.png'), dpi=200)
Example #6
Source File: utils.py From scanorama with MIT License | 6 votes |
def visualize_cluster(coords, cluster, cluster_labels, cluster_name=None, size=1, viz_prefix='vc', image_suffix='.svg'): if not cluster_name: cluster_name = cluster labels = [ 1 if c_i == cluster else 0 for c_i in cluster_labels ] c_idx = [ i for i in range(len(labels)) if labels[i] == 1 ] nc_idx = [ i for i in range(len(labels)) if labels[i] == 0 ] colors = np.array([ '#cccccc', '#377eb8' ]) image_fname = '{}_cluster{}{}'.format( viz_prefix, cluster, image_suffix ) plt.figure() plt.scatter(coords[nc_idx, 0], coords[nc_idx, 1], c=colors[0], s=size) plt.scatter(coords[c_idx, 0], coords[c_idx, 1], c=colors[1], s=size) plt.title(str(cluster_name)) plt.savefig(image_fname, dpi=500)
Example #7
Source File: utils.py From pruning_yolov3 with GNU General Public License v3.0 | 6 votes |
def plot_images(imgs, targets, paths=None, fname='images.jpg'): # Plots training images overlaid with targets imgs = imgs.cpu().numpy() targets = targets.cpu().numpy() # targets = targets[targets[:, 1] == 21] # plot only one class fig = plt.figure(figsize=(10, 10)) bs, _, h, w = imgs.shape # batch size, _, height, width bs = min(bs, 16) # limit plot to 16 images ns = np.ceil(bs ** 0.5) # number of subplots for i in range(bs): boxes = xywh2xyxy(targets[targets[:, 0] == i, 2:6]).T boxes[[0, 2]] *= w boxes[[1, 3]] *= h plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0)) plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-') plt.axis('off') if paths is not None: s = Path(paths[i]).name plt.title(s[:min(len(s), 40)], fontdict={'size': 8}) # limit to 40 characters fig.tight_layout() fig.savefig(fname, dpi=200) plt.close()
Example #8
Source File: utils.py From pruning_yolov3 with GNU General Public License v3.0 | 6 votes |
def plot_test_txt(): # from utils.utils import *; plot_test() # Plot test.txt histograms x = np.loadtxt('test.txt', dtype=np.float32) box = xyxy2xywh(x[:, :4]) cx, cy = box[:, 0], box[:, 1] fig, ax = plt.subplots(1, 1, figsize=(6, 6)) ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0) ax.set_aspect('equal') fig.tight_layout() plt.savefig('hist2d.jpg', dpi=300) fig, ax = plt.subplots(1, 2, figsize=(12, 6)) ax[0].hist(cx, bins=600) ax[1].hist(cy, bins=600) fig.tight_layout() plt.savefig('hist1d.jpg', dpi=200)
Example #9
Source File: plot_part1.py From cs294-112_hws with MIT License | 6 votes |
def plot_13(data): r1, r2, r3, r4 = data plt.figure() add_plot(r3, 'MeanReward100Episodes'); add_plot(r3, 'BestMeanReward', 'gamma = 0.9'); add_plot(r2, 'MeanReward100Episodes'); add_plot(r2, 'BestMeanReward', 'gamma = 0.99'); add_plot(r4, 'MeanReward100Episodes'); add_plot(r4, 'BestMeanReward', 'gamma = 0.999'); plt.legend(); plt.xlabel('Time step'); plt.ylabel('Reward'); plt.savefig( os.path.join('results', 'p13.png'), bbox_inches='tight', transparent=True, pad_inches=0.1 )
Example #10
Source File: utils.py From DeepLung with GNU General Public License v3.0 | 6 votes |
def plotnoduledist(annopath): import pandas as pd df = pd.read_csv(annopath+'train/annotations.csv') diameter = df['diameter_mm'].reshape((-1,1)) df = pd.read_csv(annopath+'val/annotations.csv') diameter = np.vstack([df['diameter_mm'].reshape((-1,1)), diameter]) df = pd.read_csv(annopath+'test/annotations.csv') diameter = np.vstack([df['diameter_mm'].reshape((-1,1)), diameter]) fig = plt.figure() plt.hist(diameter, normed=True, bins=50) plt.ylabel('probability') plt.xlabel('Diameters') plt.title('Nodule Diameters Histogram') plt.savefig('nodulediamhist.png')
Example #11
Source File: plot_3.py From cs294-112_hws with MIT License | 6 votes |
def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('logdir', type=str) parser.add_argument('--save_name', type=str, default='results') args = parser.parse_args() if not os.path.exists('results'): os.makedirs('results') plot_3(get_datasets(args.logdir)) plt.savefig( os.path.join('results', args.save_name + '.png'), bbox_inches='tight', transparent=True, pad_inches=0.1 )
Example #12
Source File: utils.py From kss with Apache License 2.0 | 6 votes |
def plot_alignment(alignment, gs, dir=hp.logdir): """Plots the alignment. Args: alignment: A numpy array with shape of (encoder_steps, decoder_steps) gs: (int) global step. dir: Output path. """ if not os.path.exists(dir): os.mkdir(dir) fig, ax = plt.subplots() im = ax.imshow(alignment) fig.colorbar(im) plt.title('{} Steps'.format(gs)) plt.savefig('{}/alignment_{}.png'.format(dir, gs), format='png')
Example #13
Source File: massachusetts_road_segm.py From Recipes with MIT License | 6 votes |
def plot_some_results(pred_fn, test_generator, n_images=10): fig_ctr = 0 for data, seg in test_generator: res = pred_fn(data) for d, s, r in zip(data, seg, res): plt.figure(figsize=(12, 6)) plt.subplot(1, 3, 1) plt.imshow(d.transpose(1,2,0)) plt.title("input patch") plt.subplot(1, 3, 2) plt.imshow(s[0]) plt.title("ground truth") plt.subplot(1, 3, 3) plt.imshow(r) plt.title("segmentation") plt.savefig("road_segmentation_result_%03.0f.png"%fig_ctr) plt.close() fig_ctr += 1 if fig_ctr > n_images: break
Example #14
Source File: utils.py From dc_tts with Apache License 2.0 | 6 votes |
def plot_alignment(alignment, gs, dir=hp.logdir): """Plots the alignment. Args: alignment: A numpy array with shape of (encoder_steps, decoder_steps) gs: (int) global step. dir: Output path. """ if not os.path.exists(dir): os.mkdir(dir) fig, ax = plt.subplots() im = ax.imshow(alignment) fig.colorbar(im) plt.title('{} Steps'.format(gs)) plt.savefig('{}/alignment_{}.png'.format(dir, gs), format='png') plt.close(fig)
Example #15
Source File: utils.py From DeepLung with GNU General Public License v3.0 | 6 votes |
def plothistdiameter(trainpath='/media/data1/wentao/tianchi/preprocessing/newtrain/', testpath='/media/data1/wentao/tianchi/preprocessing/newtest/'): diameterlist = [] for fname in os.listdir(trainpath): if fname.endswith('_label.npy'): label = np.load(trainpath+fname) for lidx in xrange(label.shape[0]): diameterlist.append(label[lidx, -1]) for fname in os.listdir(testpath): if fname.endswith('_label.npy'): label = np.load(testpath+fname) for lidx in xrange(label.shape[0]): diameterlist.append(label[lidx, -1]) fig = plt.figure() plt.hist(diameterlist, 50) plt.xlabel('Nodule Diameter') plt.ylabel('# Nodules') plt.title('Nodule Size Histogram') plt.savefig('processnodulesizehist.png')
Example #16
Source File: utils.py From Pytorch-Networks with MIT License | 6 votes |
def plot_result_data(acc_total, acc_val_total, loss_total, losss_val_total, cfg_path, epoch): import matplotlib.pyplot as plt y = range(epoch) plt.plot(y,acc_total,linestyle="-", linewidth=1,label='acc_train') plt.plot(y,acc_val_total,linestyle="-", linewidth=1,label='acc_val') plt.legend(('acc_train', 'acc_val'), loc='upper right') plt.xlabel("Training Epoch") plt.ylabel("Acc on dataset") plt.savefig('{}/acc.png'.format(cfg_path)) plt.cla() plt.plot(y,loss_total,linestyle="-", linewidth=1,label='loss_train') plt.plot(y,losss_val_total,linestyle="-", linewidth=1,label='loss_val') plt.legend(('loss_train', 'loss_val'), loc='upper right') plt.xlabel("Training Epoch") plt.ylabel("Loss on dataset") plt.savefig('{}/loss.png'.format(cfg_path))
Example #17
Source File: plot_confusion_matrix.py From Chinese-Character-and-Calligraphic-Image-Processing with MIT License | 6 votes |
def plotCM(classes, matrix, savname): """classes: a list of class names""" # Normalize by row matrix = matrix.astype(np.float) linesum = matrix.sum(1) linesum = np.dot(linesum.reshape(-1, 1), np.ones((1, matrix.shape[1]))) matrix /= linesum # plot plt.switch_backend('agg') fig = plt.figure() ax = fig.add_subplot(111) cax = ax.matshow(matrix) fig.colorbar(cax) ax.xaxis.set_major_locator(MultipleLocator(1)) ax.yaxis.set_major_locator(MultipleLocator(1)) for i in range(matrix.shape[0]): ax.text(i, i, str('%.2f' % (matrix[i, i] * 100)), va='center', ha='center') ax.set_xticklabels([''] + classes, rotation=90) ax.set_yticklabels([''] + classes) plt.savefig(savname)
Example #18
Source File: m_dos_pdos_eigenvalues.py From pyscf with Apache License 2.0 | 6 votes |
def dosplot (filename = None, data = None, fermi = None): if (filename is not None): data = np.loadtxt(filename) elif (data is not None): data = data import matplotlib.pyplot as plt from matplotlib import rc plt.rc('text', usetex=True) plt.rc('font', family='serif') plt.plot(data.T[0], data.T[1], label='MF Spin-UP', linestyle=':',color='r') plt.fill_between(data.T[0], 0, data.T[1], facecolor='r',alpha=0.1, interpolate=True) plt.plot(data.T[0], data.T[2], label='QP Spin-UP',color='r') plt.fill_between(data.T[0], 0, data.T[2], facecolor='r',alpha=0.5, interpolate=True) plt.plot(data.T[0],-data.T[3], label='MF Spin-DN', linestyle=':',color='b') plt.fill_between(data.T[0], 0, -data.T[3], facecolor='b',alpha=0.1, interpolate=True) plt.plot(data.T[0],-data.T[4], label='QP Spin-DN',color='b') plt.fill_between(data.T[0], 0, -data.T[4], facecolor='b',alpha=0.5, interpolate=True) if (fermi!=None): plt.axvline(x=fermi ,color='k', linestyle='--') #label='Fermi Energy' plt.axhline(y=0,color='k') plt.title('Total DOS', fontsize=20) plt.xlabel('Energy (eV)', fontsize=15) plt.ylabel('Density of States (electron/eV)', fontsize=15) plt.legend() plt.savefig("dos_eigen.svg", dpi=900) plt.show()
Example #19
Source File: malware.py From trees with Apache License 2.0 | 6 votes |
def classify(self, features, show=False): recs, _ = features.shape result_shape = (features.shape[0], len(self.root)) scores = np.zeros(result_shape) print scores.shape R = Record(np.arange(recs, dtype=int), features) for i, T in enumerate(self.root): for idxs, result in classify(T, R): for idx in idxs.indexes(): scores[idx, i] = float(result[0]) / sum(result.values()) if show: plt.cla() plt.clf() plt.close() plt.imshow(scores, cmap=plt.cm.gray) plt.title('Scores matrix') plt.savefig(r"../scratch/tree_scores.png", bbox_inches='tight') return scores
Example #20
Source File: SA.py From sopt with MIT License | 5 votes |
def save_plot(self,save_name = "SA.png"): plt.plot(self.generations_best_targets,'r-') plt.xlabel("steps") plt.ylabel("best target function value") plt.title("SA %d steps simulation" % self.steps) plt.savefig(save_name)
Example #21
Source File: Gradients.py From sopt with MIT License | 5 votes |
def save_plot(self,save_name = "Adam.png"): plt.plot(self.generations_targets,'r-') plt.xlabel("epochs") plt.ylabel("target function value") plt.plot("Adam with %d epochs" % self.epochs) plt.savefig(save_name)
Example #22
Source File: malware.py From trees with Apache License 2.0 | 5 votes |
def ROC(scores, labels, names, name="STD"): max_ACC, TP, FP = ROC_data(scores, labels, names, name) graph_ROC([max_ACC], [TP], [FP], name) #P = len(labels[labels==1]) #N = len(labels[labels==0]) ## Save raw results in a file: #fr = file(r"../scratch/"+name+"_results.txt","w") #for s, l, n in sorted(zip(scores,labels, names), key=lambda x: np.mean(x[0])): # fr.write("%.4f\t%s\t%s\n" % (np.mean(s), int(l), n)) #fr.close() ## Make an ROC curve # acc_max = "%.2f" % max(ACC) #plt.cla() #plt.clf() #plt.close() #plt.plot(FP, TP) #plt.xlim((0,0.1)) #plt.ylim((0,1)) #plt.title('ROC Curve (accuracy=%.2f)' % max_ACC) #plt.xlabel('False Positive Rate') #plt.ylabel('True Positive Rate') #plt.savefig(r"../scratch/"+name+"_ROC_curve.png", bbox_inches='tight') #f = file(r"../scratch/"+name+"_ROC_curve.csv", "w") #f.write("FalsePositive,TruePositive,Accuracy\n") #for fp, tp, acc in zip(FP,TP, ACC): # f.write("%s,%s,%s\n" % (fp, tp, acc)) #f.close() ## Read the csv files
Example #23
Source File: logutil.py From Depth-Map-Prediction with GNU General Public License v3.0 | 5 votes |
def save_fig(fn, *args, **kwargs): ''' Save a matplotlib figure to fn in the current output dir. args same as for pyplot.savefig(). ''' with open(fn, 'w') as f: pyplot.savefig(f, *args, **kwargs)
Example #24
Source File: gail-eval.py From HardRLWithYoutube with MIT License | 5 votes |
def plot(env_name, bc_log, gail_log, stochastic): upper_bound = bc_log['upper_bound'] bc_avg_ret = bc_log['avg_ret'] gail_avg_ret = gail_log['avg_ret'] plt.plot(CONFIG['traj_limitation'], upper_bound) plt.plot(CONFIG['traj_limitation'], bc_avg_ret) plt.plot(CONFIG['traj_limitation'], gail_avg_ret) plt.xlabel('Number of expert trajectories') plt.ylabel('Accumulated reward') plt.title('{} unnormalized scores'.format(env_name)) plt.legend(['expert', 'bc-imitator', 'gail-imitator'], loc='lower right') plt.grid(b=True, which='major', color='gray', linestyle='--') if stochastic: title_name = 'result/{}-unnormalized-stochastic-scores.png'.format(env_name) else: title_name = 'result/{}-unnormalized-deterministic-scores.png'.format(env_name) plt.savefig(title_name) plt.close() bc_normalized_ret = bc_log['normalized_ret'] gail_normalized_ret = gail_log['normalized_ret'] plt.plot(CONFIG['traj_limitation'], np.ones(len(CONFIG['traj_limitation']))) plt.plot(CONFIG['traj_limitation'], bc_normalized_ret) plt.plot(CONFIG['traj_limitation'], gail_normalized_ret) plt.xlabel('Number of expert trajectories') plt.ylabel('Normalized performance') plt.title('{} normalized scores'.format(env_name)) plt.legend(['expert', 'bc-imitator', 'gail-imitator'], loc='lower right') plt.grid(b=True, which='major', color='gray', linestyle='--') if stochastic: title_name = 'result/{}-normalized-stochastic-scores.png'.format(env_name) else: title_name = 'result/{}-normalized-deterministic-scores.png'.format(env_name) plt.ylim(0, 1.6) plt.savefig(title_name) plt.close()
Example #25
Source File: mujoco_dset.py From HardRLWithYoutube with MIT License | 5 votes |
def plot(self): import matplotlib.pyplot as plt plt.hist(self.rets) plt.savefig("histogram_rets.png") plt.close()
Example #26
Source File: plotting.py From medicaldetectiontoolkit with Apache License 2.0 | 5 votes |
def plot_stat_curves(stats, outfile): for c in ['roc', 'prc']: plt.figure() for s in stats: if s[c] is not None: plt.plot(s[c][0], s[c][1], label=s['name'] + '_' + c) plt.title(outfile.split('/')[-1] + '_' + c) plt.legend(loc=3 if c == 'prc' else 4) plt.xlabel('precision' if c == 'prc' else '1-spec.') plt.ylabel('recall') plt.savefig(outfile + '_' + c) plt.close()
Example #27
Source File: plotting.py From medicaldetectiontoolkit with Apache License 2.0 | 5 votes |
def plot_prediction_hist(label_list, pred_list, type_list, outfile): """ plot histogram of predictions for a specific class. :param label_list: list of 1s and 0s specifying whether prediction is a true positive match (1) or a false positive (0). False negatives (missed ground truth objects) are artificially added predictions with score 0 and label 1. :param pred_list: list of prediction-scores. :param type_list: list of prediction-types for stastic-info in title. """ preds = np.array(pred_list) labels = np.array(label_list) title = outfile.split('/')[-1] + ' count:{}'.format(len(label_list)) plt.figure() plt.yscale('log') if 0 in labels: plt.hist(preds[labels == 0], alpha=0.3, color='g', range=(0, 1), bins=50, label='false pos.') if 1 in labels: plt.hist(preds[labels == 1], alpha=0.3, color='b', range=(0, 1), bins=50, label='true pos. (false neg. @ score=0)') if type_list is not None: fp_count = type_list.count('det_fp') fn_count = type_list.count('det_fn') tp_count = type_list.count('det_tp') pos_count = fn_count + tp_count title += ' tp:{} fp:{} fn:{} pos:{}'. format(tp_count, fp_count, fn_count, pos_count) plt.legend() plt.title(title) plt.xlabel('confidence score') plt.ylabel('log n') plt.savefig(outfile) plt.close()
Example #28
Source File: plotting.py From medicaldetectiontoolkit with Apache License 2.0 | 5 votes |
def update_and_save(self, metrics, epoch): for figure_ix in range(len(self.figure_list)): fig = self.figure_list[figure_ix] detection_monitoring_plot(fig.ax1, metrics, self.exp_name, self.color_palette, epoch, figure_ix, self.separate_values_dict, self.do_validation) fig.savefig(self.file_name + '_{}'.format(figure_ix))
Example #29
Source File: malware.py From trees with Apache License 2.0 | 5 votes |
def visualizedistances(data, figname=None): D, L, N = data sorted_indexes = np.argsort(L[:,0]) D2 = D[sorted_indexes, :] D2 = D2[:, sorted_indexes] plt.cla() plt.clf() plt.close() plt.imshow(D2, cmap=plt.cm.gray) plt.title('Distance matrix') plt.savefig(figname, bbox_inches='tight')
Example #30
Source File: gail-eval.py From lirpg with MIT License | 5 votes |
def plot(env_name, bc_log, gail_log, stochastic): upper_bound = bc_log['upper_bound'] bc_avg_ret = bc_log['avg_ret'] gail_avg_ret = gail_log['avg_ret'] plt.plot(CONFIG['traj_limitation'], upper_bound) plt.plot(CONFIG['traj_limitation'], bc_avg_ret) plt.plot(CONFIG['traj_limitation'], gail_avg_ret) plt.xlabel('Number of expert trajectories') plt.ylabel('Accumulated reward') plt.title('{} unnormalized scores'.format(env_name)) plt.legend(['expert', 'bc-imitator', 'gail-imitator'], loc='lower right') plt.grid(b=True, which='major', color='gray', linestyle='--') if stochastic: title_name = 'result/{}-unnormalized-stochastic-scores.png'.format(env_name) else: title_name = 'result/{}-unnormalized-deterministic-scores.png'.format(env_name) plt.savefig(title_name) plt.close() bc_normalized_ret = bc_log['normalized_ret'] gail_normalized_ret = gail_log['normalized_ret'] plt.plot(CONFIG['traj_limitation'], np.ones(len(CONFIG['traj_limitation']))) plt.plot(CONFIG['traj_limitation'], bc_normalized_ret) plt.plot(CONFIG['traj_limitation'], gail_normalized_ret) plt.xlabel('Number of expert trajectories') plt.ylabel('Normalized performance') plt.title('{} normalized scores'.format(env_name)) plt.legend(['expert', 'bc-imitator', 'gail-imitator'], loc='lower right') plt.grid(b=True, which='major', color='gray', linestyle='--') if stochastic: title_name = 'result/{}-normalized-stochastic-scores.png'.format(env_name) else: title_name = 'result/{}-normalized-deterministic-scores.png'.format(env_name) plt.ylim(0, 1.6) plt.savefig(title_name) plt.close()