Python model.MaskRCNN() Examples

The following are 1 code examples of model.MaskRCNN(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module model , or try the search function .
Example #1
Source File: segment_train.py    From SketchyScene with MIT License 4 votes vote down vote up
def instance_segment_train(**kwargs):
    data_base_dir = kwargs['data_base_dir']
    init_with = kwargs['init_with']

    outputs_base_dir = 'outputs'
    pretrained_model_base_dir = 'pretrained_model'

    save_model_dir = os.path.join(outputs_base_dir, 'snapshot')
    log_dir = os.path.join(outputs_base_dir, 'log')
    coco_model_path = os.path.join(pretrained_model_base_dir, 'mask_rcnn_coco.h5')
    imagenet_model_path = os.path.join(pretrained_model_base_dir, 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')

    os.makedirs(save_model_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)

    config = SketchTrainConfig()
    config.display()

    # Training dataset
    dataset_train = SketchDataset(data_base_dir)
    dataset_train.load_sketches("train")
    dataset_train.prepare()

    # Create model in training mode
    model = modellib.MaskRCNN(mode="training", config=config,
                              model_dir=save_model_dir, log_dir=log_dir)

    if init_with == "imagenet":
        print("Loading weights from ", imagenet_model_path)
        model.load_weights(imagenet_model_path, by_name=True)
    elif init_with == "coco":
        # Load weights trained on MS COCO, but skip layers that
        # are different due to the different number of classes
        print("Loading weights from ", coco_model_path)
        model.load_weights(coco_model_path, by_name=True,
                           exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",
                                    "mrcnn_bbox", "mrcnn_mask"])
    elif init_with == "last":
        # Load the last model you trained and continue training
        last_model_path = model.find_last()[1]
        print("Loading weights from ", last_model_path)
        model.load_weights(last_model_path, by_name=True)
    else:
        print("Training from fresh start.")

    # Fine tune all layers
    model.train(dataset_train,
                learning_rate=config.LEARNING_RATE,
                epochs=config.TOTAL_EPOCH,
                layers="all")

    # Save final weights
    save_model_path = os.path.join(save_model_dir, "mask_rcnn_" + config.NAME + "_" + str(config.TOTAL_EPOCH) + ".h5")
    model.keras_model.save_weights(save_model_path)