Python seaborn.color_palette() Examples

The following are 30 code examples of seaborn.color_palette(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module seaborn , or try the search function .
Example #1
Source File: evaluation.py    From tsinfer with GNU General Public License v3.0 7 votes vote down vote up
def edge_plot(ts, filename):
    n = ts.num_samples
    pallete = sns.color_palette("husl", 2 ** n - 1)
    lines = []
    colours = []
    for tree in ts.trees():
        left, right = tree.interval
        for u in tree.nodes():
            children = tree.children(u)
            # Don't bother plotting unary nodes, which will all have the same
            # samples under them as their next non-unary descendant
            if len(children) > 1:
                for c in children:
                    lines.append([(left, c), (right, c)])
                    colours.append(pallete[unrank(tree.samples(c), n)])

    lc = mc.LineCollection(lines, linewidths=2, colors=colours)
    fig, ax = plt.subplots()
    ax.add_collection(lc)
    ax.autoscale()
    save_figure(filename) 
Example #2
Source File: burst_plot.py    From FRETBursts with GNU General Public License v2.0 6 votes vote down vote up
def _register_colormaps():
    import matplotlib as mpl
    import seaborn as sns

    c = sns.color_palette('nipy_spectral', 64)[2:43]
    cmap = mpl.colors.LinearSegmentedColormap.from_list('alex_lv', c)
    cmap.set_under(alpha=0)
    mpl.cm.register_cmap(name='alex_lv', cmap=cmap)

    c = sns.color_palette('YlGnBu', 64)[16:]
    cmap = mpl.colors.LinearSegmentedColormap.from_list('alex', c)
    cmap.set_under(alpha=0)
    mpl.cm.register_cmap(name='alex_light', cmap=cmap)
    mpl.cm.register_cmap(name='YlGnBu_crop', cmap=cmap)
    mpl.cm.register_cmap(name='alex_dark', cmap=mpl.cm.GnBu_r)

    # Temporary hack to workaround issue
    # https://github.com/mwaskom/seaborn/issues/855
    mpl.cm.alex_light = mpl.cm.get_cmap('alex_light')
    mpl.cm.alex_dark = mpl.cm.get_cmap('alex_dark')


# Register colormaps on import if not mocking 
Example #3
Source File: flower_classifier.py    From deep-learning-flower-identifier with MIT License 6 votes vote down vote up
def plot_solution(image_path, model):
    """
    Plot an image with the top 5 class prediction
    :param image_path:
    :param model:
    :return:
    """
    # Set up plot
    plt.figure(figsize=(6, 10))
    ax = plt.subplot(2, 1, 1)
    # Set up title
    flower_num = image_path.split('/')[3]
    title_ = cat_to_name[flower_num]
    # Plot flower
    img = process_image(image_path)
    imshow(img, ax, title=title_);
    # Make prediction
    probs, labs, flowers = predict(image_path, model)
    # Plot bar chart
    plt.subplot(2, 1, 2)
    sns.barplot(x=probs, y=flowers, color=sns.color_palette()[0]);
    plt.show() 
Example #4
Source File: basenji_sat_h5.py    From basenji with Apache License 2.0 6 votes vote down vote up
def plot_sad(ax, sat_loss_ti, sat_gain_ti):
  """ Plot loss and gain SAD scores.

    Args:
        ax (Axis): matplotlib axis to plot to.
        sat_loss_ti (L_sm array): Minimum mutation delta across satmut length.
        sat_gain_ti (L_sm array): Maximum mutation delta across satmut length.
    """

  rdbu = sns.color_palette('RdBu_r', 10)

  ax.plot(-sat_loss_ti, c=rdbu[0], label='loss', linewidth=1)
  ax.plot(sat_gain_ti, c=rdbu[-1], label='gain', linewidth=1)
  ax.set_xlim(0, len(sat_loss_ti))
  ax.legend()
  # ax_sad.grid(True, linestyle=':')

  ax.xaxis.set_ticks([])
  for axis in ['top', 'bottom', 'left', 'right']:
    ax.spines[axis].set_linewidth(0.5) 
Example #5
Source File: utils.py    From scikit-downscale with Apache License 2.0 6 votes vote down vote up
def zscore_ds_plot(training, target, future, corrected):
    labels = ["training", "future", "target", "corrected"]
    colors = {k: c for (k, c) in zip(labels, sns.color_palette("Set2", n_colors=4))}

    alpha = 0.5

    time_target = pd.date_range("1980-01-01", "1989-12-31", freq="D")
    time_training = time_target[~((time_target.month == 2) & (time_target.day == 29))]
    time_future = pd.date_range("1990-01-01", "1999-12-31", freq="D")
    time_future = time_future[~((time_future.month == 2) & (time_future.day == 29))]

    plt.figure(figsize=(8, 4))
    plt.plot(time_training, training.uas, label="training", alpha=alpha, c=colors["training"])
    plt.plot(time_target, target.uas, label="target", alpha=alpha, c=colors["target"])

    plt.plot(time_future, future.uas, label="future", alpha=alpha, c=colors["future"])
    plt.plot(time_future, corrected.uas, label="corrected", alpha=alpha, c=colors["corrected"])

    plt.xlabel("Time")
    plt.ylabel("Eastward Near-Surface Wind (m s-1)")
    plt.legend()

    return 
Example #6
Source File: plotting_utils.py    From QUANTAXIS with MIT License 6 votes vote down vote up
def customize(func):
    """
    修饰器,设置输出图像内容与风格
    """

    @wraps(func)
    def call_w_context(*args, **kwargs):
        set_context = kwargs.pop("set_context", True)
        if set_context:
            color_palette = sns.color_palette("colorblind")
            with plotting_context(), axes_style(), color_palette:
                sns.despine(left=True)
                return func(*args, **kwargs)
        else:
            return func(*args, **kwargs)

    return call_w_context 
Example #7
Source File: helpers.py    From hypertools with MIT License 6 votes vote down vote up
def vals2colors(vals, cmap='GnBu_d',res=100):
    """Maps values to colors
    Args:
    values (list or list of lists) - list of values to map to colors
    cmap (str) - color map (default is 'husl')
    res (int) - resolution of the color map (default: 100)
    Returns:
    list of rgb tuples
    """
    # flatten if list of lists
    if any(isinstance(el, list) for el in vals):
        vals = list(itertools.chain(*vals))

    # get palette from seaborn
    palette = np.array(sns.color_palette(cmap, res))
    ranks = np.digitize(vals, np.linspace(np.min(vals), np.max(vals)+1, res+1)) - 1
    return [tuple(i) for i in palette[ranks, :]] 
Example #8
Source File: cyclic_callbacks.py    From lumin with Apache License 2.0 6 votes vote down vote up
def plot(self):
        r'''
        Plots the history of the lr and momentum evolution as a function of iterations
        '''

        with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette):
            fig, axs = plt.subplots(2, 1, figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid))
            axs[1].set_xlabel("Iterations", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            axs[0].set_ylabel("Learning Rate", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            axs[1].set_ylabel("Momentum", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            axs[0].plot(range(len(self.hist['lr'])), self.hist['lr'])
            axs[1].plot(range(len(self.hist['mom'])), self.hist['mom'])
            for ax in axs:
                ax.tick_params(axis='x', labelsize=self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col)
                ax.tick_params(axis='y', labelsize=self.plot_settings.tk_sz, labelcolor=self.plot_settings.tk_col)
            plt.show() 
Example #9
Source File: opt_callbacks.py    From lumin with Apache License 2.0 6 votes vote down vote up
def plot(self, n_skip:int=0, n_max:Optional[int]=None, lim_y:Optional[Tuple[float,float]]=None) -> None:
        r'''
        Plot the loss as a function of the LR.

        Arguments:
            n_skip: Number of initial iterations to skip in plotting
            n_max: Maximum iteration number to plot
            lim_y: y-range for plotting
        '''

        # TODO: Decide on whether to keep this; could just pass to plot_lr_finders

        with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette):
            plt.figure(figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid))
            plt.plot(self.history['lr'][n_skip:n_max], self.history['loss'][n_skip:n_max], label='Training loss', color='g')
            if np.log10(self.lr_bounds[1])-np.log10(self.lr_bounds[0]) >= 3: plt.xscale('log')
            plt.ylim(lim_y)
            plt.grid(True, which="both")
            plt.legend(loc=self.plot_settings.leg_loc, fontsize=self.plot_settings.leg_sz)
            plt.xticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.yticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.ylabel("Loss", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.xlabel("Learning rate", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.show() 
Example #10
Source File: analysis.py    From dl-eeg-review with MIT License 6 votes vote down vote up
def plot_architectures(df, save_cfg=cfg.saving_config):
    """Plot bar graph showing the architectures used in the study.
    """
    fig, ax = plt.subplots(figsize=(save_cfg['text_width'] / 3, 
                                    save_cfg['text_width'] / 3))
    colors = sns.color_palette()
    counts = df['Architecture (clean)'].value_counts()
    _, _, pct = ax.pie(counts.values, labels=counts.index, autopct='%1.1f%%',
           wedgeprops=dict(width=0.3, edgecolor='w'), colors=colors,
           pctdistance=0.55)
    for i in pct:
        i.set_fontsize(5)

    ax.axis('equal')
    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'architectures')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax 
Example #11
Source File: callbacks.py    From ivis with GNU General Public License v2.0 6 votes vote down vote up
def plot_embeddings(self, embeddings):
        embeddings = MinMaxScaler((0, 1)).fit_transform(self.embeddings)

        fig = plt.figure()
        buf = io.BytesIO()
        sns.scatterplot(x=embeddings[:, 0], y=embeddings[:, 1], s=1,
                        hue=self.labels,
                        palette=sns.color_palette("hls", self.n_classes),
                        linewidth=0)

        plt.savefig(buf, format='png', dpi=300)
        plt.close(fig)
        buf.seek(0)

        image = tf.Summary.Image(encoded_image_string=buf.getvalue())
        return image 
Example #12
Source File: embedding.py    From agnez with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def _prepare_fig_labels(data, labels):
    '''Helper function for settiing up animation canvas
    '''
    # we choose a color palette with seaborn.
    max_label = labels.max()
    palette = np.array(sns.color_palette("hls", max_label+1))
    # we create a scatter plot.

    # we add the labels for each digit.
    t, b, d = data.shape
    data = data.transpose(1, 0, 2).reshape((t*b, d))
    labels = labels[np.newaxis].repeat(t, axis=0).transpose(1, 0)
    labels = labels.flatten()

    fig = plt.figure(figsize=(8, 8))
    return labels, palette, fig 
Example #13
Source File: _plot.py    From q2-qemistree with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def format_colors(feature_metadata, category, color_palette):
    colors = []
    annotations = feature_metadata[category].unique()
    color_map = values_to_colors(annotations, color_palette)

    colors.append('TREE_COLORS')
    colors.append('SEPARATOR TAB')
    colors.append('DATA')

    for idx in feature_metadata.index:
        color = color_map[feature_metadata.loc[idx, category]]

        if feature_metadata.loc[idx, 'structure_source'] == 'MS2':
            style, width = 'normal', 6
        else:
            style, width = 'dashed', 4

        colors.append('%s\tclade\t%s\t%s\t%s' % (idx, color, style, width))

    return '\n'.join(colors) 
Example #14
Source File: n2d.py    From n2d with GNU General Public License v3.0 6 votes vote down vote up
def plot(x, y, plot_id, names=None):
    viz_df = pd.DataFrame(data=x[:5000])
    viz_df['Label'] = y[:5000]
    if names is not None:
        viz_df['Label'] = viz_df['Label'].map(names)

    viz_df.to_csv(args.save_dir + '/' + args.dataset + '.csv')
    plt.subplots(figsize=(8, 5))
    sns.scatterplot(x=0, y=1, hue='Label', legend='full', hue_order=sorted(viz_df['Label'].unique()),
                    palette=sns.color_palette("hls", n_colors=args.n_clusters),
                    alpha=.5,
                    data=viz_df)
    l = plt.legend(bbox_to_anchor=(-.1, 1.00, 1.1, .5), loc="lower left", markerfirst=True,
                   mode="expand", borderaxespad=0, ncol=args.n_clusters + 1, handletextpad=0.01, )

    l.texts[0].set_text("")
    plt.ylabel("")
    plt.xlabel("")
    plt.tight_layout()
    plt.savefig(args.save_dir + '/' + args.dataset +
                '-' + plot_id + '.png', dpi=300)
    plt.clf() 
Example #15
Source File: typeI_analysis_2.py    From SAMPL6 with MIT License 5 votes vote down vote up
def stacked_barplot_2groups(df, x_label, y_label1, y_label2, fig_size=(10, 7), invert=False):
    # Color
    grays = ["#95a5a6", "#34495e"]
    current_palette = sns.color_palette(grays)

    # Plot style
    plt.close()
    plt.style.use(["seaborn-talk", "seaborn-whitegrid"])
    plt.rcParams['axes.labelsize'] = 18
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 16
    plt.tight_layout()
    bar_width = 0.70
    plt.figure(figsize=fig_size)

    data = df  # Pandas DataFrame


    x = range(len(data[x_label]))
    y1 = data[y_label1]
    y2 = data[y_label2]

    p1 = plt.bar(x, y1, width=bar_width, color=current_palette[0])
    p2 = plt.bar(x, y2, width=bar_width, bottom=y1, color=current_palette[1])

    plt.xticks(x, data[x_label], rotation=90)
    plt.xlabel(x_label)
    plt.ylabel("number of $pK_{a}s$")
    plt.legend((p1[0], p2[0]), (y_label1, y_label2))

    # Flip plot upside down
    if invert == True:
        ax = plt.gca()
        ax.invert_yaxis()

# =============================================================================
# CONSTANTS
# =============================================================================

# Paths to input data. 
Example #16
Source File: typeIII_analysis.py    From SAMPL6 with MIT License 5 votes vote down vote up
def barplot_with_CI_errorbars(df, x_label, y_label, y_lower_label, y_upper_label):
    """Creates bar plot of a given dataframe with asymmetric error bars for y axis.

    Args:
        df: Pandas Dataframe that should have columns with columnnames specified in other arguments.
        x_label: str, column name of x axis categories
        y_label: str, column name of y axis values
        y_lower_label: str, column name of lower error values of y axis
        y_upper_label: str, column name of upper error values of y axis

    """
    # Column names for new columns for delta y_err which is calculated as | y_err - y |
    delta_lower_yerr_label = "$\Delta$" + y_lower_label
    delta_upper_yerr_label = "$\Delta$" + y_upper_label
    data = df  # Pandas DataFrame
    data[delta_lower_yerr_label] = data[y_label] - data[y_lower_label]
    data[delta_upper_yerr_label] = data[y_upper_label] - data[y_label]

    # Color
    #current_palette = sns.color_palette()
    current_palette = sns.color_palette("GnBu_d")
    sns_color = current_palette[3]

    # Plot style
    plt.close()
    plt.style.use(["seaborn-talk", "seaborn-whitegrid"])
    plt.rcParams['axes.labelsize'] = 18
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 16
    #plt.tight_layout()

    # Plot
    x = range(len(data[y_label]))
    y = data[y_label]
    plt.bar(x, y)
    plt.xticks(x, data[x_label], rotation=90)
    plt.errorbar(x, y, yerr=(data[delta_lower_yerr_label], data[delta_upper_yerr_label]),
                 fmt="none", ecolor=sns_color, capsize=3, capthick=True)
    plt.xlabel(x_label)
    plt.ylabel(y_label) 
Example #17
Source File: embedding.py    From agnez with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def embedding2dplot(data, labels, show_median=True, show_legend=True):
    '''2D embedding visualization.

    Modified from:
    https://beta.oreilly.com/learning/an-illustrated-introduction-to-the-t-sne-algorithm

    '''
    # We choose a color palette with seaborn.
    max_label = labels.max()
    palette = np.array(sns.color_palette("hls", max_label+1))
    # We create a scatter plot.
    fig = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(data[:, 0], data[:, 1], lw=0, s=40,
                    c=palette[labels.astype(np.int)])
    plt.xlim(-25, 25)
    plt.ylim(-25, 25)
    ax.axis('off')
    ax.axis('tight')

    # We add the labels for each cluster.
    if show_median:
        txts = []
        for i in range(10):
            # Position of each label.
            xtext, ytext = np.median(data[labels == i, :], axis=0)
            txt = ax.text(xtext, ytext, str(i), fontsize=24)
            txt.set_path_effects([
                PathEffects.Stroke(linewidth=5, foreground="w"),
                PathEffects.Normal()])
            txts.append(txt)

    # Show labels as legend patches
    if show_legend:
        handles = _get_legend(palette, labels)
        ax.legend(handles=handles)
    return fig, ax, sc 
Example #18
Source File: basenji_hidden.py    From basenji with Apache License 2.0 5 votes vote down vote up
def regplot(vals1, vals2, out_pdf, alpha=0.5, x_label=None, y_label=None):
  plt.figure()

  gold = sns.color_palette('husl', 8)[1]
  ax = sns.regplot(
      vals1,
      vals2,
      color='black',
      lowess=True,
      scatter_kws={'color': 'black',
                   's': 4,
                   'alpha': alpha},
      line_kws={'color': gold})

  xmin, xmax = plots.scatter_lims(vals1)
  ymin, ymax = plots.scatter_lims(vals2)

  ax.set_xlim(xmin, xmax)
  if x_label is not None:
    ax.set_xlabel(x_label)
  ax.set_ylim(ymin, ymax)
  if y_label is not None:
    ax.set_ylabel(y_label)

  ax.grid(True, linestyle=':')

  plt.savefig(out_pdf)
  plt.close()


################################################################################
# __main__
################################################################################ 
Example #19
Source File: callbacks.py    From ivis with GNU General Public License v2.0 5 votes vote down vote up
def plot_embeddings(self, filename):
        embeddings = MinMaxScaler((0, 1)).fit_transform(self.embeddings)

        fig = plt.figure()
        sns.scatterplot(x=embeddings[:, 0], y=embeddings[:, 1], s=1,
                        hue=self.labels,
                        palette=sns.color_palette("hls", self.n_classes),
                        linewidth=0)

        plt.savefig(os.path.join(self.log_dir, filename), dpi=300)
        plt.close(fig) 
Example #20
Source File: analysis.py    From dl-eeg-review with MIT License 5 votes vote down vote up
def plot_number_layers(df, save_cfg=cfg.saving_config):
    """Plot histogram of number of layers.
    """
    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 4 * 2, save_cfg['text_width'] / 3))

    n_layers_df = df['Layers (clean)'].value_counts().reindex(
        [str(i) for i in range(1, 32)] + ['N/M'])
    n_layers_df = n_layers_df.dropna().astype(int)

    from matplotlib.colors import ListedColormap
    cmap = ListedColormap(sns.color_palette(None).as_hex())

    n_layers_df.plot(kind='bar', width=0.8, rot=0, colormap=cmap, ax=ax)
    ax.set_xlabel('Number of layers')
    ax.set_ylabel('Number of papers')
    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'number_layers')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

        save_cfg2 = save_cfg.copy()
        save_cfg2['format'] = 'png'
        save_cfg2['dpi'] = 300
        fig.savefig(fname + '.png', **save_cfg2)

    return ax 
