Python object_detection.utils.visualization_utils.encode_image_array_as_png_str() Examples
code examples of object_detection.utils.visualization_utils.encode_image_array_as_png_str().
Example #1
Source File: From mtl-ssl with Apache License 2.0 | 4 votes |
def visualize_pr_curve(per_class_precisions, per_class_recalls, global_step, eval_dir, nms_type, nms_iou_threshold, soft_nms_sigma): """Visualizes pr curve and writes visualizations to image summaries. Args: per_class_precisions: precision list of each class per_class_recalls: recall list of each class """ subset_keys = per_class_precisions.keys() for key in subset_keys: precisions_per_class = per_class_precisions[key] recalls_per_class = per_class_recalls[key] precisions, recalls = compute_pr_curve(precisions_per_class, recalls_per_class) # file save nms_str = '_'.join(get_string_list_for_nms(nms_type, nms_iou_threshold, soft_nms_sigma)) filename = os.path.join(eval_dir, 'pr_curve_' + key + '_' + nms_str + '.txt') # with open(filename, 'wb') as fp: # Pickling # pickle.dump([precisions, recalls], fp) with open(filename, 'w') as f: for p,r in zip(precisions, recalls): f.write('%f\t%f\n'%(p,r)) # Unpickling & plot pr curve # with open(filename, 'rb') as fp: # Unpickling # [precisions, recalls] = pickle.load(fp) # plt.plot(recalls, precisions, ls='-', lw=1.0) # get the figure contents as RGB pixel values fig = plt.figure(figsize=[8, 8]) plt.xlabel('Recall') plt.ylabel('Precision') plt.title('PR curve') plt.plot(recalls, precisions, ls='-', lw=1.0) fig.canvas.draw() width, height = fig.get_size_inches() * fig.get_dpi() image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3) tag = 'PRcurve_' + key + '_' summary = tf.Summary(value=[ tf.Summary.Value(tag=tag, image=tf.Summary.Image( encoded_image_string=vis_utils.encode_image_array_as_png_str(image))) ]) summary_writer = tf.summary.FileWriter(eval_dir) summary_writer.add_summary(summary, global_step) summary_writer.close() time.sleep(1)