Python mxnet.io.DataBatch() Examples

The following are 30 code examples of mxnet.io.DataBatch(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module mxnet.io , or try the search function .
Example #1
Source File: iterators.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 6 votes vote down vote up
def next(self):
        """Returns the next batch of data."""
        if self.curr_idx == len(self.idx):
            raise StopIteration
        #i = batches index, j = starting record
        i, j = self.idx[self.curr_idx] 
        self.curr_idx += 1

        indices = self.ndindex[i][j:j + self.batch_size]
        sentences = self.ndsent[i][j:j + self.batch_size]
        characters = self.ndchar[i][j:j + self.batch_size]
        label = self.ndlabel[i][j:j + self.batch_size]

        return DataBatch([sentences, characters], [label], pad=0, index = indices, bucket_key=self.buckets[i],
                         provide_data=[DataDesc(name=self.data_names[0], shape=sentences.shape, layout=self.layout),
                                       DataDesc(name=self.data_names[1], shape=characters.shape, layout=self.layout)],
                         provide_label=[DataDesc(name=self.label_name, shape=label.shape, layout=self.layout)]) 
Example #2
Source File: fizbuz_service.py    From HandsOnDeepLearningWithPytorch with MIT License 6 votes vote down vote up
def inference(self, model_input):
        if self.error is not None:
            return None

        # Check input shape
        check_input_shape(model_input, self.signature)
        self.mx_model.forward(DataBatch([model_input]))
        model_input = self.mx_model.get_outputs()
        # by pass lazy evaluation get_outputs either returns a list of nd arrays
        # a list of list of NDArray
        for d in model_input:
            if isinstance(d, list):
                for n in model_input:
                    if isinstance(n, mx.ndarray.ndarray.NDArray):
                        n.wait_to_read()
            elif isinstance(d, mx.ndarray.ndarray.NDArray):
                d.wait_to_read()
        return model_input 
Example #3
Source File: parall_module_local_v1.py    From insightface with MIT License 6 votes vote down vote up
def forward(self, data_batch, is_train=None):
        #g,x = self.get_params()
        #print('{fc7_weight[0][0]}', self._iter, g['fc7_0_weight'].asnumpy()[0][0])
        #print('{pre_fc1_weight[0][0]}', self._iter, g['pre_fc1_weight'].asnumpy()[0][0])


        assert self.binded and self.params_initialized
        self._curr_module.forward(data_batch, is_train=is_train)
        if is_train:
          self._iter+=1
          fc1, label = self._curr_module.get_outputs(merge_multi_context=True)
          global_fc1 = fc1
          self.global_label = label.as_in_context(self._ctx_cpu)


          for i, _module in enumerate(self._arcface_modules):
            _label = self.global_label - self._ctx_class_start[i]
            db_global_fc1 = io.DataBatch([global_fc1], [_label])
            _module.forward(db_global_fc1) #fc7 with margin
        #print('forward end') 
Example #4
Source File: iterators.py    From training_results_v0.6 with Apache License 2.0 6 votes vote down vote up
def next(self):
        """Returns the next batch of data."""
        if self.curr_idx == len(self.idx):
            raise StopIteration
        #i = batches index, j = starting record
        i, j = self.idx[self.curr_idx] 
        self.curr_idx += 1

        indices = self.ndindex[i][j:j + self.batch_size]
        sentences = self.ndsent[i][j:j + self.batch_size]
        characters = self.ndchar[i][j:j + self.batch_size]
        label = self.ndlabel[i][j:j + self.batch_size]

        return DataBatch([sentences, characters], [label], pad=0, index = indices, bucket_key=self.buckets[i],
                         provide_data=[DataDesc(name=self.data_names[0], shape=sentences.shape, layout=self.layout),
                                       DataDesc(name=self.data_names[1], shape=characters.shape, layout=self.layout)],
                         provide_label=[DataDesc(name=self.label_name, shape=label.shape, layout=self.layout)]) 
Example #5
Source File: parall_module_local_v1.py    From 1.FaceRecognition with MIT License 6 votes vote down vote up
def forward(self, data_batch, is_train=None):
        #g,x = self.get_params()
        #print('{fc7_weight[0][0]}', self._iter, g['fc7_0_weight'].asnumpy()[0][0])
        #print('{pre_fc1_weight[0][0]}', self._iter, g['pre_fc1_weight'].asnumpy()[0][0])


        assert self.binded and self.params_initialized
        self._curr_module.forward(data_batch, is_train=is_train)
        if is_train:
          self._iter+=1
          fc1, label = self._curr_module.get_outputs(merge_multi_context=True)
          global_fc1 = fc1
          self.global_label = label.as_in_context(self._ctx_cpu)


          for i, _module in enumerate(self._arcface_modules):
            _label = self.global_label - self._ctx_class_start[i]
            db_global_fc1 = io.DataBatch([global_fc1], [_label])
            _module.forward(db_global_fc1) #fc7 with margin
        #print('forward end') 
Example #6
Source File: data.py    From deeplearning-benchmark with Apache License 2.0 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #7
Source File: test_io.py    From SNIPER-mxnet with Apache License 2.0 5 votes vote down vote up
def test_DataBatch():
    from nose.tools import ok_
    from mxnet.io import DataBatch
    import re
    batch = DataBatch(data=[mx.nd.ones((2,3))])
    ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\)\] label shapes: None', str(batch)))
    batch = DataBatch(data=[mx.nd.ones((2,3)), mx.nd.ones((7,8))], label=[mx.nd.ones((4,5))])
    ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label shapes: \[\(4L?, 5L?\)\]', str(batch))) 