Example #21
Source File: analysis.py    From dl-eeg-review with MIT License 5 votes vote down vote up
def plot_architectures_per_year(df, save_cfg=cfg.saving_config):
    """Plot stacked bar graph of architectures per year.
    """
    fig, ax = plt.subplots(
        figsize=(save_cfg['text_width'] / 3 * 2, save_cfg['text_width'] / 3))
    colors = sns.color_palette()

    df['Year'] = df['Year'].astype('int32')
    col_name = 'Architecture (clean)'
    df['Arch'] = df[col_name]
    order = df[col_name].value_counts().index
    counts = df.groupby(['Year', 'Arch']).size().unstack('Arch')
    counts = counts[order]

    counts.plot(kind='bar', stacked=True, title='', ax=ax, color=colors)
    ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
    ax.set_ylabel('Number of papers')
    ax.set_xlabel('')

    plt.tight_layout()

    if save_cfg is not None:
        fname = os.path.join(save_cfg['savepath'], 'architectures_per_year')
        fig.savefig(fname + '.' + save_cfg['format'], **save_cfg)

    return ax 
Example #22
Source File: bam_cov.py    From basenji with Apache License 2.0 5 votes vote down vote up
def regplot_gc(vals1, vals2, model, out_pdf):
  gold = sns.color_palette('husl', 8)[1]

  plt.figure(figsize=(6, 6))

  # plot data and seaborn model
  ax = sns.regplot(
      vals1,
      vals2,
      color='black',
      order=3,
      scatter_kws={'color': 'black',
                   's': 4,
                   'alpha': 0.5},
      line_kws={'color': gold})

  # plot my model predictions
  svals1 = np.sort(vals1)
  preds2 = model.predict(svals1[:, np.newaxis])
  ax.plot(svals1, preds2)

  # adjust axis
  ymin, ymax = scatter_lims(vals2)
  ax.set_xlim(0.2, 0.8)
  ax.set_xlabel('GC%')
  ax.set_ylim(ymin, ymax)
  ax.set_ylabel('Coverage')

  ax.grid(True, linestyle=':')

  plt.savefig(out_pdf)
  plt.close() 
