Python mxnet.nd.transpose() Examples

The following are 30 code examples of mxnet.nd.transpose(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module mxnet.nd , or try the search function .
Example #1
Source File: utils_final.py    From InsightFace_TF with MIT License 6 votes vote down vote up
def load_data_mnist(batch_size, resize=None, root="~/.mxnet/datasets/mnist"):
    """download the fashion mnist dataest and then load into memory"""

    def transform_mnist(data, label):
        # Transform a batch of examples.
        if resize:
            n = data.shape[0]
            new_data = nd.zeros((n, resize, resize, data.shape[3]))
            for i in range(n):
                new_data[i] = image.imresize(data[i], resize, resize)
            data = new_data
        # change data from batch x height x width x channel to batch x channel x height x width
        return nd.transpose(data.astype('float32'), (0, 3, 1, 2)) / 255, label.astype('float32')

    mnist_train = gluon.data.vision.MNIST(root=root, train=True, transform=None)
    mnist_test = gluon.data.vision.MNIST(root=root, train=False, transform=None)
    # Transform later to avoid memory explosion.
    train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform=transform_mnist)
    test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform=transform_mnist)
    return (train_data, test_data) 
Example #2
Source File: train_imagenet.py    From ResidualAttentionNetwork with MIT License 6 votes vote down vote up
def transformer(data, label):
    jitter_param = 0.4
    lighting_param = 0.1
    im = data
    auglist = image.CreateAugmenter(data_shape=(3, 224, 224),
                                    rand_crop=True,
                                    rand_resize=True,
                                    rand_mirror=True,
                                    brightness=jitter_param,
                                    saturation=jitter_param,
                                    contrast=jitter_param,
                                    pca_noise=lighting_param,
                                    mean=True,
                                    std=True)

    for aug in auglist:
        im = aug(im)

    im = nd.transpose(im, (2, 0, 1))
    return im, label 
Example #3
Source File: net.py    From comment_toxic_CapsuleNet with MIT License 6 votes vote down vote up
def net_define_eu():
    net = nn.Sequential()
    with net.name_scope():
        net.add(nn.Embedding(config.MAX_WORDS, config.EMBEDDING_DIM))
        net.add(rnn.GRU(128,layout='NTC',bidirectional=True, num_layers=1, dropout=0.2))
        net.add(transpose(axes=(0,2,1)))
        net.add(nn.GlobalMaxPool1D())
        '''
        net.add(FeatureBlock1())
        '''
        net.add(extendDim(axes=3))
        net.add(PrimeConvCap(16, 32, kernel_size=(1,1), padding=(0,0),strides=(1,1)))
        net.add(CapFullyNGBlock(16, num_cap=12, input_units=32, units=16, route_num=3))
        net.add(nn.Dropout(0.2))
        net.add(nn.Dense(6, activation='sigmoid'))
    net.initialize(init=init.Xavier())
    return net 
Example #4
Source File: net.py    From comment_toxic_CapsuleNet with MIT License 6 votes vote down vote up
def net_define():
    net = nn.Sequential()
    with net.name_scope():
        net.add(nn.Embedding(config.MAX_WORDS, config.EMBEDDING_DIM))
        net.add(rnn.GRU(128,layout='NTC',bidirectional=True, num_layers=2, dropout=0.2))
        net.add(transpose(axes=(0,2,1)))
        # net.add(nn.MaxPool2D(pool_size=(config.MAX_LENGTH,1)))
        # net.add(nn.Conv2D(128, kernel_size=(101,1), padding=(50,0), groups=128,activation='relu'))
        net.add(PrimeConvCap(8,32, kernel_size=(1,1), padding=(0,0)))
        # net.add(AdvConvCap(8,32,8,32, kernel_size=(1,1), padding=(0,0)))
        net.add(CapFullyBlock(8*(config.MAX_LENGTH)/2, num_cap=12, input_units=32, units=16, route_num=5))
        # net.add(CapFullyBlock(8*(config.MAX_LENGTH-8), num_cap=12, input_units=32, units=16, route_num=5))
        # net.add(CapFullyBlock(8, num_cap=12, input_units=32, units=16, route_num=5))
        net.add(nn.Dropout(0.2))
        # net.add(LengthBlock())
        net.add(nn.Dense(6, activation='sigmoid'))
    net.initialize(init=init.Xavier())
    return net 