Example #8
Source File: data.py    From SNIPER-mxnet with Apache License 2.0 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #9
Source File: data.py    From SNIPER-mxnet with Apache License 2.0 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #10
Source File: data.py    From dlbench with MIT License 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #11
Source File: data_iterator.py    From dlcookbook-dlbs with Apache License 2.0 5 votes vote down vote up
def next(self):
        """For DataBatch definition, see this page:
           https://mxnet.incubator.apache.org/api/python/io.html#mxnet.io.DataBatch
        """
        return DataBatch(data=(self.data,), label=(self.label,), pad=0, index=None, provide_data=self.provide_data,
                         provide_label=self.provide_label) 
Example #12
Source File: data.py    From mxboard-demo with Apache License 2.0 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #13
Source File: helpers_fileiter.py    From kaggle_ndsb2 with Apache License 2.0 5 votes vote down vote up
def next(self):
        """return one dict which contains "data" and "label" """
        if self.iter_next():
            self.data, self.label = self._read()
            #for i in range(0, 10):
            #    self.data, self.label = self._read()
            #    d.append(mx.nd.array(self.data[0][1]))
            #    l.append(mx.nd.array(self.label[0][1]))
            
            res = DataBatch(data=[mx.nd.array(self.data[0][1])], label=[mx.nd.array(self.label[0][1])], pad=self.getpad(), index=None)
            #if self.cursor % 100 == 0:
            #    print "cursor: " + str(self.cursor)
            return res
        else:
            raise StopIteration 
Example #14
Source File: data.py    From uai-sdk with Apache License 2.0 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #15
Source File: data.py    From uai-sdk with Apache License 2.0 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #16
Source File: data.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #17
Source File: data.py    From casia-surf-2019-codes with MIT License 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #18
Source File: data.py    From training_results_v0.6 with Apache License 2.0 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #19
Source File: data.py    From training_results_v0.6 with Apache License 2.0 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #20
Source File: detector.py    From training_results_v0.6 with Apache License 2.0 5 votes vote down vote up
def create_batch(self, frame):
        """
        :param frame: an (w,h,channels) numpy array (image)
        :return: DataBatch of (1,channels,data_shape,data_shape)
        """
        frame_resize = mx.nd.array(cv2.resize(frame, (self.data_shape[0], self.data_shape[1])))
        #frame_resize = mx.img.imresize(frame, self.data_shape[0], self.data_shape[1], cv2.INTER_LINEAR)
        # Change dimensions from (w,h,channels) to (channels, w, h)
        frame_t = mx.nd.transpose(frame_resize, axes=(2,0,1))
        frame_norm = frame_t - self.mean_pixels_nd
        # Add dimension for batch, results in (1,channels,w,h)
        batch_frame = [mx.nd.expand_dims(frame_norm, axis=0)]
        batch_shape = [DataDesc('data', batch_frame[0].shape)]
        batch = DataBatch(data=batch_frame, provide_data=batch_shape)
        return batch 
