Python tensorpack.logger.info() Examples

The following are 20 code examples of tensorpack.logger.info(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorpack.logger , or try the search function .
Example #1
Source File: dump-model-params.py    From tensorpack with Apache License 2.0 6 votes vote down vote up
def guess_inputs(input_dir):
    meta_candidates = []
    model_candidates = []
    for path in os.listdir(input_dir):
        if path.startswith('graph-') and path.endswith('.meta'):
            meta_candidates.append(path)
        if path.startswith('model-') and path.endswith('.index'):
            modelid = int(path[len('model-'):-len('.index')])
            model_candidates.append((path, modelid))
    assert len(meta_candidates)
    meta = sorted(meta_candidates)[-1]
    if len(meta_candidates) > 1:
        logger.info("Choosing {} from {} as graph file.".format(meta, meta_candidates))
    else:
        logger.info("Choosing {} as graph file.".format(meta))

    assert len(model_candidates)
    model = sorted(model_candidates, key=lambda x: x[1])[-1][0]
    if len(model_candidates) > 1:
        logger.info("Choosing {} from {} as model file.".format(model, [x[0] for x in model_candidates]))
    else:
        logger.info("Choosing {} as model file.".format(model))
    return os.path.join(input_dir, model), os.path.join(input_dir, meta) 
Example #2
Source File: detectPlanePlayerBrain.py    From rl-medical with Apache License 2.0 6 votes vote down vote up
def _update_heirarchical(self):
        self.action_angle_step = int(self.action_angle_step/2)
        self.action_dist_step = self.action_dist_step-1
        if (self.spacing[0] > 1): self.spacing -= 1
        self._groundTruth_plane = Plane(*getACPCPlaneFromLandmarks(
                                        self.sitk_image,
                                        self._origin3d_point.astype('float'),
                                        self.ac_point, self.pc_point,
                                        self.midsag_point,
                                        self._plane_size, self.spacing))
        # self._groundTruth_plane = Plane(*getMidSagPlaneFromLandmarks(
        #                                 self.sitk_image,
        #                                 self._origin3d_point.astype('float'),
        #                                 self.ac_point, self.pc_point,
        #                                 self.midsag_point,
        #                                 self._plane_size, self.spacing))
        # logger.info('update hierarchical - spacing = {} - angle step = {} - dist step = {}'.format(self.spacing,self.action_angle_step,self.action_dist_step)) 
Example #3
Source File: detectPlanePlayerBrain.py    From rl-medical with Apache License 2.0 6 votes vote down vote up
def _update_history(self):
        ''' update history buffer with current state
        '''
        # update location history
        self._loc_history[:-1] = self._loc_history[1:]
        # loc = self._plane.origin
        loc = self._plane.params
        # logger.info('loc {}'.format(loc))
        self._loc_history[-1] = (np.around(loc[0],decimals=2),
                                 np.around(loc[1],decimals=2),
                                 np.around(loc[2],decimals=2),
                                 np.around(loc[3],decimals=2))
        # update distance history
        self._dist_history.append(self.cur_dist)
        self._dist_history_params.append(self.cur_dist_params)
        # update params history
        self._plane_history.append(self._plane)
        self._bestq_history.append(np.max(self._qvalues))
        # update q-value history
        self._qvalues_history[:-1] = self._qvalues_history[1:]
        self._qvalues_history[-1] = self._qvalues 
Example #4
Source File: detectPlanePlayerCardio.py    From rl-medical with Apache License 2.0 6 votes vote down vote up
def _oscillate(self):
        ''' Return True if the agent is stuck and oscillating
        '''
        counter = Counter(self._loc_history)
        freq = counter.most_common()
        # return false is history is empty (begining of the game)
        if len(freq) < 2: return False
        # check frequency
        if freq[0][0] == (0,0,0,0):
            if (freq[1][1]>2):
                # logger.info('oscillating {}'.format(self._loc_history))
                return True
            else:
                return False
        elif (freq[0][1]>2):
            # logger.info('oscillating {}'.format(self._loc_history))
            return True 
Example #5
Source File: detectPlanePlayerCardio.py    From rl-medical with Apache License 2.0 6 votes vote down vote up
def _update_history(self):
        ''' update history buffer with current state
        '''
        # update location history
        self._loc_history[:-1] = self._loc_history[1:]
        loc = self._plane.origin
        loc = self._plane.params
        # logger.info('loc {}'.format(loc))
        self._loc_history[-1] = (np.around(loc[0],decimals=2),
                                 np.around(loc[1],decimals=2),
                                 np.around(loc[2],decimals=2),
                                 np.around(loc[3],decimals=2))
        # update distance history
        self._dist_history.append(self.cur_dist)
        self._dist_history_params.append(self.cur_dist_params)
        # update params history
        self._plane_history.append(self._plane)
        self._bestq_history.append(np.max(self._qvalues))
        # update q-value history
        self._qvalues_history[:-1] = self._qvalues_history[1:]
        self._qvalues_history[-1] = self._qvalues 
Example #6
Source File: dataReader.py    From MARL-for-Anatomical-Landmark-Detection with Apache License 2.0 5 votes vote down vote up
def decode(self, filename,label=False):
        """ decode a single nifti image
        Args
          filename: string for input images
          label: True if nifti image is label
        Returns
          image: an image container with attributes; name, data, dims
        """
        image = ImageRecord()
        image.name = filename
        assert self._is_nifti(image.name), "unknown image format for %r" % image.name

        if label:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkInt8)
        else:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkFloat32)
            np_image = sitk.GetArrayFromImage(sitk_image)
            # threshold image between p10 and p98 then re-scale [0-255]
            p0 = np_image.min().astype('float')
            p10 = np.percentile(np_image,10)
            p99 = np.percentile(np_image,99)
            p100 = np_image.max().astype('float')
            # logger.info('p0 {} , p5 {} , p10 {} , p90 {} , p98 {} , p100 {}'.format(p0,p5,p10,p90,p98,p100))
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p10,
                                        upper=p100,
                                        outsideValue=p10)
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p0,
                                        upper=p99,
                                        outsideValue=p99)
            sitk_image = sitk.RescaleIntensity(sitk_image,
                                               outputMinimum=0,
                                               outputMaximum=255)

        # Convert from [depth, width, height] to [width, height, depth]
        image.data = sitk.GetArrayFromImage(sitk_image).transpose(2,1,0)#.astype('uint8')
        image.dims = np.shape(image.data)

        return sitk_image, image 