Example #5
Source File: net.py    From comment_toxic_CapsuleNet with MIT License 6 votes vote down vote up
def forward(self, x):
        x_t = nd.transpose(x, axes=(0,2,1))
        conv3_out = self.conv3(x_t)
        conv5_out = self.conv5(conv3_out) + conv3_out
        conv7_out = self.conv7(conv5_out) + conv5_out 
        # conv_out = nd.concat(*[conv3_out, conv5_out, conv7_out], dim=1)
        conv_out = self.conv_drop(conv7_out)
        conv_max_pooled = self.conv_maxpool(conv_out)

        gru_out = self.gru(x)
        gru_out_t = nd.transpose(gru_out, axes=(0,2,1))
        # gru_pooled = nd.transpose(gru_out, axes=(0,2,1))
        # gru_maxpooled = self.gru_post_max(gru_out_t)
        # return gru_maxpooled
        # gru_avepooled = self.gru_post_ave(gru_out_t)
        # gru_pooled = nd.concat(*[gru_maxpooled, gru_avepooled], dim=1)

        # gru_pooled = nd.concat(*[gru_maxpooled, gru_avepooled], dim=1)
        gru_maxpooled = self.gru_maxpool(gru_out_t)
        # gru_avepooled = self.gru_maxpool(gru_out_t)
        # gru_pooled = nd.concat(*[gru_maxpooled, gru_avepooled], dim=1)

        # conv_ave_pooled = self.conv_avepool(conv_out)
        concated_feature = nd.concat(*[gru_maxpooled, conv_max_pooled], dim=1)
        return concated_feature 
Example #6
Source File: utils.py    From EmotionClassifier with GNU General Public License v3.0 6 votes vote down vote up
def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"):
    """download the fashion mnist dataest and then load into memory"""
    def transform_mnist(data, label):
        # Transform a batch of examples.
        if resize:
            n = data.shape[0]
            new_data = nd.zeros((n, resize, resize, data.shape[3]))
            for i in range(n):
                new_data[i] = image.imresize(data[i], resize, resize)
            data = new_data
        # change data from batch x height x width x channel to batch x channel x height x width
        return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32')

    mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None)
    mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None)
    # Transform later to avoid memory explosion. 
    train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform=transform_mnist)
    test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform=transform_mnist)
    return (train_data, test_data) 
Example #7
Source File: utils_final.py    From InsightFace_TF with MIT License 6 votes vote down vote up
def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"):
    """download the fashion mnist dataest and then load into memory"""

    def transform_mnist(data, label):
        # Transform a batch of examples.
        if resize:
            n = data.shape[0]
            new_data = nd.zeros((n, resize, resize, data.shape[3]))
            for i in range(n):
                new_data[i] = image.imresize(data[i], resize, resize)
            data = new_data
        # change data from batch x height x width x channel to batch x channel x height x width
        return nd.transpose(data.astype('float32'), (0, 3, 1, 2)) / 255, label.astype('float32')

    mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None)
    mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None)
    # Transform later to avoid memory explosion.
    train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform=transform_mnist)
    test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform=transform_mnist)
    return (train_data, test_data) 
Example #8
Source File: siamrpn_tracker.py    From gluon-cv with Apache License 2.0 6 votes vote down vote up
def _convert_score(self, score):
        """from cls to score

        Parameters
        ----------
            score : ndarray
                network output

        Returns
        -------
            get feature map score though softmax
        """
        score = nd.transpose(score, axes=(1, 2, 3, 0))
        score = nd.reshape(score, shape=(2, -1))
        score = nd.transpose(score, axes=(1, 0))
        score = nd.softmax(score, axis=1)
        score = nd.slice_axis(score, axis=1, begin=1, end=2)
        score = nd.squeeze(score, axis=1)
        return score.asnumpy() 