Example #21
Source File: data.py    From training_results_v0.6 with Apache License 2.0 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #22
Source File: test_io.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def test_DataBatch():
    from nose.tools import ok_
    from mxnet.io import DataBatch
    import re
    batch = DataBatch(data=[mx.nd.ones((2, 3))])
    ok_(re.match(
        'DataBatch: data shapes: \[\(2L?, 3L?\)\] label shapes: None', str(batch)))
    batch = DataBatch(data=[mx.nd.ones((2, 3)), mx.nd.ones(
        (7, 8))], label=[mx.nd.ones((4, 5))])
    ok_(re.match(
        'DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label shapes: \[\(4L?, 5L?\)\]', str(batch))) 
Example #23
Source File: data.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return DataBatch(data=(self.data,),
                             label=(self.label,),
                             pad=0,
                             index=None,
                             provide_data=self.provide_data,
                             provide_label=self.provide_label)
        else:
            raise StopIteration 
Example #24
Source File: detector.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def create_batch(self, frame):
        """
        :param frame: an (w,h,channels) numpy array (image)
        :return: DataBatch of (1,channels,data_shape,data_shape)
        """
        frame_resize = mx.nd.array(cv2.resize(frame, (self.data_shape[0], self.data_shape[1])))
        #frame_resize = mx.img.imresize(frame, self.data_shape[0], self.data_shape[1], cv2.INTER_LINEAR)
        # Change dimensions from (w,h,channels) to (channels, w, h)
        frame_t = mx.nd.transpose(frame_resize, axes=(2,0,1))
        frame_norm = frame_t - self.mean_pixels_nd
        # Add dimension for batch, results in (1,channels,w,h)
        batch_frame = [mx.nd.expand_dims(frame_norm, axis=0)]
        batch_shape = [DataDesc('data', batch_frame[0].shape)]
        batch = DataBatch(data=batch_frame, provide_data=batch_shape)
        return batch 
Example #25
Source File: age_iter.py    From 1.FaceRecognition with MIT License 4 votes vote down vote up
def next(self):
        if not self.is_init:
          self.reset()
          self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch+=1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
          batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                #if label[1]>=0.0 or label[2]>=0.0:
                #  print(label[0:10])
                _data = self.imdecode(s)
                if self.rand_mirror:
                  _rd = random.randint(0,1)
                  if _rd==1:
                    _data = mx.ndarray.flip(data=_data, axis=1)
                if self.nd_mean is not None:
                    _data = _data.astype('float32')
                    _data -= self.nd_mean
                    _data *= 0.0078125
                if self.cutoff>0:
                  centerh = random.randint(0, _data.shape[0]-1)
                  centerw = random.randint(0, _data.shape[1]-1)
                  half = self.cutoff//2
                  starth = max(0, centerh-half)
                  endh = min(_data.shape[0], centerh+half)
                  startw = max(0, centerw-half)
                  endw = min(_data.shape[1], centerw+half)
                  _data = _data.astype('float32')
                  #print(starth, endh, startw, endw, _data.shape)
                  _data[starth:endh, startw:endw, :] = 127.5
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                #print('aa',data[0].shape)
                #data = self.augmentation_transform(data)
                #print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i<batch_size:
                raise StopIteration

        return io.DataBatch([batch_data], [batch_label], batch_size - i) 
Example #26
Source File: data.py    From 1.FaceRecognition with MIT License 4 votes vote down vote up
def reset_c2c(self):
      self.select_triplets()
      for identity,v in self.id2range.iteritems():
        _list = range(*v)
      
        for idx in _list:
          s = imgrec.read_idx(idx)
          ocontents.append(s)
        embeddings = None
        #print(len(ocontents))
        ba = 0
        while True:
          bb = min(ba+args.batch_size, len(ocontents))
          if ba>=bb:
            break
          _batch_size = bb-ba
          _batch_size2 = max(_batch_size, args.ctx_num)
          data = nd.zeros( (_batch_size2,3, image_size[0], image_size[1]) )
          label = nd.zeros( (_batch_size2,) )
          count = bb-ba
          ii=0
          for i in xrange(ba, bb):
            header, img = mx.recordio.unpack(ocontents[i])
            img = mx.image.imdecode(img)
            img = nd.transpose(img, axes=(2, 0, 1))
            data[ii][:] = img
            label[ii][:] = header.label
            ii+=1
          while ii<_batch_size2:
            data[ii][:] = data[0][:]
            label[ii][:] = label[0][:]
            ii+=1
          db = mx.io.DataBatch(data=(data,), label=(label,))
          self.mx_model.forward(db, is_train=False)
          net_out = self.mx_model.get_outputs()
          net_out = net_out[0].asnumpy()
          model.forward(db, is_train=False)
          net_out = model.get_outputs()
          net_out = net_out[0].asnumpy()
          if embeddings is None:
            embeddings = np.zeros( (len(ocontents), net_out.shape[1]))
          embeddings[ba:bb,:] = net_out[0:_batch_size,:]
          ba = bb
        embeddings = sklearn.preprocessing.normalize(embeddings)
        embedding = np.mean(embeddings, axis=0, keepdims=True)
        embedding = sklearn.preprocessing.normalize(embedding)
        sims = np.dot(embeddings, embedding).flatten()
        assert len(sims)==len(_list)
        for i in xrange(len(_list)):
          _idx = _list[i]
          self.idx2cos[_idx] = sims[i] 
