Python seaborn.axes_style() Examples
The following are 30
code examples of seaborn.axes_style().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
seaborn
, or try the search function
.
Example #1
Source File: plot_functions.py From idea_relations with MIT License | 7 votes |
def joint_plot(x, y, xlabel=None, ylabel=None, xlim=None, ylim=None, loc="best", color='#0485d1', size=8, markersize=50, kind="kde", scatter_color="r"): with sns.axes_style("darkgrid"): if xlabel and ylabel: g = SubsampleJointGrid(xlabel, ylabel, data=DataFrame(data={xlabel: x, ylabel: y}), space=0.1, ratio=2, size=size, xlim=xlim, ylim=ylim) else: g = SubsampleJointGrid(x, y, size=size, space=0.1, ratio=2, xlim=xlim, ylim=ylim) g.plot_joint(sns.kdeplot, shade=True, cmap="Blues") g.plot_sub_joint(plt.scatter, 1000, s=20, c=scatter_color, alpha=0.3) g.plot_marginals(sns.distplot, kde=False, rug=False) g.annotate(ss.pearsonr, fontsize=25, template="{stat} = {val:.2g}\np = {p:.2g}") g.ax_joint.set_yticklabels(g.ax_joint.get_yticks()) g.ax_joint.set_xticklabels(g.ax_joint.get_xticks()) return g
Example #2
Source File: plotting_utils.py From QUANTAXIS with MIT License | 6 votes |
def customize(func): """ 修饰器,设置输出图像内容与风格 """ @wraps(func) def call_w_context(*args, **kwargs): set_context = kwargs.pop("set_context", True) if set_context: color_palette = sns.color_palette("colorblind") with plotting_context(), axes_style(), color_palette: sns.despine(left=True) return func(*args, **kwargs) else: return func(*args, **kwargs) return call_w_context
Example #3
Source File: plotting_utils.py From QUANTAXIS with MIT License | 6 votes |
def axes_style(style: str = "darkgrid", rc: dict = None): """ 创建默认轴域风格 参数 --- :param style: seaborn 样式 :param rc: dict 配置标签 """ if rc is None: rc = {} rc_default = {} for name, val in rc_default.items(): rc.set_default(name, val) return sns.axes_style(style=style, rc=rc)
Example #4
Source File: plot_utils.py From jqfactor_analyzer with MIT License | 6 votes |
def customize(func): @wraps(func) def call_w_context(*args, **kwargs): if not PlotConfig.FONT_SETTED: _use_chinese(True) set_context = kwargs.pop('set_context', True) if set_context: with plotting_context(), axes_style(): sns.despine(left=True) return func(*args, **kwargs) else: return func(*args, **kwargs) return call_w_context
Example #5
Source File: interpretation.py From lumin with Apache License 2.0 | 6 votes |
def plot_embedding(embed:OrderedDict, feat:str, savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None: r''' Visualise weights in provided categorical entity-embedding matrix Arguments: embed: state_dict of trained nn.Embedding feat: name of feature embedded savename: Optional name of file to which to save the plot of feature importances settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance ''' with sns.axes_style(**settings.style): plt.figure(figsize=(settings.w_small, settings.h_small)) sns.heatmap(to_np(embed['weight']), annot=True, fmt='.1f', linewidths=.5, cmap=settings.div_palette, annot_kws={'fontsize':settings.leg_sz}) plt.xlabel("Embedding", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.ylabel(feat, fontsize=settings.lbl_sz, color=settings.lbl_col) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc) if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight') plt.show()
Example #6
Source File: plotting.py From MJHMC with GNU General Public License v2.0 | 6 votes |
def gauss_2d(nsamples=1000): """ Another simple test plot 1d gaussian sampled from each sampler visualized as a joint 2d gaussian """ gaussian = TestGaussian(ndims=1) control = HMCBase(distribution=gaussian) experimental = MarkovJumpHMC(distribution=gaussian, resample=False) with sns.axes_style("white"): sns.jointplot( control.sample(nsamples)[0], experimental.sample(nsamples)[0], kind='hex', stat_func=None)
Example #7
Source File: cyclic_callbacks.py From lumin with Apache License 2.0 | 6 votes |
def plot(self): r''' Plots the history of the lr and momentum evolution as a function of iterations ''' with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette): fig, axs = plt.subplots(2, 1, figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid)) axs[1].set_xlabel("Iterations", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) axs[0].set_ylabel("Learning Rate", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) axs[1].set_ylabel("Momentum", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) axs[0].plot(range(len(self.hist['lr'])), self.hist['lr']) axs[1].plot(range(len(self.hist['mom'])), self.hist['mom']) for ax in axs: ax.tick_params(axis='x', labelsize=self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col) ax.tick_params(axis='y', labelsize=self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col) plt.show()
Example #8
Source File: plot_utils.py From jqfactor_analyzer with MIT License | 5 votes |
def axes_style(style='darkgrid', rc=None): if rc is None: rc = {} rc_default = {} for name, val in rc_default.items(): rc.setdefault(name, val) return sns.axes_style(style=style, rc=rc)
Example #9
Source File: plotfunctions.py From DataScience-webapp-with-flask with MIT License | 5 votes |
def plot_boxplot(ds, cat, num): sns.set() plt.gcf().clear() with sns.axes_style(style='ticks'): sns.factorplot(cat, num, data=ds, kind="box") from io import BytesIO plt.xlabel(cat) plt.ylabel(num) figfile = BytesIO() plt.savefig(figfile, format='png') figfile.seek(0) # rewind to beginning of file import base64 figdata_png = base64.b64encode(figfile.getvalue()) return figdata_png
Example #10
Source File: rt-heatmap.py From pyrocore with GNU General Public License v2.0 | 5 votes |
def heatmap(self, df, imagefile): """ Create the heat map. """ import seaborn as sns import matplotlib.ticker as tkr import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap sns.set() with sns.axes_style('whitegrid'): fig, ax = plt.subplots(figsize=(5, 11)) # inches cmax = max(df[self.args[2]].max(), self.CMAP_MIN_MAX) csteps = { 0.0: 'darkred', 0.3/cmax: 'red', 0.6/cmax: 'orangered', 0.9/cmax: 'coral', 1.0/cmax: 'skyblue', 1.5/cmax: 'blue', 1.9/cmax: 'darkblue', 2.0/cmax: 'darkgreen', 3.0/cmax: 'green', (self.CMAP_MIN_MAX - .1)/cmax: 'palegreen', 1.0: 'yellow'} cmap = LinearSegmentedColormap.from_list('RdGrYl', sorted(csteps.items()), N=256) dataset = df.pivot(*self.args) sns.heatmap(dataset, mask=dataset.isnull(), annot=False, linewidths=.5, square=True, ax=ax, cmap=cmap, annot_kws=dict(stretch='condensed')) ax.tick_params(axis='y', labelrotation=30, labelsize=8) # ax.get_yaxis().set_major_formatter(tkr.FuncFormatter(lambda x, p: x)) plt.savefig(imagefile)
Example #11
Source File: plotting.py From MJHMC with GNU General Public License v2.0 | 5 votes |
def hist_2d(distr, nsamples, **kwargs): """ Plots a 2d hexbinned histogram of distribution Args: distr: Distribution object nsamples: number of samples to use to generate plot """ sampler = MarkovJumpHMC(distribution=distr, **kwargs) samples = sampler.sample(nsamples) with sns.axes_style("white"): g = sns.jointplot(samples[0], samples[1], kind='kde', stat_func=None) return g
Example #12
Source File: budgeted_stream_plot.py From DARENet with MIT License | 5 votes |
def main(args): distance_confidence_info = pickle.load(open(osp.join(args.result_path, "distance_confidence_info.pkl"), "rb")) margin_confidence_info = pickle.load(open(osp.join(args.result_path, "margin_confidence_info.pkl"), "rb")) random_info = pickle.load(open(osp.join(args.result_path, "random_info.pkl"), "rb")) distance_confidence_info['CMCs'] = [cmc * 100 for cmc in distance_confidence_info['CMCs']] margin_confidence_info['CMCs'] = [cmc * 100 for cmc in margin_confidence_info['CMCs']] random_info['CMCs'] = [cmc * 100 for cmc in random_info['CMCs']] with sns.axes_style("white"): fig = plt.figure(figsize=(6, 4.5)) ax = fig.add_subplot(111) ax.plot(random_info['resulted_budgets'], random_info['CMCs'], marker='.', linewidth=2.5, markersize=0, label="DaRe(R)+RE (random)", color=flatui[0]) ax.plot(distance_confidence_info['resulted_budgets'], distance_confidence_info['CMCs'], marker='*', linewidth=2.5, markersize=0, label="DaRe(R)+RE (distance)", color=flatui[1]) ax.plot(margin_confidence_info['resulted_budgets'], margin_confidence_info['CMCs'], marker='*', linewidth=2.5, markersize=0, label="DaRe(R)+RE (margin)", color=flatui[2]) ax.scatter(SVDNet_R_RE[0], SVDNet_R_RE[1], marker='*', s=150, label="SVDNet(R)+RE", color=flatui[3]) ax.scatter(IDE_R_KISSME[0], IDE_R_KISSME[1], marker='h', s=100, label="IDE(R)+KISSME", color=flatui[4]) ax.scatter(IDE_C_KISSME[0], IDE_C_KISSME[1], marker='o', s=100, label="IDE(C)+KISSME", color=flatui[5]) ax.scatter(TriNet_R[0], TriNet_R[1], marker='D', s=60, label="TriNet(R)", color=flatui[6]) ax.scatter(SVDNet_C[0], SVDNet_C[1], marker='p', s=100, label="SVDNet(C)", color=flatui[7]) plt.xlabel("Average Budget (in MUL-ADD)", size=15) plt.ylabel("CMC Rank 1 Accuracy (\%)", size=15) handles, labels = ax.get_legend_handles_labels() label_order = ['TriNet(R)', 'SVDNet(C)', 'SVDNet(R)+RE', 'IDE(R)+KISSME', 'IDE(C)+KISSME', 'DaRe(R)+RE (random)', 'DaRe(R)+RE (distance)', 'DaRe(R)+RE (margin)'] new_handles = [] for l in label_order: for i in range(len(labels)): if labels[i] == l: new_handles.append(handles[i]) ax.legend(new_handles, label_order, loc='lower right') plt.grid(linestyle='dotted') plt.tight_layout(pad=1, w_pad=1, h_pad=1) plt.xlim(3e8, 4.5e9) plt.ylim(55, 95) plt.savefig(args.figname + ".pdf", bbox_inches='tight') plt.close()
Example #13
Source File: opt_callbacks.py From lumin with Apache License 2.0 | 5 votes |
def plot_lr(self) -> None: r''' Plot the LR as a function of iterations. ''' with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette): plt.figure(figsize=(self.plot_settings.h_small, self.plot_settings.h_small)) plt.plot(range(len(self.history['lr'])), self.history['lr']) plt.xticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col) plt.yticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col) plt.ylabel("Learning rate", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) plt.xlabel("Iterations", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) plt.show()
Example #14
Source File: cyclic_callbacks.py From lumin with Apache License 2.0 | 5 votes |
def plot(self) -> None: r''' Plots the history of the parameter evolution as a function of iterations ''' with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette): plt.figure(figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid)) plt.xlabel("Iterations", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) plt.ylabel(self.param_name, fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col) plt.plot(range(len(self.hist)), self.hist) plt.xticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col) plt.yticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col) plt.show()
Example #15
Source File: metric_logger.py From lumin with Apache License 2.0 | 5 votes |
def reset(self) -> None: r''' Resets/initialises the logger's values and plots, and produces a placeholder plot. Should be called prior to `update_vals` or `update_plot`. ''' self.loss_vals, self.vel_vals, self.gen_vals = [[] for _ in self.loss_names], [[] for _ in self.loss_names], [[] for _ in range(len(self.loss_names)-1)] self.mean_losses = [None for _ in self.loss_names] self.subepochs, self.epochs = [0], [0] self.count,self.log = 1,False with sns.axes_style(**self.settings.style): if self.extra_detail: self.fig = plt.figure(figsize=(self.settings.w_mid, self.settings.h_mid), constrained_layout=True) gs = self.fig.add_gridspec(2, 3) self.loss_ax = self.fig.add_subplot(gs[:,:-1]) self.vel_ax = self.fig.add_subplot(gs[:1,2:]) self.gen_ax = self.fig.add_subplot(gs[1:2,2:]) for ax in [self.loss_ax, self.vel_ax, self.gen_ax]: ax.tick_params(axis='x', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col) ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col) self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.vel_ax.set_ylabel(r'$\Delta \bar{L}\ /$ Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.gen_ax.set_xlabel('Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.gen_ax.set_ylabel('Validation / Train', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.display = display(self.fig, display_id=True) else: self.fig, self.loss_ax = plt.subplots(1, figsize=(self.settings.w_mid, self.settings.h_mid)) self.loss_ax.tick_params(axis='x', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col) self.loss_ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col) self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.display = display(self.loss_ax.figure, display_id=True)
Example #16
Source File: training.py From lumin with Apache License 2.0 | 5 votes |
def plot_train_history(histories:List[Dict[str,List[float]]], savename:Optional[str]=None, ignore_trn=True, settings:PlotSettings=PlotSettings(), show:bool=True) -> None: r''' Plot histories object returned by :meth:`~lumin.nn.training.fold_train.fold_train_ensemble` showing the loss evolution over time per model trained. Arguments: histories: list of dictionaries mapping loss type to values at each (sub)-epoch savename: Optional name of file to which to save the plot of feature importances ignore_trn: whether to ignore training loss settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance show: whether or not to show the plot, or just save it ''' with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette: plt.figure(figsize=(settings.w_mid, settings.h_mid)) for i, history in enumerate(histories): if i == 0: for j, l in enumerate(history): if not('trn' in l and ignore_trn): plt.plot(history[l], color=palette[j], label=_lookup_name(l)) else: for j, l in enumerate(history): if not('trn' in l and ignore_trn): plt.plot(history[l], color=palette[j]) plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.xlabel("Epoch", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.ylabel("Loss", fontsize=settings.lbl_sz, color=settings.lbl_col) if savename is not None: plt.savefig(f'{savename}{settings.format}', bbox_inches='tight') if show: plt.show()
Example #17
Source File: results.py From lumin with Apache License 2.0 | 5 votes |
def plot_binary_class_pred(df:pd.DataFrame, pred_name:str='pred', targ_name:str='gen_target', wgt_name:str=None, wgt_scale:float=1, log_y:bool=False, lim_x:Tuple[float,float]=(0,1), density=True, savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None: r''' Basic plotter for prediction distribution in a binary classification problem. Note that labels are set using the settings.targ2class dictionary, which by default is {0: 'Background', 1: 'Signal'}. Arguments: df: DataFrame with targets and predictions pred_name: name of column to use as predictions targ_name: name of column to use as targets wgt_name: optional name of column to use as sample weights wgt_scale: applies a global multiplicative rescaling to sample weights. Default 1 = no rescaling log_y: whether to use a log scale for the y-axis lim_x: limit for plotting on the x-axis density: whether to normalise each distribution to one, or keep set to sum of weights / datapoints savename: Optional name of file to which to save the plot of feature importances settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance ''' with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette): plt.figure(figsize=(settings.w_mid, settings.h_mid)) for targ in sorted(set(df[targ_name])): cut = df[targ_name] == targ hist_kws = {} if wgt_name is None else {'weights': wgt_scale*df.loc[cut, wgt_name]} sns.distplot(df.loc[cut, pred_name], label=settings.targ2class[targ], hist_kws=hist_kws, norm_hist=density, kde=False) plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz) plt.xlabel("Class prediction", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.xlim(lim_x) if density: plt.ylabel(r"$\frac{1}{N}\ \frac{dN}{dp}$", fontsize=settings.lbl_sz, color=settings.lbl_col) elif wgt_scale != 1: plt.ylabel(str(wgt_scale) + r"$\times\frac{dN}{dp}$", fontsize=settings.lbl_sz, color=settings.lbl_col) else: plt.ylabel(r"$\frac{dN}{dp}$", fontsize=settings.lbl_sz, color=settings.lbl_col) if log_y: plt.yscale('log', nonposy='clip') plt.grid(True, which="both") plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc) if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight') plt.show()
Example #18
Source File: data_viewing.py From lumin with Apache License 2.0 | 5 votes |
def plot_rank_order_dendrogram(df:pd.DataFrame, threshold:float=0.8, savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) \ -> Dict[str,Union[List[str],float]]: r''' Plots a dendrogram of features in df clustered via Spearman's rank correlation coefficient. Also returns a sets of features with correlation coefficients greater than the threshold Arguments: df: Pandas DataFrame containing data threshold: Threshold on correlation coefficient savename: Optional name of file to which to save the plot of feature importances settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance Returns: Dict of sets of features with correlation coefficients greater than the threshold and cluster distance ''' corr = np.round(scipy.stats.spearmanr(df).correlation, 4) corr_condensed = hc.distance.squareform(1-np.abs(corr)) # Abs because negtaive of a feature is a trvial transformation: information unaffected z = hc.linkage(corr_condensed, method='average', optimal_ordering=True) with sns.axes_style('white'), sns.color_palette(settings.cat_palette): plt.figure(figsize=(settings.w_large, (0.5*len(df.columns)))) hc.dendrogram(z, labels=df.columns, orientation='left', leaf_font_size=settings.lbl_sz, color_threshold=1-threshold) plt.xlabel("Distance (1 - |Spearman's Rank Correlation Coefficient|)", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight') plt.show() feats = df.columns sets = {} for i, merge in enumerate(z): if merge[2] > 1-threshold: continue if merge[0] <= len(z): a = [feats[int(merge[0])]] else: a = sets.pop(int(merge[0]))['children'] if merge[1] <= len(z): b = [feats[int(merge[1])]] else: b = sets.pop(int(merge[1]))['children'] sets[1 + i + len(z)] = {'children': [*a, *b], 'distance': merge[2]} return sets
Example #19
Source File: interpretation.py From lumin with Apache License 2.0 | 5 votes |
def plot_importance(df:pd.DataFrame, feat_name:str='Feature', imp_name:str='Importance', unc_name:str='Uncertainty', threshold:Optional[float]=None, x_lbl:str='Importance via feature permutation', savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None: r''' Plot feature importances as computted via `get_nn_feat_importance`, `get_ensemble_feat_importance`, or `rf_rank_features` Arguments: df: DataFrame containing columns of features, importances and, optionally, uncertainties feat_name: column name for features imp_name: column name for importances unc_name: column name for uncertainties (if present) threshold: if set, will draw a line at the threshold hold used for feature importance x_lbl: label to put on the x-axis savename: Optional name of file to which to save the plot of feature importances settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance ''' with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette: fig, ax = plt.subplots(figsize=(settings.w_large, (0.75)*settings.lbl_sz)) xerr = None if unc_name not in df else 'Uncertainty' df.plot(feat_name, imp_name, 'barh', ax=ax, legend=False, xerr=xerr, error_kw={'elinewidth': 3}, color=palette[0]) if threshold is not None: ax.axvline(x=threshold, label=f'Threshold {threshold}', color=palette[1], linestyle='--', linewidth=3) plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz) ax.set_xlabel(x_lbl, fontsize=settings.lbl_sz, color=settings.lbl_col) ax.set_ylabel('Feature', fontsize=settings.lbl_sz, color=settings.lbl_col) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}') plt.show()
Example #20
Source File: plot_functions.py From idea_relations with MIT License | 5 votes |
def start_plotting(fig_size, fig_pos, style="white", rc=None, despine=False): with sns.axes_style(style, rc): fig = plt.figure(figsize=fig_size) if not fig_pos: ax = fig.add_subplot(111) else: ax = fig.add_axes(fig_pos) if despine: sns.despine(left=True) return fig, ax
Example #21
Source File: threshold.py From lumin with Apache License 2.0 | 4 votes |
def binary_class_cut_by_ams(df:pd.DataFrame, top_perc:float=5.0, min_pred:float=0.9, wgt_factor:float=1.0, br:float=0.0, syst_unc_b:float=0.0, pred_name:str='pred', targ_name:str='gen_target', wgt_name:str='gen_weight', plot_settings:PlotSettings=PlotSettings()) -> Tuple[float,float,float]: r''' Optimise a cut on a signal-background classifier prediction by the Approximate Median Significance Cut which should generalise better by taking the mean class prediction of the top top_perc percentage of points as ranked by AMS Arguments: df: Pandas DataFrame containing data top_perc: top percentage of events to consider as ranked by AMS min_pred: minimum prediction to consider wgt_factor: single multiplicative coeficient for rescaling signal and background weights before computing AMS br: background offset bias syst_unc_b: fractional systemtatic uncertainty on background pred_name: column to use as predictions targ_name: column to use as truth labels for signal and background wgt_name: column to use as weights for signal and background events plot_settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance Returns: Optimised cut AMS at cut Maximum AMS ''' # TODO: Multithread AMS calculation sig, bkg = (df.gen_target == 1), (df.gen_target == 0) if 'ams' not in df.columns: df['ams'] = -1 df.loc[df[pred_name] >= min_pred, 'ams'] = df[df[pred_name] >= min_pred].apply( lambda row: calc_ams(wgt_factor*np.sum(df.loc[(df[pred_name] >= row[pred_name]) & sig, wgt_name]), wgt_factor*np.sum(df.loc[(df[pred_name] >= row[pred_name]) & bkg, wgt_name]), br=br, unc_b=syst_unc_b), axis=1) sort = df.sort_values(by='ams', ascending=False) cuts = sort[pred_name].values[0:int(top_perc*len(sort)/100)] cut = np.mean(cuts) ams = calc_ams(wgt_factor*np.sum(sort.loc[(sort[pred_name] >= cut) & sig, 'gen_weight']), wgt_factor*np.sum(sort.loc[(sort[pred_name] >= cut) & bkg, 'gen_weight']), br=br, unc_b=syst_unc_b) print(f'Mean cut at {cut} corresponds to AMS of {ams}') print(f'Maximum AMS for data is {sort.iloc[0]["ams"]} at cut of {sort.iloc[0][pred_name]}') with sns.axes_style(plot_settings.style), sns.color_palette(plot_settings.cat_palette) as palette: plt.figure(figsize=(plot_settings.w_small, plot_settings.h_small)) sns.distplot(cuts, label=f'Top {top_perc}%') plt.axvline(x=cut, label='Mean prediction', color=palette[1]) plt.axvline(x=sort.iloc[0][pred_name], label='Max. AMS', color=palette[2]) plt.legend(loc=plot_settings.leg_loc, fontsize=plot_settings.leg_sz) plt.xticks(fontsize=plot_settings.tk_sz, color=plot_settings.tk_col) plt.yticks(fontsize=plot_settings.tk_sz, color=plot_settings.tk_col) plt.xlabel('Class prediction', fontsize=plot_settings.lbl_sz, color=plot_settings.lbl_col) plt.ylabel(r"$\frac{1}{N}\ \frac{dN}{dp}$", fontsize=plot_settings.lbl_sz, color=plot_settings.lbl_col) plt.show() return cut, ams, sort.iloc[0]["ams"]
Example #22
Source File: metric_logger.py From lumin with Apache License 2.0 | 4 votes |
def update_plot(self, best:Optional[float]=None) -> None: r''' Updates the plot(s), Optionally showing the user-chose best loss achieved. Arguments: best: the value of the best loss achieved so far ''' # Loss self.loss_ax.clear() with sns.axes_style(**self.settings.style), sns.color_palette(self.settings.cat_palette): for v,m in zip(self.loss_vals,self.loss_names): self.loss_ax.plot(self.subepochs[1:], v, label=m) if best is not None: self.loss_ax.plot(self.subepochs[1:], np.ones_like(self.subepochs[1:])*best, label=f'Best = {best:.3E}', linestyle='--') if self.log: self.loss_ax.set_yscale('log', nonposy='clip') self.loss_ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col, which='both') self.loss_ax.grid(True, which="both") self.loss_ax.legend(loc='upper right', fontsize=0.8*self.settings.leg_sz) self.loss_ax.set_xlabel('Sub-Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.loss_ax.set_ylabel('Loss', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) if self.extra_detail: # Velocity self.vel_ax.clear() self.vel_ax.tick_params(axis='y', labelsize=0.8*self.settings.tk_sz, labelcolor=self.settings.tk_col, which='both') self.vel_ax.grid(True, which="both") with sns.color_palette(self.settings.cat_palette): for v,m in zip(self.vel_vals,self.loss_names): self.vel_ax.plot(self.epochs[1:], v, label=f'{m} {v[-1]:.2E}') self.vel_ax.legend(loc='lower right', fontsize=0.8*self.settings.leg_sz) self.vel_ax.set_ylabel(r'$\Delta \bar{L}\ /$ Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) # Generalisation self.gen_ax.clear() self.gen_ax.grid(True, which="both") with sns.color_palette(self.settings.cat_palette) as palette: for i, (v,m) in enumerate(zip(self.gen_vals,self.loss_names[1:])): self.gen_ax.plot(self.epochs[1:], v, label=f'{m} {v[-1]:.2f}', color=palette[i+1]) self.gen_ax.legend(loc='upper left', fontsize=0.8*self.settings.leg_sz) self.gen_ax.set_xlabel('Epoch', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) self.gen_ax.set_ylabel('Validation / Train', fontsize=0.8*self.settings.lbl_sz, color=self.settings.lbl_col) if len(self.epochs) > 5: self.epochs = self.epochs[1:] for i in range(len(self.vel_vals)): self.vel_vals[i] = self.vel_vals[i][1:] for i in range(len(self.gen_vals)): self.gen_vals[i] = self.gen_vals[i][1:] self.display.update(self.fig) else: self.display.update(self.loss_ax.figure)
Example #23
Source File: data_viewing.py From lumin with Apache License 2.0 | 4 votes |
def plot_kdes_from_bs(x:np.ndarray, bs_stats:Dict[str,Any], name2args:Dict[str,Dict[str,Any]], feat:str, units:Optional[str]=None, moments=True, savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None: r''' Plots KDEs computed via :meth:`~lumin.utils.statistics.bootstrap_stats` Arguments: bs_stats: (filtered) dictionary retruned by :meth:`~lumin.utils.statistics.bootstrap_stats` name2args: Dictionary mapping names of different distributions to arguments to pass to seaborn tsplot feat: Name of feature being plotted (for axis lablels) units: Optional units to show on axes moments: whether to display mean and standard deviation of each distribution savename: Optional name of file to which to save the plot of feature importances settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance ''' # TODO: update to sns 9 with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette: plt.figure(figsize=(settings.w_mid, settings.h_mid)) for i, name in enumerate(name2args): if 'color' not in name2args[name]: name2args[name]['color'] = palette[i] if 'label' in name2args[name]: name2args[name]['condition'] = name2args[name]['label'] name2args[name].pop('label') if 'condition' in name2args[name] and moments: mean, mean_unc = uncert_round(np.mean(bs_stats[f'{name}_mean']), np.std(bs_stats[f'{name}_mean'], ddof=1)) std, std_unc = uncert_round(np.mean(bs_stats[f'{name}_std']), np.std(bs_stats[f'{name}_std'], ddof=1)) name2args[name]['condition'] += r', $\overline{x}=' + r'{}\pm{}\ \sigma= {}\pm{}$'.format(mean, mean_unc, std, std_unc) sns.tsplot(data=bs_stats[f'{name}_kde'], time=x, **name2args[name]) plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz) y_lbl = r'$\frac{1}{N}\ \frac{dN}{d' + feat.replace('$','') + r'}$' if units is not None: x_lbl = feat + r'$\ [' + units + r']$' y_lbl += r'$\ [' + units + r'^{-1}]$' else: x_lbl = feat plt.xlabel(x_lbl, fontsize=settings.lbl_sz, color=settings.lbl_col) plt.ylabel(y_lbl, fontsize=settings.lbl_sz, color=settings.lbl_col) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc) if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight') plt.show()
Example #24
Source File: data_viewing.py From lumin with Apache License 2.0 | 4 votes |
def compare_events(events:list) -> None: r''' Plots at least two events side by side in their transverse and longitudinal projections Arguments: events: list of DataFrames containing vector coordinates for 3 momenta ''' # TODO: check typing, why list? # TODO: make this work with a single event # TODO: add plot settings & saving with sns.axes_style('whitegrid'), sns.color_palette('tab10'): fig, axs = plt.subplots(3, len(events), figsize=(9*len(events), 18), gridspec_kw={'height_ratios': [1, 0.5, 0.5]}) for vector in [x[:-3] for x in events[0].columns if '_px' in x.lower()]: for i, in_data in enumerate(events): x = in_data[vector + '_px'].values[0] try: y = in_data[vector + '_py'].values[0] except KeyError: y = 0 try: z = in_data[vector + '_pz'].values[0] except KeyError: z = 0 axs[0, i].plot((0, x), (0, y), label=vector) axs[1, i].plot((0, z), (0, x), label=vector) axs[2, i].plot((0, z), (0, y), label=vector) for ax in axs[0]: ax.add_artist(plt.Circle((0, 0), 1, color='grey', fill=False, linewidth=2)) ax.set_xlim(-1.1, 1.1) ax.set_ylim(-1.1, 1.1) ax.set_xlabel(r"$p_x$", fontsize=16, color='black') ax.set_ylabel(r"$p_y$", fontsize=16, color='black') ax.legend(loc='right', fontsize=12) for ax in axs[1]: ax.add_artist(plt.Rectangle((-2, -1), 4, 2, color='grey', fill=False, linewidth=2)) ax.set_xlim(-2.2, 2.2) ax.set_ylim(-1.1, 1.1) ax.set_xlabel(r"$p_z$", fontsize=16, color='black') ax.set_ylabel(r"$p_x$", fontsize=16, color='black') ax.legend(loc='right', fontsize=12) for ax in axs[2]: ax.add_artist(plt.Rectangle((-2, -1), 4, 2, color='grey', fill=False, linewidth=2)) ax.set_xlim(-2.2, 2.2) ax.set_ylim(-1.1, 1.1) ax.set_xlabel(r"$p_z$", fontsize=16, color='black') ax.set_ylabel(r"$p_y$", fontsize=16, color='black') ax.legend(loc='right', fontsize=12) fig.show()
Example #25
Source File: interpretation.py From lumin with Apache License 2.0 | 4 votes |
def plot_multibody_weighted_outputs(model:AbsModel, inputs:Union[np.ndarray,Tensor], block_names:Optional[List[str]]=None, use_mean:bool=False, savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None: r''' Interpret how a model relies on the outputs of each block in a :class:MultiBlock by plotting the outputs of each block as weighted by the tail block. This function currently only supports models whose tail block contains a single neuron in the first dense layer. Input data is passed through the model and the absolute sums of the weighted block outputs are computed per datum, and optionally averaged over the number of block outputs. Arguments: model: model to interpret inputs: input data to use for interpretation block_names: names for each block to use when plotting use_mean: if True, will average the weighted outputs over the number of output neurons in each block savename: Optional name of file to which to save the plot of feature importances settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance ''' assert model.tail[0].weight.shape[0] == 1, 'This function currently only supports models whose tail block contains a single neuron in the first dense layer' if block_names is not None: assert len(block_names) == len(model.body.blocks), 'block_names passed, but number of names does not match number of blocks' else: block_names = [f'{i}' for i in range(len(model.body.blocks))] hook = FowardHook(model.tail[0]) model.predict(inputs) y, itr = [], 0 for b in model.body.blocks: o = hook.input[0][:,itr:itr+b.get_out_size()] w = model.tail[0].weight[0][itr:itr+b.get_out_size()] y.append(to_np(torch.abs(o@w)/b.get_out_size()) if use_mean else to_np(torch.abs(o@w))) itr += b.get_out_size() with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette): plt.figure(figsize=(settings.w_mid, settings.h_mid)) sns.boxplot(x=block_names, y=y) plt.xlabel("Block", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.ylabel(r"Mean $|\bar{w}\cdot\bar{x}|$" if use_mean else r"$|\bar{w}\cdot\bar{x}|$", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc) if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight') plt.show()
Example #26
Source File: interpretation.py From lumin with Apache License 2.0 | 4 votes |
def plot_2d_partial_dependence(model:Any, df:pd.DataFrame, feats:Tuple[str,str], train_feats:List[str], ignore_feats:Optional[List[str]]=None, input_pipe:Pipeline=None, sample_sz:Optional[int]=None, wgt_name:Optional[str]=None, n_points:Tuple[int,int]=[20,20], pdp_interact_kargs:Optional[Dict[str,Any]]=None, pdp_interact_plot_kargs:Optional[Dict[str,Any]]=None, savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None: r''' Wrapper for PDPbox to plot 2D dependence of specified pair of features using provided NN or RF. If features have been preprocessed using an SK-Learn Pipeline, then that can be passed in order to rescale them back to their original values. Arguments: model: any trained model with a .predict method df: DataFrame containing training data feats: pair of features for which to evaluate the partial dependence of the model train_feats: list of all training features including ones which were later ignored, i.e. input features considered when input_pipe was fitted ignore_feats: features present in training data which were not used to train the model (necessary to correctly deprocess feature using input_pipe) input_pipe: SK-Learn Pipeline which was used to process the training data sample_sz: if set, will only compute partial dependence on a random sample with replacement of the training data, sampled according to weights (if set). Speeds up computation and allows weighted partial dependencies to computed. wgt_name: Optional column name to use as sampling weights n_points: pair of numbers of points at which to evaluate the model output, passed to pdp_interact as num_grid_points n_clusters: number of clusters in which to group dependency lines. Set to None to show all lines pdp_isolate_kargs: optional dictionary of keyword arguments to pass to pdp_isolate pdp_plot_kargs: optional dictionary of keyword arguments to pass to pdp_plot savename: Optional name of file to which to save the plot of feature importances settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance ''' check_pdpbox() if pdp_interact_kargs is None: pdp_interact_kargs = {} if pdp_interact_plot_kargs is None: pdp_interact_plot_kargs = {} if sample_sz is not None or wgt_name is not None: if wgt_name is None: weights = None else: weights = df[wgt_name].values.astype('float64') weights *= 1/np.sum(weights) df = df.sample(len(df) if sample_sz is None else sample_sz, weights=weights, replace=True) interact = pdp.pdp_interact(model, df, [f for f in train_feats if ignore_feats is None or f not in ignore_feats], feats, num_grid_points=n_points, **pdp_interact_kargs) if input_pipe is not None: _deprocess_interact(interact, input_pipe, feats, train_feats) with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette): fig, ax = pdp.pdp_interact_plot(interact, feats, figsize=(settings.h_large, settings.h_large), plot_params={'title': None, 'subtitle': None, 'cmap':settings.seq_palette}, **pdp_interact_plot_kargs) ax['title_ax'].remove() ax['pdp_inter_ax'].set_xlabel(feats[0], fontsize=settings.lbl_sz, color=settings.lbl_col) ax['pdp_inter_ax'].set_ylabel(feats[1], fontsize=settings.lbl_sz, color=settings.lbl_col) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc) if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}') plt.show()
Example #27
Source File: interpretation.py From lumin with Apache License 2.0 | 4 votes |
def plot_1d_partial_dependence(model:Any, df:pd.DataFrame, feat:str, train_feats:List[str], ignore_feats:Optional[List[str]]=None, input_pipe:Pipeline=None, sample_sz:Optional[int]=None, wgt_name:Optional[str]=None, n_clusters:Optional[int]=10, n_points:int=20, pdp_isolate_kargs:Optional[Dict[str,Any]]=None, pdp_plot_kargs:Optional[Dict[str,Any]]=None, y_lim:Optional[Union[Tuple[float,float],List[float]]]=None, savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None: r''' Wrapper for PDPbox to plot 1D dependence of specified feature using provided NN or RF. If features have been preprocessed using an SK-Learn Pipeline, then that can be passed in order to rescale the x-axis back to its original values. Arguments: model: any trained model with a .predict method df: DataFrame containing training data feat: feature for which to evaluate the partial dependence of the model train_feats: list of all training features including ones which were later ignored, i.e. input features considered when input_pipe was fitted ignore_feats: features present in training data which were not used to train the model (necessary to correctly deprocess feature using input_pipe) input_pipe: SK-Learn Pipeline which was used to process the training data sample_sz: if set, will only compute partial dependence on a random sample with replacement of the training data, sampled according to weights (if set). Speeds up computation and allows weighted partial dependencies to computed. wgt_name: Optional column name to use as sampling weights n_points: number of points at which to evaluate the model output, passed to pdp_isolate as num_grid_points n_clusters: number of clusters in which to group dependency lines. Set to None to show all lines pdp_isolate_kargs: optional dictionary of keyword arguments to pass to pdp_isolate pdp_plot_kargs: optional dictionary of keyword arguments to pass to pdp_plot y_lim: If set, will limit y-axis plot range to tuple savename: Optional name of file to which to save the plot of feature importances settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance ''' if pdp_isolate_kargs is None: pdp_isolate_kargs = {} if pdp_plot_kargs is None: pdp_plot_kargs = {} if sample_sz is not None or wgt_name is not None: if wgt_name is None: weights = None else: weights = df[wgt_name].values.astype('float64') weights *= 1/np.sum(weights) df = df.sample(len(df) if sample_sz is None else sample_sz, weights=weights, replace=True) iso = pdp.pdp_isolate(model, df, [f for f in train_feats if ignore_feats is None or f not in ignore_feats], feat, num_grid_points=n_points, **pdp_isolate_kargs) if input_pipe is not None: _deprocess_iso(iso, input_pipe, feat, train_feats) with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette): fig, ax = pdp.pdp_plot(iso, feat, center=False, plot_lines=True, cluster=n_clusters is not None, n_cluster_centers=n_clusters, plot_params={'title': None, 'subtitle': None}, figsize=(settings.w_mid, settings.h_mid), **pdp_plot_kargs) ax['title_ax'].remove() ax['pdp_ax'].set_xlabel(feat, fontsize=settings.lbl_sz, color=settings.lbl_col) ax['pdp_ax'].set_ylabel("Partial dependence", fontsize=settings.lbl_sz, color=settings.lbl_col) if y_lim is not None: ax['pdp_ax'].set_ylim(y_lim) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc) if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}') plt.show()
Example #28
Source File: general.py From pyfinance with MIT License | 4 votes |
def corr_heatmap( x, mask_half=True, cmap="RdYlGn_r", vmin=-1, vmax=1, linewidths=0.5, square=True, figsize=(10, 10), **kwargs ): """Wrapper around seaborn.heatmap for visualizing correlation matrix. Parameters ---------- x : DataFrame Underlying data (not a correlation matrix) mask_half : bool, default True If True, mask (whiteout) the upper right triangle of the matrix All other parameters passed to seaborn.heatmap: https://seaborn.pydata.org/generated/seaborn.heatmap.html Example ------- # Generate some correlated data >>> import numpy as np >>> import pandas as pd >>> k = 10 >>> size = 400 >>> mu = np.random.randint(0, 10, k).astype(float) >>> r = np.random.ranf(k ** 2).reshape((k, k)) * 5 >>> df = pd.DataFrame(np.random.multivariate_normal(mu, r, size=size)) >>> corr_heatmap(df, figsize=(6, 6)) """ if mask_half: mask = np.zeros_like(x.corr().values) mask[np.triu_indices_from(mask)] = True else: mask = None with sns.axes_style("white"): return sns.heatmap( x.corr(), cmap=cmap, vmin=vmin, vmax=vmax, linewidths=linewidths, square=square, mask=mask, **kwargs )
Example #29
Source File: representation_plot.py From srl-zoo with MIT License | 4 votes |
def prettyPlotAgainst(states, rewards, title="Representation", fit_pca=False, cmap='coolwarm'): """ State dimensions are plotted one against the other (it creates a matrix of 2d representation) using rewards for coloring, the diagonal is a distribution plot, and the scatter plots have a density outline. :param states: (np.ndarray) :param rewards: (np.ndarray) :param title: (str) :param fit_pca: (bool) :param cmap: (str) """ with sns.axes_style('white'): n = states.shape[1] fig, ax_mat = plt.subplots(n, n, figsize=(10, 10), sharex=False, sharey=False) fig.subplots_adjust(hspace=0.2, wspace=0.2) if fit_pca: title += " (PCA)" states = PCA(n_components=n).fit_transform(states) c_idx = cm.get_cmap(cmap) norm = colors.Normalize(vmin=np.min(rewards), vmax=np.max(rewards)) for i in range(n): for j in range(n): x, y = states[:, i], states[:, j] ax = ax_mat[i, j] if i != j: ax.scatter(x, y, c=rewards, cmap=cmap, s=5) sns.kdeplot(x, y, cmap="Greys", ax=ax, shade=True, shade_lowest=False, alpha=0.2) ax.set_xlim([np.min(x), np.max(x)]) ax.set_ylim([np.min(y), np.max(y)]) else: if len(np.unique(rewards)) < 10: for r in np.unique(rewards): sns.distplot(x[rewards == r], color=c_idx(norm(r)), ax=ax) else: sns.distplot(x, ax=ax) if i == 0: ax.set_title("Dim {}".format(j), y=1.2) if i != j: # Hide ticks if i != 0 and i != n - 1: ax.xaxis.set_visible(False) if j != 0 and j != n - 1: ax.yaxis.set_visible(False) # Set up ticks only on one side for the "edge" subplots... if j == 0: ax.yaxis.set_ticks_position('left') if j == n - 1: ax.yaxis.set_ticks_position('right') if i == 0: ax.xaxis.set_ticks_position('top') if i == n - 1: ax.xaxis.set_ticks_position('bottom') plt.suptitle(title, fontsize=16) plt.show()
Example #30
Source File: plots.py From Comparative-Annotation-Toolkit with Apache License 2.0 | 4 votes |
def improvement_plot(consensus_data, ordered_genomes, improvement_tgt): def do_kdeplot(x, y, ax, n_levels=None, bw='scott'): try: sns.kdeplot(x, y, ax=ax, cut=0, cmap='Purples_d', shade=True, shade_lowest=False, n_levels=n_levels, bw=bw, rasterized=True) except: logger.warning('Unable to do a KDE fit to AUGUSTUS improvement.') pass af = luigi.local_target.atomic_file(improvement_tgt.path) with PdfPages(af.tmp_path) as pdf, sns.axes_style("whitegrid"): for genome in ordered_genomes: data = pd.DataFrame(consensus_data[genome]['Evaluation Improvement']['changes']) unchanged = consensus_data[genome]['Evaluation Improvement']['unchanged'] if len(data) == 0: continue data.columns = ['transMap original introns', 'transMap intron annotation support', 'transMap intron RNA support', 'Original introns', 'Intron annotation support', 'Intron RNA support', 'transMap alignment goodness', 'Alignment goodness'] fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2) for ax in [ax1, ax2, ax3, ax4]: ax.set_xlim(0, 100) ax.set_ylim(0, 100) do_kdeplot(data['transMap original introns'], data['Original introns'], ax1, n_levels=25, bw=2) sns.regplot(x=data['transMap original introns'], y=data['Original introns'], ax=ax1, color='#A9B36F', scatter_kws={"s": 3, 'alpha': 0.7, 'rasterized': True}, fit_reg=False) do_kdeplot(data['transMap intron annotation support'], data['Intron annotation support'], ax2, n_levels=25, bw=2) sns.regplot(x=data['transMap intron annotation support'], y=data['Intron annotation support'], ax=ax2, color='#A9B36F', scatter_kws={"s": 3, 'alpha': 0.7, 'rasterized': True}, fit_reg=False) do_kdeplot(data['transMap intron RNA support'], data['Intron RNA support'], ax3, n_levels=25, bw=2) sns.regplot(x=data['transMap intron RNA support'], y=data['Intron RNA support'], ax=ax3, color='#A9B36F', scatter_kws={"s": 3, 'alpha': 0.7, 'rasterized': True}, fit_reg=False) do_kdeplot(data['transMap alignment goodness'], data['Alignment goodness'], ax4, n_levels=20, bw=1) sns.regplot(x=data['transMap alignment goodness'], y=data['Alignment goodness'], ax=ax4, color='#A9B36F', scatter_kws={"s": 3, 'alpha': 0.7, 'rasterized': True}, fit_reg=False) fig.suptitle('AUGUSTUS metric improvements for {:,} transcripts in {}.\n' '{:,} transMap transcripts were chosen.'.format(len(data), genome, unchanged)) for ax in [ax1, ax2, ax3, ax4]: ax.set(adjustable='box', aspect='equal') fig.subplots_adjust(hspace=0.3) multipage_close(pdf, tight_layout=False) af.move_to_final_destination()