Example #9
Source File: siamrpn_tracker.py    From gluon-cv with Apache License 2.0 6 votes vote down vote up
def _convert_bbox(self, delta, anchor):
        """from loc to predict postion

        Parameters
        ----------
            delta : ndarray or np.ndarray
                network output
            anchor : np.ndarray
                generate anchor location

        Returns
        -------
            rejust predict postion though Anchor
        """
        delta = nd.transpose(delta, axes=(1, 2, 3, 0))
        delta = nd.reshape(delta, shape=(4, -1))
        delta = delta.asnumpy()
        delta[0, :] = delta[0, :] * anchor[:, 2] + anchor[:, 0]
        delta[1, :] = delta[1, :] * anchor[:, 3] + anchor[:, 1]
        delta[2, :] = np.exp(delta[2, :]) * anchor[:, 2]
        delta[3, :] = np.exp(delta[3, :]) * anchor[:, 3]
        return delta 
Example #10
Source File: utils.py    From CapsNet_Mxnet with Apache License 2.0 6 votes vote down vote up
def load_data_fashion_mnist(batch_size, resize=None):
    """download the fashion mnist dataest and then load into memory"""
    def transform_mnist(data, label):
        if resize:
            # resize to resize x resize
            data = image.imresize(data, resize, resize)
        # change data from height x weight x channel to channel x height x weight
        return nd.transpose(data.astype('float32'), (2,0,1))/255, label.astype('float32')
    mnist_train = gluon.data.vision.FashionMNIST(root='./data',
        train=True, transform=transform_mnist)
    mnist_test = gluon.data.vision.FashionMNIST(root='./data',
        train=False, transform=transform_mnist)
    train_data = gluon.data.DataLoader(
        mnist_train, batch_size, shuffle=True)
    test_data = gluon.data.DataLoader(
        mnist_test, batch_size, shuffle=False)
    return (train_data, test_data) 
Example #11
Source File: E2FAR.py    From mxnet-E2FAR with Apache License 2.0 6 votes vote down vote up
def __getitem__(self, idx):
        img_path = self.data_frame.iloc[idx, 0]
        img = cv2.imread(img_path, 1)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        x, y, w, h = self.data_frame.iloc[idx, 1:5]
        l, t, ww, hh = enlarge_bbox(x, y, w, h, self.enlarge_factor)
        r, b = l + ww, t + hh

        img = img[t: b, l:r, :]
        img = cv2.resize(img, (self.img_size, self.img_size))
        img = img.astype(np.float32) - 127.5

        img = nd.transpose(nd.array(img), (2, 0, 1))

        label_path = img_path.replace('.jpg', '.mat')

        label = sio.loadmat(label_path)

        params_shape = label['Shape_Para'].astype(np.float32).ravel()
        params_exp = label['Exp_Para'].astype(np.float32).ravel()

        return img, params_shape, params_exp 
Example #12
Source File: TextEXAM_multi-label.py    From AAAI_2019_EXAM with GNU General Public License v2.0 6 votes vote down vote up
def forward(self,x):
        """
        return shape:(batch_size,2000,2)
        """
        # Encode layer
        question = x[:,0:30]
        question = self.Embed(question)
        question = self.gru(question)

        #interaction layer
        interaction = nd.dot(question,self.topic_embedding.data())
        interaction = nd.transpose(interaction,axes=(0,2,1))
        interaction = interaction.reshape((batch_size*2000,-1))
        # interaction = interaction.expand_dims(axis=1)
        # print("interaction done")

        #agg layer
        # interaction = self.pooling(self.conv_2(self.conv_1(interaction)))
        # print("agg done")
        res = self.mlp_2(self.mlp_1(interaction))
        res = res.reshape((batch_size,2000))

        return res

#Train Model 
Example #13
Source File: utils.py    From CapsNet_Mxnet with Apache License 2.0 6 votes vote down vote up
def load_data_mnist(batch_size, resize=None):
    """download the fashion mnist dataest and then load into memory"""
    def transform_mnist(data, label):
        if resize:
            # resize to resize x resize
            data = image.imresize(data, resize, resize)
        # change data from height x weight x channel to channel x height x weight
        return nd.transpose(data.astype('float32'), (2,0,1))/255, label.astype('float32')
    mnist_train = gluon.data.vision.MNIST(root='./data',
        train=True, transform=transform_mnist)
    mnist_test = gluon.data.vision.MNIST(root='./data',
        train=False, transform=transform_mnist)
    train_data = gluon.data.DataLoader(
        mnist_train, batch_size, shuffle=True)
    test_data = gluon.data.DataLoader(
        mnist_test, batch_size, shuffle=False)
    return (train_data, test_data) 