Example #27
Source File: data.py    From insightface with MIT License 4 votes vote down vote up
def reset_c2c(self):
      self.select_triplets()
      for identity,v in self.id2range.iteritems():
        _list = range(*v)
      
        for idx in _list:
          s = imgrec.read_idx(idx)
          ocontents.append(s)
        embeddings = None
        #print(len(ocontents))
        ba = 0
        while True:
          bb = min(ba+args.batch_size, len(ocontents))
          if ba>=bb:
            break
          _batch_size = bb-ba
          _batch_size2 = max(_batch_size, args.ctx_num)
          data = nd.zeros( (_batch_size2,3, image_size[0], image_size[1]) )
          label = nd.zeros( (_batch_size2,) )
          count = bb-ba
          ii=0
          for i in xrange(ba, bb):
            header, img = mx.recordio.unpack(ocontents[i])
            img = mx.image.imdecode(img)
            img = nd.transpose(img, axes=(2, 0, 1))
            data[ii][:] = img
            label[ii][:] = header.label
            ii+=1
          while ii<_batch_size2:
            data[ii][:] = data[0][:]
            label[ii][:] = label[0][:]
            ii+=1
          db = mx.io.DataBatch(data=(data,), label=(label,))
          self.mx_model.forward(db, is_train=False)
          net_out = self.mx_model.get_outputs()
          net_out = net_out[0].asnumpy()
          model.forward(db, is_train=False)
          net_out = model.get_outputs()
          net_out = net_out[0].asnumpy()
          if embeddings is None:
            embeddings = np.zeros( (len(ocontents), net_out.shape[1]))
          embeddings[ba:bb,:] = net_out[0:_batch_size,:]
          ba = bb
        embeddings = sklearn.preprocessing.normalize(embeddings)
        embedding = np.mean(embeddings, axis=0, keepdims=True)
        embedding = sklearn.preprocessing.normalize(embedding)
        sims = np.dot(embeddings, embedding).flatten()
        assert len(sims)==len(_list)
        for i in xrange(len(_list)):
          _idx = _list[i]
          self.idx2cos[_idx] = sims[i] 
Example #28
Source File: data.py    From MaskInsightface with Apache License 2.0 4 votes vote down vote up
def reset_c2c(self):
      self.select_triplets()
      for identity,v in self.id2range.iteritems():
        _list = range(*v)

        for idx in _list:
          s = imgrec.read_idx(idx)
          ocontents.append(s)
        embeddings = None
        #print(len(ocontents))
        ba = 0
        while True:
          bb = min(ba+args.batch_size, len(ocontents))
          if ba>=bb:
            break
          _batch_size = bb-ba
          _batch_size2 = max(_batch_size, args.ctx_num)
          data = nd.zeros( (_batch_size2,3, image_size[0], image_size[1]) )
          label = nd.zeros( (_batch_size2,) )
          count = bb-ba
          ii=0
          for i in range(ba, bb):
            header, img = mx.recordio.unpack(ocontents[i])
            img = mx.image.imdecode(img)
            img = nd.transpose(img, axes=(2, 0, 1))
            data[ii][:] = img
            label[ii][:] = header.label
            ii+=1
          while ii<_batch_size2:
            data[ii][:] = data[0][:]
            label[ii][:] = label[0][:]
            ii+=1
          db = mx.io.DataBatch(data=(data,), label=(label,))
          self.mx_model.forward(db, is_train=False)
          net_out = self.mx_model.get_outputs()
          net_out = net_out[0].asnumpy()
          model.forward(db, is_train=False)
          net_out = model.get_outputs()
          net_out = net_out[0].asnumpy()
          if embeddings is None:
            embeddings = np.zeros( (len(ocontents), net_out.shape[1]))
          embeddings[ba:bb,:] = net_out[0:_batch_size,:]
          ba = bb
        embeddings = sklearn.preprocessing.normalize(embeddings)
        embedding = np.mean(embeddings, axis=0, keepdims=True)
        embedding = sklearn.preprocessing.normalize(embedding)
        sims = np.dot(embeddings, embedding).flatten()
        assert len(sims)==len(_list)
        for i in range(len(_list)):
          _idx = _list[i]
          self.idx2cos[_idx] = sims[i] 