Example #7
Source File: medical.py    From MARL-for-Anatomical-Landmark-Detection with Apache License 2.0 5 votes vote down vote up
def step(self, act, q_values,isOver):
        for i in range(0,self.agents):
            if isOver[i]: act[i]=15
        current_st, reward, terminal, info = self.env.step(act, q_values, isOver)
        # for i in range(0,self.agents):
        current_st=tuple(current_st)
        self.frames.append(current_st)
        return self._observation(),reward, terminal, info 
Example #8
Source File: dump-model-params.py    From tensorpack with Apache License 2.0 5 votes vote down vote up
def _import_external_ops(message):
    if "horovod" in message.lower():
        logger.info("Importing horovod ...")
        import horovod.tensorflow  # noqa
        return
    if "MaxBytesInUse" in message:
        logger.info("Importing memory_stats ...")
        from tensorflow.contrib.memory_stats import MaxBytesInUse  # noqa
        return
    if 'Nccl' in message:
        logger.info("Importing nccl ...")
        if TF_version <= (1, 12):
            try:
                from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so
            except Exception:
                pass
            else:
                _validate_and_load_nccl_so()
            from tensorflow.contrib.nccl.ops import gen_nccl_ops  # noqa
        else:
            from tensorflow.python.ops import gen_nccl_ops  # noqa
        return
    if 'ZMQConnection' in message:
        import zmq_ops  # noqa
        return
    logger.error("Unhandled error: " + message) 