Example #14
Source File: gen_submission.py    From ResidualAttentionNetwork with MIT License 5 votes vote down vote up
def trans_test(data):
    im = data.astype(np.float32) / 255.
    auglist = image.CreateAugmenter(data_shape=(3, 32, 32),
                                    mean=mx.nd.array([0.485, 0.456, 0.406]),
                                    std=mx.nd.array([0.229, 0.224, 0.225]))
    for aug in auglist:
        im = aug(im)

    im = nd.transpose(im, (2, 0, 1))
    return im 
Example #15
Source File: learn_nms.py    From Relation-Networks-for-Object-Detection with MIT License 5 votes vote down vote up
def extract_multi_position_matrix_nd(bbox):
    bbox = nd.transpose(bbox, axes=(1, 0, 2))
    xmin, ymin, xmax, ymax = nd.split(data=bbox, num_outputs=4, axis=2)
    # [num_fg_classes, num_boxes, 1]
    bbox_width = xmax - xmin + 1.
    bbox_height = ymax - ymin + 1.
    center_x = 0.5 * (xmin + xmax)
    center_y = 0.5 * (ymin + ymax)
    # [num_fg_classes, num_boxes, num_boxes]
    delta_x = nd.broadcast_minus(lhs=center_x, 
        rhs=nd.transpose(center_x, axes=(0, 2, 1)))
    delta_x = nd.broadcast_div(delta_x, bbox_width)
    delta_x = nd.log(nd.maximum(nd.abs(delta_x), 1e-3))

    delta_y = nd.broadcast_minus(lhs=center_y,
        rhs=nd.transpose(center_y, axes=(0, 2, 1)))
    delta_y = nd.broadcast_div(delta_y, bbox_height)
    delta_y = nd.log(nd.maximum(nd.abs(delta_y), 1e-3))

    delta_width = nd.broadcast_div(lhs=bbox_width, 
        rhs=nd.transpose(bbox_width, axes=(0, 2, 1)))
    delta_width = nd.log(delta_width)

    delta_height = nd.broadcast_div(lhs=bbox_height,
        rhs=nd.transpose(bbox_height, axes=(0, 2, 1)))
    delta_height = nd.log(delta_height)
    concat_list = [delta_x, delta_y, delta_width, delta_height]
    for idx, sym in enumerate(concat_list):
        concat_list[idx] = nd.expand_dims(sym, axis=3)
    position_matrix = nd.concat(*concat_list, dim=3)
    return position_matrix 
Example #16
Source File: utils.py    From gluon-face with MIT License 5 votes vote down vote up
def transform_train(data, label):
    im = data.astype('float32') / 255 - 0.5
    im = nd.transpose(im, (2, 0, 1))
    return im, label 
Example #17
Source File: utils.py    From gluon-face with MIT License 5 votes vote down vote up
def transform_val(data, label):
    im = data.astype('float32') / 255 - 0.5
    im = nd.transpose(im, (2, 0, 1))
    return im, label 
Example #18
Source File: test_script.py    From gluon-face with MIT License 5 votes vote down vote up
def transform_test_flip(data, isf=False):
    flip_data = nd.flip(data, axis=1)
    if isf:
        data = nd.transpose(data, (2, 0, 1)).astype('float32')
        flip_data = nd.transpose(flip_data, (2, 0, 1)).astype('float32')
        return data, flip_data
    return transform_test(data), transform_test(flip_data) 
