Python seaborn.heatmap() Examples
The following are 30
code examples of seaborn.heatmap().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
seaborn
, or try the search function
.
Example #1
Source File: time_align.py From scanorama with MIT License | 7 votes |
def time_align_visualize(alignments, time, y, namespace='time_align'): plt.figure() heat = np.flip(alignments + alignments.T + np.eye(alignments.shape[0]), axis=0) sns.heatmap(heat, cmap="YlGnBu", vmin=0, vmax=1) plt.savefig(namespace + '_heatmap.svg') G = nx.from_numpy_matrix(alignments) G = nx.maximum_spanning_tree(G) pos = {} for i in range(len(G.nodes)): pos[i] = np.array([time[i], y[i]]) mst_edges = set(nx.maximum_spanning_tree(G).edges()) weights = [ G[u][v]['weight'] if (not (u, v) in mst_edges) else 8 for u, v in G.edges() ] plt.figure() nx.draw(G, pos, edges=G.edges(), width=10) plt.ylim([-1, 1]) plt.savefig(namespace + '.svg')
Example #2
Source File: heatmap.py From Attention-on-Attention-for-VQA with MIT License | 6 votes |
def plot_heatmap(a, b, title='title', saveLoc='temp'): a = a.reshape((6,6)) b = b.reshape((6,6)) fig, (ax1, ax2) = plt.subplots(1, 2) h1 = sns.heatmap(a,cmap="magma",cbar=False,ax=ax1) h1.set_title("Attention 1") h1.invert_yaxis() h1.set_xlabel('') h1.set_ylabel('') h2 = sns.heatmap(b,cmap="magma",ax=ax2) h2.set_title("Attention 1") h2.invert_yaxis() h2.set_xlabel('') h2.set_ylabel('') plt.show()
Example #3
Source File: metrics.py From axcell with Apache License 2.0 | 6 votes |
def plot_confusion_matrix(self, name): cm, target_names = self.confusion_matrix(name) # cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] df_cm = pd.DataFrame(cm, index=[i for i in target_names], columns=[i for i in target_names]) plt.figure(figsize=(20, 20)) ax = sn.heatmap(df_cm, annot=True, square=True, fmt="d", cmap="YlGnBu", mask=cm == 0, linecolor="black", linewidths=0.01) ax.set_ylabel("True") ax.set_xlabel("Predicted")
Example #4
Source File: design_matrix.py From nltools with MIT License | 6 votes |
def heatmap(self, figsize=(8, 6), **kwargs): """Visualize Design Matrix spm style. Use .plot() for typical pandas plotting functionality. Can pass optional keyword args to seaborn heatmap. """ cmap = kwargs.pop('cmap', 'gray') fig, ax = plt.subplots(1, figsize=figsize) ax = sns.heatmap(self, cmap=cmap, cbar=False, ax=ax, **kwargs) for _, spine in ax.spines.items(): spine.set_visible(True) for i, label in enumerate(ax.get_yticklabels()): if i in [0, self.shape[0] - 1]: label.set_visible(True) else: label.set_visible(False) ax.axhline(linewidth=4, color="k") ax.axvline(linewidth=4, color="k") ax.axhline(y=self.shape[0], color='k', linewidth=4) ax.axvline(x=self.shape[1], color='k', linewidth=4) plt.yticks(rotation=0)
Example #5
Source File: visFunction.py From uiKLine with MIT License | 6 votes |
def plotSigHeats(signals,markets,start=0,step=2,size=1,iters=6): """ 打印信号回测盈损热度图,寻找参数稳定岛 """ sigMat = pd.DataFrame(index=range(iters),columns=range(iters)) for i in range(iters): for j in range(iters): climit = start + i*step wlimit = start + j*step caps,poss = plotSigCaps(signals,markets,climit=climit,wlimit=wlimit,size=size,op=False) sigMat[i][j] = caps[-1] sns.heatmap(sigMat.values.astype(np.float64),annot=True,fmt='.2f',annot_kws={"weight": "bold"}) xTicks = [i+0.5 for i in range(iters)] yTicks = [iters-i-0.5 for i in range(iters)] xyLabels = [str(start+i*step) for i in range(iters)] _, labels = plt.yticks(yTicks,xyLabels) plt.setp(labels, rotation=0) _, labels = plt.xticks(xTicks,xyLabels) plt.setp(labels, rotation=90) plt.xlabel('Loss Stop @') plt.ylabel('Profit Stop @') return sigMat
Example #6
Source File: common.py From typhon with MIT License | 6 votes |
def _plot_weights(self, title, file, layer_index=0, vmin=-5, vmax=5): import seaborn as sns sns.set_context("paper") layers = self.iwp.estimator.steps[-1][1].coefs_ layer = layers[layer_index] f, ax = plt.subplots(figsize=(18, 12)) weights = pd.DataFrame(layer) weights.index = self.iwp.inputs sns.set(font_scale=1.1) # Draw a heatmap with the numeric values in each cell sns.heatmap( weights, annot=True, fmt=".1f", linewidths=.5, ax=ax, cmap="difference", center=0, vmin=vmin, vmax=vmax, # annot_kws={"size":14}, ) ax.tick_params(labelsize=18) f.tight_layout() f.savefig(file)
Example #7
Source File: visFunction.py From uiKLine with MIT License | 6 votes |
def plotSigHeats(signals,markets,start=0,step=2,size=1,iters=6): """ 打印信号回测盈损热度图,寻找参数稳定岛 """ sigMat = pd.DataFrame(index=range(iters),columns=range(iters)) for i in range(iters): for j in range(iters): climit = start + i*step wlimit = start + j*step caps,poss = plotSigCaps(signals,markets,climit=climit,wlimit=wlimit,size=size,op=False) sigMat[i][j] = caps[-1] sns.heatmap(sigMat.values.astype(np.float64),annot=True,fmt='.2f',annot_kws={"weight": "bold"}) xTicks = [i+0.5 for i in range(iters)] yTicks = [iters-i-0.5 for i in range(iters)] xyLabels = [str(start+i*step) for i in range(iters)] _, labels = plt.yticks(yTicks,xyLabels) plt.setp(labels, rotation=0) _, labels = plt.xticks(xTicks,xyLabels) plt.setp(labels, rotation=90) plt.xlabel('Loss Stop @') plt.ylabel('Profit Stop @') return sigMat
Example #8
Source File: plots.py From cgpm with Apache License 2.0 | 6 votes |
def plot_heatmap( D, xordering=None, yordering=None, xticklabels=None, yticklabels=None, vmin=None, vmax=None, ax=None): import seaborn as sns D = np.copy(D) if ax is None: _, ax = plt.subplots() if xticklabels is None: xticklabels = np.arange(D.shape[0]) if yticklabels is None: yticklabels = np.arange(D.shape[1]) if xordering is not None: xticklabels = xticklabels[xordering] D = D[:,xordering] if yordering is not None: yticklabels = yticklabels[yordering] D = D[yordering,:] sns.heatmap( D, yticklabels=yticklabels, xticklabels=xticklabels, linewidths=0.2, cmap='BuGn', ax=ax, vmin=vmin, vmax=vmax) ax.set_xticklabels(xticklabels, rotation=90) ax.set_yticklabels(yticklabels, rotation=0) return ax
Example #9
Source File: test_ext_signature.py From feets with MIT License | 6 votes |
def test_plot_SignaturePhMag(fig_test, fig_ref): # fig test ext = extractors.Signature() kwargs = ext.get_default_params() kwargs.update( feature="SignaturePhMag", value=[[1, 2, 3, 4]], ax=fig_test.subplots(), plot_kws={}, time=[1, 2, 3, 4], magnitude=[1, 2, 3, 4], error=[1, 2, 3, 4], features={"PeriodLS": 1, "Amplitude": 10}, ) ext.plot(**kwargs) # expected eax = fig_ref.subplots() eax.set_title( f"SignaturePhMag - {kwargs['phase_bins']}x{kwargs['mag_bins']}" ) eax.set_xlabel("Phase") eax.set_ylabel("Magnitude") sns.heatmap(kwargs["value"], ax=eax, **kwargs["plot_kws"])
Example #10
Source File: spatial_heatmap.py From NanoPlot with GNU General Public License v3.0 | 6 votes |
def spatial_heatmap(array, path, title=None, color="Greens", figformat="png"): """Taking channel information and creating post run channel activity plots.""" logging.info("Nanoplotter: Creating heatmap of reads per channel using {} reads." .format(array.size)) activity_map = Plot( path=path + "." + figformat, title="Number of reads generated per channel") layout = make_layout(maxval=np.amax(array)) valueCounts = pd.value_counts(pd.Series(array)) for entry in valueCounts.keys(): layout.template[np.where(layout.structure == entry)] = valueCounts[entry] plt.figure() ax = sns.heatmap( data=pd.DataFrame(layout.template, index=layout.yticks, columns=layout.xticks), xticklabels="auto", yticklabels="auto", square=True, cbar_kws={"orientation": "horizontal"}, cmap=color, linewidths=0.20) ax.set_title(title or activity_map.title) activity_map.fig = ax.get_figure() activity_map.save(format=figformat) plt.close("all") return [activity_map]
Example #11
Source File: base_backend.py From delira with GNU Affero General Public License v3.0 | 6 votes |
def _heatmap(self, plot_kwargs=None, figure_kwargs=None, **kwargs): """ Function to create a heatmap plot and push it Parameters ---------- plot_kwargs : dict the arguments for plotting figure_kwargs : dict the arguments to actually create the figure **kwargs : additional keyword arguments for pushing the created figure to the logging writer """ if figure_kwargs is None: figure_kwargs = {} if plot_kwargs is None: plot_kwargs = {} with self.FigureManager(self._figure, figure_kwargs, kwargs): from seaborn import heatmap heatmap(**plot_kwargs)
Example #12
Source File: heat_map.py From NAS-Benchmark with GNU General Public License v3.0 | 6 votes |
def draw(self): f, ax1= plt.subplots(figsize=(15, 9)) sns.heatmap(self.df1, annot=True, ax=ax1, annot_kws={'size': 13, 'weight': 'bold'}) ax1.set_xlabel('Ops without none operation', labelpad=14, fontsize='medium') ax1.set_ylabel('Possiable Input Index', labelpad=14, fontsize='medium') # ax1.set_title('The weights for Ops without none operation in normal cell', pad = 18, fontsize='x-large') # f, ax2= plt.subplots(figsize=(15, 9)) # sns.heatmap(self.df2, annot=True, ax=ax2, # annot_kws={'size': 13, 'weight': 'bold'}) # ax2.set_xlabel('Ops without none operation', labelpad=14, fontsize='medium') # ax2.set_ylabel('Possible predecessors id for each intermediate node', labelpad=14, fontsize='medium') # #ax2.set_title('The weights for Ops without none operation in reduction cell', pad = 18, fontsize='x-large') plt.savefig(self.store_path+'/normal_hm.pdf', bbox_inches = 'tight', dpi=600) # plt.show()
Example #13
Source File: experiment.py From axcell with Apache License 2.0 | 6 votes |
def _plot_confusion_matrix(self, cm, normalize, fmt=None): if normalize: s = cm.sum(axis=1)[:, None] s[s == 0] = 1 cm = cm / s if fmt is None: fmt = "0.2f" if normalize else "d" target_names = self.get_cm_labels(cm) df_cm = pd.DataFrame(cm, index=[i for i in target_names], columns=[i for i in target_names]) plt.figure(figsize=(10, 10)) ax = sn.heatmap(df_cm, annot=True, square=True, fmt=fmt, cmap="YlGnBu", mask=cm == 0, linecolor="black", linewidths=0.01) ax.set_ylabel("True") ax.set_xlabel("Predicted")
Example #14
Source File: basenji_sat_h5.py From basenji with Apache License 2.0 | 6 votes |
def plot_heat(ax, sat_delta_ti, min_limit): """ Plot satmut deltas. Args: ax (Axis): matplotlib axis to plot to. sat_delta_ti (4 x L_sm array): Single target delta matrix for saturated mutagenesis region, min_limit (float): Minimum heatmap limit. """ vlim = max(min_limit, abs(sat_delta_ti).max()) sns.heatmap( sat_delta_ti, linewidths=0, cmap='RdBu_r', vmin=-vlim, vmax=vlim, xticklabels=False, ax=ax) ax.yaxis.set_ticklabels('ACGT', rotation='horizontal') # , size=10)
Example #15
Source File: basenji_motifs_denovo.py From basenji with Apache License 2.0 | 6 votes |
def plot_kernel(kernel_weights, out_pdf): depth, width = kernel_weights.shape fig_width = 2 + 1.5*np.log2(width) # normalize kernel_weights -= kernel_weights.mean(axis=0) # plot sns.set(font_scale=1.5) plt.figure(figsize=(fig_width, depth)) sns.heatmap(kernel_weights, cmap='PRGn', linewidths=0.2, center=0) ax = plt.gca() ax.set_xticklabels(range(1,width+1)) if depth == 4: ax.set_yticklabels('ACGT', rotation='horizontal') else: ax.set_yticklabels(range(1,depth+1), rotation='horizontal') plt.savefig(out_pdf) plt.close()
Example #16
Source File: functions.py From Match-LSTM with MIT License | 6 votes |
def draw_heatmap_sea(x, xlabels, ylabels, answer, save_path, inches=(11, 3), bottom=0.45, linewidths=0.2): """ draw matrix heatmap with seaborn :param x: :param xlabels: :param ylabels: :param answer: :param save_path: :param inches: :param bottom: :param linewidths: :return: """ fig, ax = plt.subplots() plt.subplots_adjust(bottom=bottom) plt.title('Answer: ' + answer) sns.heatmap(x, linewidths=linewidths, ax=ax, cmap='Blues', xticklabels=xlabels, yticklabels=ylabels) fig.set_size_inches(inches) fig.savefig(save_path)
Example #17
Source File: QARisk.py From QUANTAXIS with MIT License | 6 votes |
def plot_signal(self, start=None, end=None): """ 使用热力图画出买卖信号 """ start = self.account.start_date if start is None else start end = self.account.end_date if end is None else end _, ax = plt.subplots(figsize=(20, 18)) sns.heatmap(self.account.trade.reset_index().drop('account_cookie', axis=1).set_index('datetime').loc[start:end], cmap="YlGnBu", linewidths=0.05, ax=ax) ax.set_title('SIGNAL TABLE --ACCOUNT: {}'.format(self.account.account_cookie)) ax.set_xlabel('Code') ax.set_ylabel('DATETIME') return plt
Example #18
Source File: QARisk.py From QUANTAXIS with MIT License | 6 votes |
def plot_dailyhold(self, start=None, end=None): """ 使用热力图画出每日持仓 """ start = self.account.start_date if start is None else start end = self.account.end_date if end is None else end _, ax = plt.subplots(figsize=(20, 8)) sns.heatmap(self.account.daily_hold.reset_index().set_index('date').loc[start:end], cmap="YlGnBu", linewidths=0.05, ax=ax) ax.set_title('HOLD TABLE --ACCOUNT: {}'.format(self.account.account_cookie)) ax.set_xlabel('Code') ax.set_ylabel('DATETIME') return plt
Example #19
Source File: basenji_sat_plot.py From basenji with Apache License 2.0 | 6 votes |
def plot_heat(ax, sat_delta_ti, min_limit): """ Plot satmut deltas. Args: ax (Axis): matplotlib axis to plot to. sat_delta_ti (4 x L_sm array): Single target delta matrix for saturated mutagenesis region, min_limit (float): Minimum heatmap limit. """ vlim = max(min_limit, np.nanmax(np.abs(sat_delta_ti))) sns.heatmap( sat_delta_ti, linewidths=0, cmap='RdBu_r', vmin=-vlim, vmax=vlim, xticklabels=False, ax=ax) ax.yaxis.set_ticklabels('ACGT', rotation='horizontal') # , size=10)
Example #20
Source File: lda_plots.py From numpy-ml with GNU General Public License v3.0 | 6 votes |
def plot_unsmoothed(): corpus, T = generate_corpus() L = LDA(T) L.train(corpus, verbose=False) fig, axes = plt.subplots(1, 2) ax1 = sns.heatmap(L.beta, xticklabels=[], yticklabels=[], ax=axes[0]) ax1.set_xlabel("Topics") ax1.set_ylabel("Words") ax1.set_title("Recovered topic-word distribution") ax2 = sns.heatmap(L.gamma, xticklabels=[], yticklabels=[], ax=axes[1]) ax2.set_xlabel("Topics") ax2.set_ylabel("Documents") ax2.set_title("Recovered document-topic distribution") plt.savefig("img/plot_unsmoothed.png", dpi=300) plt.close("all")
Example #21
Source File: interpretation.py From lumin with Apache License 2.0 | 6 votes |
def plot_embedding(embed:OrderedDict, feat:str, savename:Optional[str]=None, settings:PlotSettings=PlotSettings()) -> None: r''' Visualise weights in provided categorical entity-embedding matrix Arguments: embed: state_dict of trained nn.Embedding feat: name of feature embedded savename: Optional name of file to which to save the plot of feature importances settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance ''' with sns.axes_style(**settings.style): plt.figure(figsize=(settings.w_small, settings.h_small)) sns.heatmap(to_np(embed['weight']), annot=True, fmt='.1f', linewidths=.5, cmap=settings.div_palette, annot_kws={'fontsize':settings.leg_sz}) plt.xlabel("Embedding", fontsize=settings.lbl_sz, color=settings.lbl_col) plt.ylabel(feat, fontsize=settings.lbl_sz, color=settings.lbl_col) plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col) plt.title(settings.title, fontsize=settings.title_sz, color=settings.title_col, loc=settings.title_loc) if savename is not None: plt.savefig(settings.savepath/f'{savename}{settings.format}', bbox_inches='tight') plt.show()
Example #22
Source File: akita_sat_vcf.py From basenji with Apache License 2.0 | 6 votes |
def plot_heat(ax, sat_score_ti, min_limit=None): """ Plot satmut deltas. Args: ax (Axis): matplotlib axis to plot to. sat_delta_ti (L_sm x 4 array): Single target delta matrix for saturated mutagenesis region, """ if np.max(sat_score_ti) < min_limit: vmax = min_limit else: vmax = None sns.heatmap( sat_score_ti.T, linewidths=0, xticklabels=False, yticklabels=False, cmap='Blues', vmax=vmax, ax=ax) # yticklabels break the plot for some reason # ax.yaxis.set_ticklabels('ACGT', rotation='horizontal')
Example #23
Source File: stock_visualizer.py From stock-analysis with MIT License | 6 votes |
def heatmap(self, pct_change=False, **kwargs): """ Generate a seaborn heatmap for correlations between assets. Parameters: - pct_change: Whether or not to show the correlations of the daily percent change in price or just use the closing price. - kwargs: Keyword arguments to pass down to `sns.heatmap()` Returns: A seaborn heatmap """ pivot = self.data.pivot_table( values='close', index=self.data.index, columns='name' ) if pct_change: pivot = pivot.pct_change() return sns.heatmap(pivot.corr(), annot=True, center=0, **kwargs)
Example #24
Source File: plot.py From retentioneering-tools with Mozilla Public License 2.0 | 6 votes |
def altair_step_matrix(diff, plot_name=None, title='', vmin=None, vmax=None, font_size=12, **kwargs): heatmap_data = diff.reset_index().melt('index') heatmap_data.columns = ['y', 'x', 'z'] table = alt.Chart(heatmap_data).encode( x=alt.X('x:O', sort=None), y=alt.Y('y:O', sort=None) ) heatmap = table.mark_rect().encode( color=alt.Color( 'z:Q', scale=alt.Scale(scheme='blues'), ) ) text = table.mark_text( align='center', fontSize=font_size ).encode( text='z', color=alt.condition( abs(alt.datum.z) < 0.8, alt.value('black'), alt.value('white')) ) heatmap_object = (heatmap + text).properties( width=3 * font_size * len(diff.columns), height=2 * font_size * diff.shape[0] ) return heatmap_object, plot_name, None, diff.retention.retention_config
Example #25
Source File: Auto_NLP.py From Auto_ViML with Apache License 2.0 | 5 votes |
def plot_confusion_matrix(y_test,y_pred, model_name='Model'): """ This plots a beautiful confusion matrix based on input: ground truths and predictions """ #Confusion Matrix '''Plotting CONFUSION MATRIX''' import matplotlib.pyplot as plt import seaborn as sns sns.set_style('darkgrid') '''Display''' from IPython.core.display import display, HTML display(HTML("<style>.container { width:95% !important; }</style>")) pd.options.display.float_format = '{:,.2f}'.format #Get the confusion matrix and put it into a df from sklearn.metrics import confusion_matrix, f1_score cm = confusion_matrix(y_test, y_pred) cm_df = pd.DataFrame(cm, index = np.unique(y_test).tolist(), columns = np.unique(y_test).tolist(), ) #Plot the heatmap plt.figure(figsize=(12, 8)) sns.heatmap(cm_df, center=0, cmap=sns.diverging_palette(220, 15, as_cmap=True), annot=True, fmt='g') plt.title(' %s \nF1 Score(avg = micro): %0.2f \nF1 Score(avg = macro): %0.2f' %( model_name,f1_score(y_test, y_pred, average='micro'),f1_score(y_test, y_pred, average='macro')), fontsize = 13) plt.ylabel('True label', fontsize = 13) plt.xlabel('Predicted label', fontsize = 13) plt.show(); ##############################################################################################
Example #26
Source File: Auto_NLP.py From Auto_ViML with Apache License 2.0 | 5 votes |
def plot_confusion_matrix(y_test,y_pred, model_name='Model'): """ This plots a beautiful confusion matrix based on input: ground truths and predictions """ #Confusion Matrix '''Plotting CONFUSION MATRIX''' import matplotlib.pyplot as plt import seaborn as sns sns.set_style('darkgrid') '''Display''' from IPython.core.display import display, HTML display(HTML("<style>.container { width:95% !important; }</style>")) pd.options.display.float_format = '{:,.2f}'.format #Get the confusion matrix and put it into a df from sklearn.metrics import confusion_matrix, f1_score cm = confusion_matrix(y_test, y_pred) cm_df = pd.DataFrame(cm, index = np.unique(y_test).tolist(), columns = np.unique(y_test).tolist(), ) #Plot the heatmap plt.figure(figsize=(12, 8)) sns.heatmap(cm_df, center=0, cmap=sns.diverging_palette(220, 15, as_cmap=True), annot=True, fmt='g') plt.title(' %s \nF1 Score(avg = micro): %0.2f \nF1 Score(avg = macro): %0.2f' %( model_name,f1_score(y_test, y_pred, average='micro'),f1_score(y_test, y_pred, average='macro')), fontsize = 13) plt.ylabel('True label', fontsize = 13) plt.xlabel('Predicted label', fontsize = 13) plt.show(); ##############################################################################################
Example #27
Source File: multiple_linear_regression.py From deep-learning-samples with The Unlicense | 5 votes |
def plot_correlation_heatmap(X, header): """Plot a heatmap of the correlation matrix for X. This requires the seaborn package to be installed. """ import seaborn cm = np.corrcoef(X.T) hm = seaborn.heatmap(cm, cbar=True, annot=True, square=True, yticklabels=header, xticklabels=header) plt.show()
Example #28
Source File: visuals.py From B-SOID with GNU General Public License v3.0 | 5 votes |
def plot_tmat(tm: object): """ :param tm: object, transition matrix data frame :param fps: scalar, camera frame-rate """ fig = plt.figure() fig.suptitle("Transition matrix of {} behaviors".format(tm.shape[0])) sn.heatmap(tm, annot=True) plt.xlabel("Next frame behavior") plt.ylabel("Current frame behavior") plt.show() return fig
Example #29
Source File: visuals.py From B-SOID with GNU General Public License v3.0 | 5 votes |
def plot_tmat(tm: object): """ :param tm: object, transition matrix data frame :param fps: scalar, camera frame-rate """ fig = plt.figure() fig.suptitle("Transition matrix of {} behaviors".format(tm.shape[0])) sn.heatmap(tm, annot=True) plt.xlabel("Next frame behavior") plt.ylabel("Current frame behavior") # plt.show() return fig
Example #30
Source File: viz.py From focus with GNU General Public License v3.0 | 5 votes |
def heatmap(wcor): """ Make a scatterplot of zscore values with gene names as xtick labels. :param wcor: numpy.ndarray matrix of sample correlation structure for predicted expression :return: numpy.ndarray (RGB) formatted heatmap of correlation structure """ mpl.rcParams["figure.figsize"] = [6.4, 6.4] fig = plt.figure() fig.subplots_adjust(bottom=0.20, left=0.28) mask = np.zeros_like(wcor, dtype=np.bool) mask[np.triu_indices_from(mask)] = True ax = sns.heatmap(wcor, mask=mask, cmap="RdBu_r", square=True, linewidths=0, cbar=False, xticklabels=False, yticklabels=False, ax=None, vmin=-1, vmax=1) ax.margins(2) ax.set_aspect("equal", "box") fig.canvas.draw() # save image as numpy array data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") img = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) # rotate heatmap to make upside-down triangle shape rows, cols, ch = img.shape M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 45, 1) dst = cv2.warpAffine(img, M, (cols, rows), borderMode=cv2.BORDER_CONSTANT, borderValue=(255, 255, 255)) # trim extra whitespace crop_img = dst[int(dst.shape[0] / 2.5):int(dst.shape[0] / 1.1)] return crop_img