Example #9
Source File: sampleTrain.py    From rl-medical with Apache License 2.0 5 votes vote down vote up
def decode(self, filename,label=False):
        """ decode a single nifti image
        Args
          filename: string for input images
          label: True if nifti image is label
        Returns
          image: an image container with attributes; name, data, dims
        """
        image = ImageRecord()
        image.name = filename
        assert self._is_nifti(image.name), "unknown image format for %r" % image.name

        if label:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkInt8)
        else:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkFloat32)
            np_image = sitk.GetArrayFromImage(sitk_image)
            # threshold image between p10 and p98 then re-scale [0-255]
            p0 = np_image.min().astype('float')
            p10 = np.percentile(np_image,10)
            p99 = np.percentile(np_image,99)
            p100 = np_image.max().astype('float')
            # logger.info('p0 {} , p5 {} , p10 {} , p90 {} , p98 {} , p100 {}'.format(p0,p5,p10,p90,p98,p100))
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p10,
                                        upper=p100,
                                        outsideValue=p10)
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p0,
                                        upper=p99,
                                        outsideValue=p99)
            sitk_image = sitk.RescaleIntensity(sitk_image,
                                               outputMinimum=0,
                                               outputMaximum=255)

        # Convert from [depth, width, height] to [width, height, depth]
        # stupid simpleitk
        image.data = sitk.GetArrayFromImage(sitk_image).transpose(2,1,0)#.astype('uint8')
        image.dims = np.shape(image.data)

        return sitk_image, image 
Example #10
Source File: detectPlanePlayerCardio.py    From rl-medical with Apache License 2.0 5 votes vote down vote up
def _calc_reward_params(self, prev_params, next_params):
        ''' Calculate the new reward based on the euclidean distance to the target plane
        '''
        # logger.info('prev_params {}'.format(np.around(prev_params,2)))
        # logger.info('next_params {}'.format(np.around(next_params,2)))
        prev_dist = calcScaledDistTwoParams(self._groundTruth_plane.params,
                                      prev_params,
                                      scale_angle = self.action_angle_step,
                                      scale_dist = self.action_dist_step)
        next_dist = calcScaledDistTwoParams(self._groundTruth_plane.params,
                                      next_params,
                                      scale_angle = self.action_angle_step,
                                      scale_dist = self.action_dist_step)

        return prev_dist - next_dist 
Example #11
Source File: detectPlanePlayerBrain.py    From rl-medical with Apache License 2.0 5 votes vote down vote up
def step(self, action, qvalues):
        ob, reward, done, info = self.env.step(action,qvalues)
        self.frames.append(ob)
        return self._observation(), reward, done, info 
Example #12
Source File: detectPlanePlayerBrain.py    From rl-medical with Apache License 2.0 5 votes vote down vote up
def _calc_reward_params(self, prev_params, next_params):
        ''' Calculate the new reward based on the euclidean distance to the target plane
        '''
        # logger.info('prev_params {}'.format(np.around(prev_params,2)))
        # logger.info('next_params {}'.format(np.around(next_params,2)))
        prev_dist = calcScaledDistTwoParams(self._groundTruth_plane.params,
                                      prev_params,
                                      scale_angle = self.action_angle_step,
                                      scale_dist = self.action_dist_step)
        next_dist = calcScaledDistTwoParams(self._groundTruth_plane.params,
                                      next_params,
                                      scale_angle = self.action_angle_step,
                                      scale_dist = self.action_dist_step)
        # logger.info('next_dist {} prev_dist {}'.format(next_dist, prev_dist))
        return prev_dist - next_dist 
Example #13
Source File: dataReader.py    From rl-medical with Apache License 2.0 5 votes vote down vote up
def decode(self, filename, label=False):
        """ decode a single nifti image
        Args
          filename: string for input images
          label: True if nifti image is label
        Returns
          image: an image container with attributes; name, data, dims
        """
        image = ImageRecord()
        image.name = filename
        assert self._is_nifti(image.name), "unknown image format for %r" % image.name

        if label:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkInt8)
        else:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkFloat32)
            np_image = sitk.GetArrayFromImage(sitk_image)
            # threshold image between p10 and p98 then re-scale [0-255]
            p0 = np_image.min().astype('float')
            p10 = np.percentile(np_image, 10)
            p99 = np.percentile(np_image, 99)
            p100 = np_image.max().astype('float')
            # logger.info('p0 {} , p5 {} , p10 {} , p90 {} , p98 {} , p100 {}'.format(p0,p5,p10,p90,p98,p100))
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p10,
                                        upper=p100,
                                        outsideValue=p10)
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p0,
                                        upper=p99,
                                        outsideValue=p99)
            sitk_image = sitk.RescaleIntensity(sitk_image,
                                               outputMinimum=0,
                                               outputMaximum=255)

        # Convert from [depth, width, height] to [width, height, depth]
        image.data = sitk.GetArrayFromImage(sitk_image).transpose(2, 1, 0) #.astype('uint8')
        image.dims = np.shape(image.data)

        return sitk_image, image 