Example #19
Source File: seq2seq.py    From ST-MetaNet with MIT License 5 votes vote down vote up
def forward(self, feature, data, label, is_training):
        """ Forward the seq2seq network.

        Parameters
        ----------
        feature: NDArray with shape [b, n, d].
            The features of each node. 
        data: NDArray with shape [b, t, n, d].
            The flow readings.
        label: NDArray with shape [b, t, n, d].
            The flow labels.
        is_training: bool.


        Returns
        -------
        loss: loss for gradient descent.
        (pred, label): each of them is a NDArray with shape [n, b, t, d].

        """
        data = nd.transpose(data, axes=(2, 0, 1, 3)) # [n, b, t, d]
        label = nd.transpose(label, axes=(2, 0, 1, 3)) # [n, b, t, d]

        # geo-feature embedding (NMK Learner)
        feature = self.geo_encoder(nd.mean(feature, axis=0)) # shape=[n, d]

        # seq2seq encoding process
        states = self.encoder(feature, data)

        # seq2seq decoding process
        output = self.decoder(feature, label, states, is_training) # [n, b, t, d]
             
        # loss calculation
        label = label[:,:,:,:self.decoder.output_dim]

        # loss = nd.mean((output - label) ** 2, axis=1, exclude=True)
        loss = nd.mean(nd.abs(output - label), axis=1, exclude=True)
        return loss, [output, label] 
Example #20
Source File: train_imagenet.py    From ResidualAttentionNetwork with MIT License 5 votes vote down vote up
def trans_test(data, label):
    im = data
    auglist = image.CreateAugmenter(data_shape=(3, 224, 224), resize=256,
                                    mean=True,
                                    std=True)
    for aug in auglist:
        im = aug(im)

    im = nd.transpose(im, (2, 0, 1))
    return im, label 
Example #21
Source File: train_cifar.py    From ResidualAttentionNetwork with MIT License 5 votes vote down vote up
def transformer(data, label):
    im = data.asnumpy()
    im = np.pad(im, pad_width=((4, 4), (4, 4), (0, 0)), mode='constant')
    im = random_eraser(im)
    im = nd.array(im) / 255.
    auglist = image.CreateAugmenter(data_shape=(3, 32, 32), rand_crop=True, rand_mirror=True,
                                    mean=mx.nd.array([0.4914, 0.4824, 0.4467]),
                                    std=mx.nd.array([0.2471, 0.2435, 0.2616]))
    for aug in auglist:
        im = aug(im)
    im = nd.transpose(im, (2, 0, 1))
    return im, label 
Example #22
Source File: train_cifar.py    From ResidualAttentionNetwork with MIT License 5 votes vote down vote up
def trans_test(data, label):
    im = data.astype(np.float32) / 255.
    auglist = image.CreateAugmenter(data_shape=(3, 32, 32),
                                    mean=mx.nd.array([0.4914, 0.4824, 0.4467]),
                                    std=mx.nd.array([0.2471, 0.2435, 0.2616]))
    for aug in auglist:
        im = aug(im)

    im = nd.transpose(im, (2, 0, 1))
    return im, label 
Example #23
Source File: data.py    From SNIPER-mxnet with Apache License 2.0 5 votes vote down vote up
def transform(data, target_wd, target_ht, is_train, box):
    """Crop and normnalize an image nd array."""
    if box is not None:
        x, y, w, h = box
        data = data[y:min(y+h, data.shape[0]), x:min(x+w, data.shape[1])]

    # Resize to target_wd * target_ht.
    data = mx.image.imresize(data, target_wd, target_ht)

    # Normalize in the same way as the pre-trained model.
    data = data.astype(np.float32) / 255.0
    data = (data - mx.nd.array([0.485, 0.456, 0.406])) / mx.nd.array([0.229, 0.224, 0.225])

    if is_train:
        if random.random() < 0.5:
            data = nd.flip(data, axis=1)
        data, _ = mx.image.random_crop(data, (224, 224))
    else:
        data, _ = mx.image.center_crop(data, (224, 224))

    # Transpose from (target_wd, target_ht, 3)
    # to (3, target_wd, target_ht).
    data = nd.transpose(data, (2, 0, 1))

    # If image is greyscale, repeat 3 times to get RGB image.
    if data.shape[0] == 1:
        data = nd.tile(data, (3, 1, 1))
    return data.reshape((1,) + data.shape) 