Example #29
Source File: age_iter.py    From insightface with MIT License 4 votes vote down vote up
def next(self):
        if not self.is_init:
          self.reset()
          self.is_init = True
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        self.nbatch+=1
        batch_size = self.batch_size
        c, h, w = self.data_shape
        batch_data = nd.empty((batch_size, c, h, w))
        if self.provide_label is not None:
          batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                label, s, bbox, landmark = self.next_sample()
                #if label[1]>=0.0 or label[2]>=0.0:
                #  print(label[0:10])
                _data = self.imdecode(s)
                if self.rand_mirror:
                  _rd = random.randint(0,1)
                  if _rd==1:
                    _data = mx.ndarray.flip(data=_data, axis=1)
                if self.nd_mean is not None:
                    _data = _data.astype('float32')
                    _data -= self.nd_mean
                    _data *= 0.0078125
                if self.cutoff>0:
                  centerh = random.randint(0, _data.shape[0]-1)
                  centerw = random.randint(0, _data.shape[1]-1)
                  half = self.cutoff//2
                  starth = max(0, centerh-half)
                  endh = min(_data.shape[0], centerh+half)
                  startw = max(0, centerw-half)
                  endw = min(_data.shape[1], centerw+half)
                  _data = _data.astype('float32')
                  #print(starth, endh, startw, endw, _data.shape)
                  _data[starth:endh, startw:endw, :] = 127.5
                data = [_data]
                try:
                    self.check_valid_image(data)
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                #print('aa',data[0].shape)
                #data = self.augmentation_transform(data)
                #print('bb',data[0].shape)
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i<batch_size:
                raise StopIteration

        return io.DataBatch([batch_data], [batch_label], batch_size - i) 
Example #30
Source File: data.py    From insightocr with MIT License 4 votes vote down vote up
def next(self):
        """Returns the next batch of data."""
        #print('in next', self.cur, self.labelcur)
        #self.nbatch+=1
        batch_size = self.batch_size
        #c, h, w = self.data_shape
        batch_data = nd.empty(self.provide_data[0][1])
        batch_label = nd.empty(self.provide_label[0][1])
        i = 0
        try:
            while i < batch_size:
                item = self.next_sample()
                with open(item['image_path'], 'rb') as fin:
                    img = fin.read()
                try:
                    #if config.to_gray:
                    #  _data = mx.image.imdecode(img, flag=0) #to gray
                    #else:
                    #  _data = mx.image.imdecode(img)
                    #self.check_valid_image(_data)
                    img = np.fromstring(img, np.uint8)
                    if config.to_gray:
                      _data = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
                    else:
                      _data = cv2.imdecode(img, cv2.IMREAD_COLOR)
                      _data = cv2.cvtColor(_data, cv2.COLOR_BGR2RGB)
                    if _data.shape[0]!=config.img_height or _data.shape[1]!=config.img_width:
                      _data = cv2.resize(_data, (config.img_width, config.img_height) )
                except RuntimeError as e:
                    logging.debug('Invalid image, skipping:  %s', str(e))
                    continue
                _data = mx.nd.array(_data)
                #print(_data.shape)
                #if _data.shape[0]!=config.img_height or _data.shape[1]!=config.img_width:
                #  _data = self.resize_aug(_data)
                #print(_data.shape)
                _data = _data.astype('float32')
                _data -= 127.5
                _data *= 0.0078125
                data = [_data]
                label = item['label']
                for datum in data:
                    assert i < batch_size, 'Batch size must be multiples of augmenter output length'
                    #print(datum.shape)
                    batch_data[i][:] = self.postprocess_data(datum)
                    batch_label[i][:] = label
                    i += 1
        except StopIteration:
            if i<batch_size:
                raise StopIteration

        data_all = [batch_data]
        if config.use_lstm:
          data_all += self.init_state_arrays
        return io.DataBatch(data_all, [batch_label], batch_size - i)