Example #14
Source File: medical.py    From rl-medical with Apache License 2.0 5 votes vote down vote up
def step(self, action, q_values):
        ob, reward, done, info = self.env.step(action, q_values)
        self.frames.append(ob)
        return self._observation(), reward, done, info 
Example #15
Source File: imagenet.py    From LQ-Nets with MIT License 5 votes vote down vote up
def get_config(model, fake=False, data_aug=True):
    nr_tower = max(get_nr_gpu(), 1)
    batch = TOTAL_BATCH_SIZE // nr_tower

    if fake:
        logger.info("For benchmark, batch size is fixed to 64 per tower.")
        dataset_train = FakeData(
            [[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
        callbacks = []
    else:
        logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
        dataset_train = get_data('train', batch, data_aug)
        dataset_val = get_data('val', batch, data_aug)
        callbacks = [
            ModelSaver(),
        ]
        if data_aug:
            callbacks.append(ScheduledHyperParamSetter('learning_rate',
                                                       [(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5), (105, 1e-6)]))
        callbacks.append(HumanHyperParamSetter('learning_rate'))
        infs = [ClassificationError('wrong-top1', 'val-error-top1'),
                ClassificationError('wrong-top5', 'val-error-top5')]
        if nr_tower == 1:
            # single-GPU inference with queue prefetch
            callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
        else:
            # multi-GPU inference (with mandatory queue prefetch)
            callbacks.append(DataParallelInferenceRunner(
                dataset_val, infs, list(range(nr_tower))))

    return AutoResumeTrainConfig(
        model=model,
        dataflow=dataset_train,
        callbacks=callbacks,
        steps_per_epoch=5000 if TOTAL_BATCH_SIZE == 256 else 10000,
        max_epoch=110 if data_aug else 64,
        nr_tower=nr_tower
    ) 
Example #16
Source File: detectPlaneHelper.py    From rl-medical with Apache License 2.0 4 votes vote down vote up
def getPlane(sitk_image3d, origin, plane_params, plane_size, spacing=(1,1,1)):
    ''' Get a plane from a 3d nifti image using its norm form
    '''
    # plane equation ax+by+cz=d , where norm = (a,b,c), and d=a*x0+b*y0+c*z0
    a, b, c, d = [np.cos(np.deg2rad(plane_params[0])),
                  np.cos(np.deg2rad(plane_params[1])),
                  np.cos(np.deg2rad(plane_params[2])),
                  plane_params[3]]
    # find plane norm vector
    plane_norm = np.array((a,b,c))
    plane_norm = normalizeUnitVector(np.array((a,b,c)))

    # plane_params = [np.rad2deg(np.arccos(plane_norm[0])),
    #                 np.rad2deg(np.arccos(plane_norm[1])),
    #                 np.rad2deg(np.arccos(plane_norm[2])),
    #                 d]

    # get transformation and origin
    origin3d = np.array(sitk_image3d.GetOrigin())
    direction = np.array(sitk_image3d.GetDirection())
    transformation = np.array(direction.reshape(3,3))
    transformation_inv = np.linalg.inv(transformation)

    # find plane origin
    plane_origin = origin + d * plane_norm
    plane_origin_physical = plane_origin.dot(transformation_inv) + origin3d
    # plane_origin_physical = np.array(sitk_image3d.TransformContinuousIndexToPhysicalPoint(plane_origin))
    # find point in x-direction of the 3d volume to sample in this direction
    pointx = (origin[0] + plane_size[0]/2, origin[1], origin[2])
    pointx_proj, _ = projectPointOnPlane(pointx, plane_norm, plane_origin)
    pointx_proj_physical = pointx_proj.dot(transformation_inv) + origin3d
    # pointx_proj_physical = np.array(sitk_image3d.TransformContinuousIndexToPhysicalPoint(pointx_proj))
    vectorx = normalizeUnitVector(pointx_proj_physical - plane_origin_physical)
    # z-direction
    # find point in the new positive z-direction (plane norm)
    pointz = plane_origin + (plane_size[2]/2) * plane_norm
    pointz_physical = pointz.dot(transformation_inv) + origin3d
    # pointz_physical = np.array(sitk_image3d.TransformContinuousIndexToPhysicalPoint(pointz))
    vectorz = normalizeUnitVector(pointz_physical - plane_origin_physical)
    # y-direction
    vectory = np.cross(vectorz, vectorx)
    vectory = normalizeUnitVector(vectory)
    # sample a grid in the calculated directions
    grid, grid_smooth, points = sampleGrid(sitk_image3d,
                                           plane_origin,
                                           vectorx,
                                           vectory,
                                           vectorz,
                                           plane_size,
                                           spacing=spacing)

    # logger.info('plane_norm {}'.format(np.around(plane_norm,2)))
    # logger.info('plane_origin {}'.format(np.around(plane_origin,2)))

    return grid, grid_smooth, plane_norm, plane_origin, plane_params, points


############################################################################### 
Example #17
Source File: adanet-resnet.py    From adanet with MIT License 4 votes vote down vote up
def get_config(model, fake=False):
    nr_tower = max(get_num_gpu(), 1)
    assert args.batch % nr_tower == 0
    batch = args.batch // nr_tower

    logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
    if batch < 32 or batch > 64:
        logger.warn("Batch size per tower not in [32, 64]. This probably will lead to worse accuracy than reported.")
    if fake:
        data = QueueInput(FakeData(
            [[batch, 224, 224, 3], [batch],[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8'))
        callbacks = []
    else:
        data = QueueInput(get_data('train', batch))

        START_LR = 0.1
        BASE_LR = START_LR * (args.batch / 256.0)
        callbacks = [
            ModelSaver(),
            EstimatedTimeLeft(),
            ScheduledHyperParamSetter(
                'learning_rate', [
                    (0, min(START_LR, BASE_LR)), (30, BASE_LR * 1e-1), (45, BASE_LR * 1e-2),
                    (55, BASE_LR * 1e-3)]),
        ]
        if BASE_LR > START_LR:
            callbacks.append(
                ScheduledHyperParamSetter(
                    'learning_rate', [(0, START_LR), (5, BASE_LR)], interp='linear'))

        infs = [ClassificationError('wrong-top1', 'val-error-top1'),
                ClassificationError('wrong-top5', 'val-error-top5')]
        dataset_val = get_data('val', batch)
        if nr_tower == 1:
            # single-GPU inference with queue prefetch
            callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
        else:
            # multi-GPU inference (with mandatory queue prefetch)
            callbacks.append(DataParallelInferenceRunner(
                dataset_val, infs, list(range(nr_tower))))

    return AutoResumeTrainConfig(
        model=model,
        data=data,
        callbacks=callbacks,
        steps_per_epoch=100 if args.fake else 1280000 // args.batch,
        max_epoch=60,
    ) 
Example #18
Source File: imagenet-resnet.py    From tensorpack with Apache License 2.0 4 votes vote down vote up
def get_config(model):
    nr_tower = max(get_num_gpu(), 1)
    assert args.batch % nr_tower == 0
    batch = args.batch // nr_tower

    logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
    if batch < 32 or batch > 64:
        logger.warn("Batch size per tower not in [32, 64]. This probably will lead to worse accuracy than reported.")
    if args.fake:
        data = QueueInput(FakeData(
            [[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8'))
        callbacks = []
    else:
        if args.symbolic:
            data = TFDatasetInput(get_imagenet_tfdata(args.data, 'train', batch))
        else:
            data = QueueInput(get_imagenet_dataflow(args.data, 'train', batch))

        START_LR = 0.1
        BASE_LR = START_LR * (args.batch / 256.0)
        callbacks = [
            ModelSaver(),
            EstimatedTimeLeft(),
            ScheduledHyperParamSetter(
                'learning_rate', [
                    (0, min(START_LR, BASE_LR)), (30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2),
                    (90, BASE_LR * 1e-3), (100, BASE_LR * 1e-4)]),
        ]
        if BASE_LR > START_LR:
            callbacks.append(
                ScheduledHyperParamSetter(
                    'learning_rate', [(0, START_LR), (5, BASE_LR)], interp='linear'))

        infs = [ClassificationError('wrong-top1', 'val-error-top1'),
                ClassificationError('wrong-top5', 'val-error-top5')]
        dataset_val = get_imagenet_dataflow(args.data, 'val', batch)
        if nr_tower == 1:
            # single-GPU inference with queue prefetch
            callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
        else:
            # multi-GPU inference (with mandatory queue prefetch)
            callbacks.append(DataParallelInferenceRunner(
                dataset_val, infs, list(range(nr_tower))))

    if get_num_gpu() > 0:
        callbacks.append(GPUUtilizationTracker())

    return TrainConfig(
        model=model,
        data=data,
        callbacks=callbacks,
        steps_per_epoch=100 if args.fake else 1281167 // args.batch,
        max_epoch=105,
    ) 
Example #19
Source File: imagenet-resnet-gn.py    From GroupNorm-reproduce with Apache License 2.0 4 votes vote down vote up
def get_config(model, fake=False):
    nr_tower = max(get_num_gpu(), 1)
    assert args.batch % nr_tower == 0
    batch = args.batch // nr_tower

    if fake:
        logger.info("For benchmark, batch size is fixed to 64 per tower.")
        dataset_train = FakeData(
            [[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
        callbacks = []
        steps_per_epoch = 100
    else:
        logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))

        dataset_train = get_imagenet_dataflow(args.data, 'train', batch)
        dataset_val = get_imagenet_dataflow(args.data, 'val', min(64, batch))
        steps_per_epoch = 1281167 // args.batch

        BASE_LR = 0.1 * args.batch / 256.0
        logger.info("BASELR: {}".format(BASE_LR))
        callbacks = [
            ModelSaver(),
            EstimatedTimeLeft(),
            GPUUtilizationTracker(),
            ScheduledHyperParamSetter(
                'learning_rate', [(0, BASE_LR), (30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2),
                                  (90, BASE_LR * 1e-3)]),
        ]
        if BASE_LR > 0.1:
            callbacks.append(
                ScheduledHyperParamSetter(
                    'learning_rate', [(0, 0.1), (5 * steps_per_epoch, BASE_LR)],
                    interp='linear', step_based=True))

        infs = [ClassificationError('wrong-top1', 'val-error-top1'),
                ClassificationError('wrong-top5', 'val-error-top5')]
        if nr_tower == 1:
            # single-GPU inference with queue prefetch
            callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
        else:
            # multi-GPU inference (with mandatory queue prefetch)
            callbacks.append(DataParallelInferenceRunner(
                dataset_val, infs, list(range(nr_tower))))

    return TrainConfig(
        model=model,
        dataflow=dataset_train,
        callbacks=callbacks,
        steps_per_epoch=steps_per_epoch,
        max_epoch=100,
    ) 
Example #20
Source File: imagenet-resnet.py    From webvision-2.0-benchmarks with Apache License 2.0 4 votes vote down vote up
def get_config(model, fake=False):
    nr_tower = max(get_nr_gpu(), 1)
    assert args.batch % nr_tower == 0
    batch = args.batch // nr_tower

    if fake:
        logger.info("For benchmark, batch size is fixed to 64 per tower.")
        dataset_train = FakeData(
            [[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
        callbacks = []
    else:
        logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
        dataset_train = get_data('train', batch)
        dataset_val = get_data('val', batch)

        BASE_LR = 0.1 * (args.batch / 256.0)
        callbacks = [
            ModelSaver(),
            ScheduledHyperParamSetter(
                'learning_rate', [(0, BASE_LR), (30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2),
                                  (90, BASE_LR * 1e-3)]),
        ]
        if BASE_LR > 0.1:
            callbacks.append(
                ScheduledHyperParamSetter(
                    'learning_rate', [(0, 0.1), (3, BASE_LR)], interp='linear'))

        infs = [ClassificationError('wrong-top1', 'val-error-top1'),
                ClassificationError('wrong-top5', 'val-error-top5')]
        if nr_tower == 1:
            # single-GPU inference with queue prefetch
            callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
        else:
            # multi-GPU inference (with mandatory queue prefetch)
            callbacks.append(DataParallelInferenceRunner(
                dataset_val, infs, list(range(nr_tower))))

    return TrainConfig(
        model=model,
        dataflow=dataset_train,
        callbacks=callbacks,
        steps_per_epoch=100 if args.fake else 1280000 // args.batch,
        max_epoch=110,
    )