Example #24
Source File: target.py    From cascade_rcnn_gluon with Apache License 2.0 5 votes vote down vote up
def forward(self, anchors, cls_preds, gt_boxes, gt_ids):
        """Generate training targets."""
        anchors = self._center_to_corner(anchors.reshape((-1, 4)))
        ious = nd.transpose(nd.contrib.box_iou(anchors, gt_boxes), (1, 0, 2))
        matches = self._matcher(ious)
        if self._use_negative_sampling:
            samples = self._sampler(matches, cls_preds, ious)
        else:
            samples = self._sampler(matches)
        cls_targets = self._cls_encoder(samples, matches, gt_ids)
        box_targets, box_masks = self._box_encoder(samples, matches, anchors, gt_boxes)
        return cls_targets, box_targets, box_masks 
Example #25
Source File: get_data.py    From EmotionClassifier with GNU General Public License v3.0 5 votes vote down vote up
def transform(data):
    data = nd.array(data) # 部分数据增强接受`float32`
    data = nd.transpose(data, (2,0,1)) # 改变维度顺序为(c, w, h)
    data = image_augmentaion(data)
    data = random_mask(data, 32, n_chanel= 1,flag=1) # 执行random_mask, 随机遮盖

    return data 
Example #26
Source File: cifar10_dist.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def transform(data, label):
    return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32) 
Example #27
Source File: tensor_utils.py    From mxnet-centernet with MIT License 5 votes vote down vote up
def symbolic_transpose_and_gather_feat(F, feat, ind, K, batch, cat, attri):
    #print("In symbolic_transpose_and_gather_feat, feat.shape = ", feat.shape)
    feat = F.transpose(feat, axes=(0, 2, 3, 1))
    feat = F.reshape(feat, shape=(batch, -1, cat))
    #print("In symbolic_transpose_and_gather_feat, feat.shape = ", feat.shape)

    feat = symbolic_gather_feat(F, feat, ind, K, attri)
    return feat 
Example #28
Source File: data.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def transform(data, target_wd, target_ht, is_train, box):
    """Crop and normnalize an image nd array."""
    if box is not None:
        x, y, w, h = box
        data = data[y:min(y+h, data.shape[0]), x:min(x+w, data.shape[1])]

    # Resize to target_wd * target_ht.
    data = mx.image.imresize(data, target_wd, target_ht)

    # Normalize in the same way as the pre-trained model.
    data = data.astype(np.float32) / 255.0
    data = (data - mx.nd.array([0.485, 0.456, 0.406])) / mx.nd.array([0.229, 0.224, 0.225])

    if is_train:
        if random.random() < 0.5:
            data = nd.flip(data, axis=1)
        data, _ = mx.image.random_crop(data, (224, 224))
    else:
        data, _ = mx.image.center_crop(data, (224, 224))

    # Transpose from (target_wd, target_ht, 3)
    # to (3, target_wd, target_ht).
    data = nd.transpose(data, (2, 0, 1))

    # If image is greyscale, repeat 3 times to get RGB image.
    if data.shape[0] == 1:
        data = nd.tile(data, (3, 1, 1))
    return data.reshape((1,) + data.shape) 
Example #29
Source File: TextEXAM_multi-label.py    From AAAI_2019_EXAM with GNU General Public License v2.0 5 votes vote down vote up
def batch_attention(encoder,decoder):
    attention = nd.softmax(nd.batch_dot(encoder,nd.transpose(decoder,axes = (0,2,1))),axis=1)
    new_decoder = nd.batch_dot(attention,nd.transpose(encoder,axes=(0,1,2)))
    return new_decoder 
Example #30
Source File: target.py    From gluon-cv with Apache License 2.0 5 votes vote down vote up
def forward(self, anchors, cls_preds, gt_boxes, gt_ids):
        """Generate training targets."""
        anchors = self._center_to_corner(anchors.reshape((-1, 4)))
        ious = nd.transpose(nd.contrib.box_iou(anchors, gt_boxes), (1, 0, 2))
        matches = self._matcher(ious)
        if self._use_negative_sampling:
            samples = self._sampler(matches, cls_preds, ious)
        else:
            samples = self._sampler(matches)
        cls_targets = self._cls_encoder(samples, matches, gt_ids)
        box_targets, box_masks = self._box_encoder(samples, matches, anchors, gt_boxes)
        return cls_targets, box_targets, box_masks