Example #23
Source File: plotter.py    From message-analyser with MIT License 5 votes vote down vote up
def barplot_messages_per_minutes(msgs, path_to_save, minutes=2):
    sns.set(style="whitegrid", palette="muted")
    sns.despine(top=True)

    messages_per_minutes = stools.get_messages_per_minutes(msgs, minutes)

    xticks_labels = stools.get_hours()
    xticks = [i * 60 // minutes for i in range(24)]

    min_minutes = len(min(messages_per_minutes.values(), key=lambda day: len(day)))
    max_minutes = len(max(messages_per_minutes.values(), key=lambda day: len(day)))
    pal = sns.color_palette("GnBu_d", max_minutes - min_minutes + 1)[::-1]

    ax = sns.barplot(x=list(range(len(messages_per_minutes))), y=[len(day) for day in messages_per_minutes.values()],
                     edgecolor="none",
                     palette=np.array(pal)[[len(day) - min_minutes for day in messages_per_minutes.values()]])
    _change_bar_width(ax, 1.)
    ax.set(xlabel="hour", ylabel="messages")
    ax.set_xticklabels(xticks_labels)

    ax.tick_params(axis='x', bottom=True, color="#A9A9A9")
    plt.xticks(xticks, rotation=65)

    fig = plt.gcf()
    fig.set_size_inches(20, 10)

    fig.savefig(os.path.join(path_to_save, barplot_messages_per_minutes.__name__ + ".png"), dpi=500)
    # plt.show()
    log_line(f"{barplot_messages_per_minutes.__name__} was created.")
    plt.close("all") 
Example #24
Source File: plotter.py    From message-analyser with MIT License 5 votes vote down vote up
def barplot_messages_per_day(msgs, path_to_save):
    sns.set(style="whitegrid", palette="muted")
    sns.despine(top=True)

    messages_per_day_vals = stools.get_messages_per_day(msgs).values()

    xticks, xticks_labels, xlabel = _get_xticks(msgs)

    min_day = len(min(messages_per_day_vals, key=lambda day: len(day)))
    max_day = len(max(messages_per_day_vals, key=lambda day: len(day)))
    pal = sns.color_palette("Greens_d", max_day - min_day + 1)[::-1]

    ax = sns.barplot(x=list(range(len(messages_per_day_vals))), y=[len(day) for day in messages_per_day_vals],
                     edgecolor="none", palette=np.array(pal)[[len(day) - min_day for day in messages_per_day_vals]])
    _change_bar_width(ax, 1.)
    ax.set(xlabel=xlabel, ylabel="messages")
    ax.set_xticklabels(xticks_labels)

    ax.tick_params(axis='x', bottom=True, color="#A9A9A9")
    plt.xticks(xticks, rotation=65)

    fig = plt.gcf()
    fig.set_size_inches(20, 10)
    fig.savefig(os.path.join(path_to_save, barplot_messages_per_day.__name__ + ".png"), dpi=500)

    # plt.show()
    log_line(f"{barplot_messages_per_day.__name__} was created.")
    plt.close("all") 
Example #25
Source File: opt_callbacks.py    From lumin with Apache License 2.0 5 votes vote down vote up
def plot_lr(self) -> None:
        r'''
        Plot the LR as a function of iterations.
        '''

        with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette):
            plt.figure(figsize=(self.plot_settings.h_small, self.plot_settings.h_small))
            plt.plot(range(len(self.history['lr'])), self.history['lr'])
            plt.xticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.yticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.ylabel("Learning rate", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.xlabel("Iterations", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.show() 
Example #26
Source File: bam_cov.py    From basenji with Apache License 2.0 5 votes vote down vote up
def regplot_shift(vals1, vals2, preds2, out_pdf):
  gold = sns.color_palette('husl', 8)[1]

  plt.figure(figsize=(6, 6))

  # plot data and seaborn model
  ax = sns.regplot(
      vals1,
      vals2,
      color='black',
      order=3,
      scatter_kws={'color': 'black',
                   's': 4,
                   'alpha': 0.5},
      line_kws={'color': gold})

  # plot my model predictions
  ax.plot(vals1, preds2)

  # adjust axis
  ymin, ymax = scatter_lims(vals2)
  ax.set_xlabel('Shift')
  ax.set_ylim(ymin, ymax)
  ax.set_ylabel('Covariance')

  ax.grid(True, linestyle=':')

  plt.savefig(out_pdf)
  plt.close() 
Example #27
Source File: plots.py    From Comparative-Annotation-Toolkit with Apache License 2.0 5 votes vote down vote up
def generic_unstacked_barplot(df, pdf, title_string, legend_labels, ylabel, names, box_label,
                              bbox_to_anchor=(1.12, 0.7)):
    fig, ax = plt.subplots()
    bars = []
    shorter_bar_width = bar_width / len(df)
    for i, (_, d) in enumerate(df.iterrows()):
        bars.append(ax.bar(np.arange(len(df.columns)) + shorter_bar_width * i, d, shorter_bar_width,
                           color=sns.color_palette()[i], linewidth=0.0))
    _generic_histogram(bars, legend_labels, title_string, pdf, ax, fig, ylabel, names, box_label, bbox_to_anchor) 
Example #28
Source File: cyclic_callbacks.py    From lumin with Apache License 2.0 5 votes vote down vote up
def plot(self) -> None:
        r'''
        Plots the history of the parameter evolution as a function of iterations
        '''

        with sns.axes_style(self.plot_settings.style), sns.color_palette(self.plot_settings.cat_palette):
            plt.figure(figsize=(self.plot_settings.w_mid, self.plot_settings.h_mid))
            plt.xlabel("Iterations", fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.ylabel(self.param_name, fontsize=self.plot_settings.lbl_sz, color=self.plot_settings.lbl_col)
            plt.plot(range(len(self.hist)), self.hist)
            plt.xticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.yticks(fontsize=self.plot_settings.tk_sz, color=self.plot_settings.tk_col)
            plt.show() 
Example #29
Source File: training.py    From lumin with Apache License 2.0 5 votes vote down vote up
def plot_train_history(histories:List[Dict[str,List[float]]], savename:Optional[str]=None, ignore_trn=True, settings:PlotSettings=PlotSettings(),
                       show:bool=True) -> None:
    r'''
    Plot histories object returned by :meth:`~lumin.nn.training.fold_train.fold_train_ensemble` showing the loss evolution over time per model trained.

    Arguments:
        histories: list of dictionaries mapping loss type to values at each (sub)-epoch
        savename: Optional name of file to which to save the plot of feature importances
        ignore_trn: whether to ignore training loss
        settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
        show: whether or not to show the plot, or just save it
    '''
    with sns.axes_style(**settings.style), sns.color_palette(settings.cat_palette) as palette:
        plt.figure(figsize=(settings.w_mid, settings.h_mid))
        for i, history in enumerate(histories):
            if i == 0:
                for j, l in enumerate(history):
                    if not('trn' in l and ignore_trn): plt.plot(history[l], color=palette[j], label=_lookup_name(l))
            else:
                for j, l in enumerate(history):
                    if not('trn' in l and ignore_trn): plt.plot(history[l], color=palette[j])

        plt.legend(loc=settings.leg_loc, fontsize=settings.leg_sz)
        plt.xticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.yticks(fontsize=settings.tk_sz, color=settings.tk_col)
        plt.xlabel("Epoch", fontsize=settings.lbl_sz, color=settings.lbl_col)
        plt.ylabel("Loss", fontsize=settings.lbl_sz, color=settings.lbl_col)
        if savename is not None: plt.savefig(f'{savename}{settings.format}', bbox_inches='tight')
        if show: plt.show() 
Example #30
Source File: _plot.py    From q2-qemistree with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def values_to_colors(categories, color_palette: str):
    '''This function generates a color map (dict) for unique values in a
    user-specified feature metadata column.'''
    color_map = {}
    colors = sns.color_palette(color_palette,
                               n_colors=len(categories)).as_hex()
    # give a heads up to the user
    if len(set(colors)) < len(categories):
        warnings.warn("The mapping between colors and categories"
                      " is not unique, some colors have been repeated",
                      UserWarning)
    for i, value in enumerate(categories):
        color_map[value] = colors[i]
    return color_map