Python object_detection.utils.visualization_utils.encode_image_array_as_png_str() Examples

The following are 1 code examples of object_detection.utils.visualization_utils.encode_image_array_as_png_str(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module object_detection.utils.visualization_utils , or try the search function .
Example #1
Source File: eval_util.py    From mtl-ssl with Apache License 2.0 4 votes vote down vote up
def visualize_pr_curve(per_class_precisions, per_class_recalls, global_step, eval_dir,
                       nms_type, nms_iou_threshold, soft_nms_sigma):
  """Visualizes pr curve and writes visualizations to image summaries.

  Args:
    per_class_precisions: precision list of each class
    per_class_recalls: recall list of each class

  """

  subset_keys = per_class_precisions.keys()
  for key in subset_keys:
    precisions_per_class = per_class_precisions[key]
    recalls_per_class = per_class_recalls[key]
    precisions, recalls = compute_pr_curve(precisions_per_class, recalls_per_class)

    # file save
    nms_str = '_'.join(get_string_list_for_nms(nms_type,
                                               nms_iou_threshold,
                                               soft_nms_sigma))

    filename = os.path.join(eval_dir,  'pr_curve_' + key + '_' + nms_str + '.txt')
    # with open(filename, 'wb') as fp:  # Pickling
    #   pickle.dump([precisions, recalls], fp)

    with open(filename, 'w') as f:
      for p,r in zip(precisions, recalls):
        f.write('%f\t%f\n'%(p,r))


    # Unpickling & plot pr curve
    # with open(filename, 'rb') as fp:  # Unpickling
    #   [precisions, recalls] = pickle.load(fp)
    # plt.plot(recalls, precisions, ls='-', lw=1.0)

    # get the figure contents as RGB pixel values
    fig = plt.figure(figsize=[8, 8])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('PR curve')
    plt.plot(recalls, precisions, ls='-', lw=1.0)
    fig.canvas.draw()

    width, height = fig.get_size_inches() * fig.get_dpi()
    image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)

    tag = 'PRcurve_' + key + '_'
    summary = tf.Summary(value=[
      tf.Summary.Value(tag=tag, image=tf.Summary.Image(
        encoded_image_string=vis_utils.encode_image_array_as_png_str(image)))
    ])
    summary_writer = tf.summary.FileWriter(eval_dir)
    summary_writer.add_summary(summary, global_step)
    summary_writer.close()
    time.sleep(1)