Python torch.squeeze() Examples

The following are 30 code examples of torch.squeeze(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: BiLSTM_CNN.py    From pytorch_NER_BiLSTM_CNN_CRF with Apache License 2.0 6 votes vote down vote up
def _char_forward(self, inputs):
        """
        Args:
            inputs: 3D tensor, [bs, max_len, max_len_char]

        Returns:
            char_conv_outputs: 3D tensor, [bs, max_len, output_dim]
        """
        max_len, max_len_char = inputs.size(1), inputs.size(2)
        inputs = inputs.view(-1, max_len * max_len_char)  # [bs, -1]
        input_embed = self.char_embedding(inputs)  # [bs, ml*ml_c, feature_dim]
        # input_embed = self.dropout_embed(input_embed)
        # [bs, 1, max_len, max_len_char, feature_dim]
        input_embed = input_embed.view(-1, 1, max_len, max_len_char, self.char_dim)
        # conv
        char_conv_outputs = []
        for char_encoder in self.char_encoders:
            conv_output = char_encoder(input_embed)
            pool_output = torch.squeeze(torch.max(conv_output, -2)[0], -1)
            char_conv_outputs.append(pool_output)
        char_conv_outputs = torch.cat(char_conv_outputs, dim=1)
        char_conv_outputs = char_conv_outputs.permute(0, 2, 1)

        return char_conv_outputs 
Example #2
Source File: pointwise.py    From pykg2vec with MIT License 6 votes vote down vote up
def forward(self, h, r, t):
        h_emb, r_emb, t_emb = self.embed(h, r, t)
        first_dimen = list(h_emb.shape)[0]
        
        stacked_h = torch.unsqueeze(h_emb, dim=1)
        stacked_r = torch.unsqueeze(r_emb, dim=1)
        stacked_t = torch.unsqueeze(t_emb, dim=1)

        stacked_hrt = torch.cat([stacked_h, stacked_r, stacked_t], dim=1)
        stacked_hrt = torch.unsqueeze(stacked_hrt, dim=1)  # [b, 1, 3, k]

        stacked_hrt = [conv_layer(stacked_hrt) for conv_layer in self.conv_list]
        stacked_hrt = torch.cat(stacked_hrt, dim=3)
        stacked_hrt = stacked_hrt.view(first_dimen, -1)
        preds = self.fc1(stacked_hrt)
        preds = torch.squeeze(preds, dim=-1)
        return preds 
Example #3
Source File: MessageFunction.py    From nmp_qc with MIT License 6 votes vote down vote up
def m_ggnn(self, h_v, h_w, e_vw, opt={}):

        m = Variable(torch.zeros(h_w.size(0), h_w.size(1), self.args['out']).type_as(h_w.data))

        for w in range(h_w.size(1)):
            if torch.nonzero(e_vw[:, w, :].data).size():
                for i, el in enumerate(self.args['e_label']):
                    ind = (el == e_vw[:,w,:]).type_as(self.learn_args[0][i])

                    parameter_mat = self.learn_args[0][i][None, ...].expand(h_w.size(0), self.learn_args[0][i].size(0),
                                                                            self.learn_args[0][i].size(1))

                    m_w = torch.transpose(torch.bmm(torch.transpose(parameter_mat, 1, 2),
                                                                        torch.transpose(torch.unsqueeze(h_w[:, w, :], 1),
                                                                                        1, 2)), 1, 2)
                    m_w = torch.squeeze(m_w)
                    m[:,w,:] = ind.expand_as(m_w)*m_w
        return m 
Example #4
Source File: utils.py    From Semantic-Aware-Scene-Recognition with MIT License 6 votes vote down vote up
def semanticIoU(pred, label):
    """
    Computes the mean Intersection over Union for all the classes between two mini-batch tensors of semantic
    segmentation
    :param pred: Tensor of predictions
    :param label: Tensor of ground-truth
    :return: Mean semantic intersection over Union for all the classes
    """
    imPred = np.asarray(torch.squeeze(pred))
    imLab = np.asarray(torch.squeeze(label))

    area_intersection = []
    area_union = []

    for i in range(imLab.shape[0]):
        intersection, union = intersectionAndUnion(imPred[i], imLab[i])
        area_intersection.append(intersection)
        area_union.append(union)

    IoU = 1.0 * np.sum(area_intersection, axis=0) / np.sum(np.spacing(1)+area_union, axis=0)

    return np.mean(IoU) 
Example #5
Source File: utils.py    From Semantic-Aware-Scene-Recognition with MIT License 6 votes vote down vote up
def MeanPixelAccuracy(pred, label):
    """
    Function to compute the mean pixel accuracy for semantic segmentation between mini-batch tensors
    :param pred: Tensor of predictions
    :param label: Tensor of ground-truth
    :return: Mean pixel accuracy for all the mini-bath
    """
    # Convert tensors to numpy arrays
    imPred = np.asarray(torch.squeeze(pred))
    imLab = np.asarray(torch.squeeze(label))

    # Create empty numpy arrays
    pixel_accuracy = np.empty(imLab.shape[0])
    pixel_correct = np.empty(imLab.shape[0])
    pixel_labeled = np.empty(imLab.shape[0])

    # Compute pixel accuracy for each pair of images in the batch
    for i in range(imLab.shape[0]):
        pixel_accuracy[i], pixel_correct[i], pixel_labeled[i] = pixelAccuracy(imPred[i], imLab[i])

    # Compute the final accuracy for the batch
    acc = 100.0 * np.sum(pixel_correct) / (np.spacing(1) + np.sum(pixel_labeled))

    return acc 
Example #6
Source File: segbase.py    From SegmenTron with Apache License 2.0 6 votes vote down vote up
def _pad_image(img, crop_size):
    b, c, h, w = img.shape
    assert(c == 3)
    padh = crop_size[0] - h if h < crop_size[0] else 0
    padw = crop_size[1] - w if w < crop_size[1] else 0
    if padh == 0 and padw == 0:
        return img
    img_pad = F.pad(img, (0, padh, 0, padw))

    # TODO clean this code
    # mean = cfg.DATASET.MEAN
    # std = cfg.DATASET.STD
    # pad_values = -np.array(mean) / np.array(std)
    # img_pad = torch.zeros((b, c, h + padh, w + padw)).to(img.device)
    # for i in range(c):
    #     # print(img[:, i, :, :].unsqueeze(1).shape)
    #     img_pad[:, i, :, :] = torch.squeeze(
    #         F.pad(img[:, i, :, :].unsqueeze(1), (0, padh, 0, padw),
    #               'constant', value=pad_values[i]), 1)
    # assert(img_pad.shape[2] >= crop_size[0] and img_pad.shape[3] >= crop_size[1])

    return img_pad 
Example #7
Source File: Losses.py    From BMSG-GAN with MIT License 6 votes vote down vote up
def dis_loss(self, real_samps, fake_samps):
        # small assertion:
        assert real_samps.device == fake_samps.device, \
            "Real and Fake samples are not on the same device"

        # device for computations:
        device = fake_samps.device

        # predictions for real images and fake images separately :
        r_preds = self.dis(real_samps)
        f_preds = self.dis(fake_samps)

        # calculate the real loss:
        real_loss = self.criterion(
            th.squeeze(r_preds),
            th.ones(real_samps.shape[0]).to(device))

        # calculate the fake loss:
        fake_loss = self.criterion(
            th.squeeze(f_preds),
            th.zeros(fake_samps.shape[0]).to(device))

        # return final losses
        return (real_loss + fake_loss) / 2 
Example #8
Source File: actor_critic_loss_function.py    From rlgraph with Apache License 2.0 6 votes vote down vote up
def _graph_fn_state_value_function_loss_per_item(self, state_values, advantages, time_percentage=None):
        """
        Computes the loss for V(s).

        Args:
            state_values (SingleDataOp): Baseline predictions V(s).
            advantages (SingleDataOp): Advantage values.

        Returns:
            SingleDataOp: Baseline loss per item.
        """
        v_targets = None
        if get_backend() == "tf":
            state_values = tf.squeeze(input=state_values, axis=-1)
            v_targets = advantages + state_values
            v_targets = tf.stop_gradient(input=v_targets)
        elif get_backend() == "pytorch":
            state_values = torch.squeeze(state_values, dim=-1)
            v_targets = advantages + state_values
            v_targets = v_targets.detach()

        vf_loss = (v_targets - state_values) ** 2
        return self.weight_vf.get(time_percentage) * vf_loss 
Example #9
Source File: ReadoutFunction.py    From nmp_qc with MIT License 6 votes vote down vote up
def r_duvenaud(self, h):
        # layers
        aux = []
        for l in range(len(h)):
            param_sz = self.learn_args[l].size()
            parameter_mat = torch.t(self.learn_args[l])[None, ...].expand(h[l].size(0), param_sz[1],
                                                                                      param_sz[0])

            aux.append(torch.transpose(torch.bmm(parameter_mat, torch.transpose(h[l], 1, 2)), 1, 2))

            for j in range(0, aux[l].size(1)):
                # Mask whole 0 vectors
                aux[l][:, j, :] = nn.Softmax()(aux[l][:, j, :].clone())*(torch.sum(aux[l][:, j, :] != 0, 1) > 0).expand_as(aux[l][:, j, :]).type_as(aux[l])

        aux = torch.sum(torch.sum(torch.stack(aux, 3), 3), 1)
        return self.learn_modules[0](torch.squeeze(aux)) 
Example #10
Source File: Patient2Vec.py    From Patient2Vec with MIT License 6 votes vote down vote up
def convolutional_layer(self, inputs):
        convolution_all = []
        conv_wts = []
        for i in range(self.seq_len):
            convolution_one_month = []
            for j in range(self.pad_size):
                convolution = self.conv(torch.unsqueeze(inputs[:, i, j], dim=1))
                convolution_one_month.append(convolution)
            convolution_one_month = torch.stack(convolution_one_month)
            convolution_one_month = torch.squeeze(convolution_one_month, dim=3)
            convolution_one_month = torch.transpose(convolution_one_month, 0, 1)
            convolution_one_month = torch.transpose(convolution_one_month, 1, 2)
            convolution_one_month = torch.squeeze(convolution_one_month, dim=1)
            convolution_one_month = self.func_tanh(convolution_one_month)
            convolution_one_month = torch.unsqueeze(convolution_one_month, dim=1)
            vec = torch.bmm(convolution_one_month, inputs[:, i])
            convolution_all.append(vec)
            conv_wts.append(convolution_one_month)
        convolution_all = torch.stack(convolution_all, dim=1)
        convolution_all = torch.squeeze(convolution_all, dim=2)
        conv_wts = torch.squeeze(torch.stack(conv_wts, dim=1), dim=2)
        return convolution_all, conv_wts 
Example #11
Source File: core.py    From spinningup with MIT License 5 votes vote down vote up
def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape. 
Example #12
Source File: model_utils.py    From Extremely-Fine-Grained-Entity-Typing with MIT License 5 votes vote down vote up
def forward(self, key, memory, lengths):
    '''
    key (bsz, hidden)
    memory (bsz, seq_len, hidden)
    '''
    scores = torch.bmm(memory, self.attn(key).unsqueeze(2)).squeeze(2)
    attn_scores = self.normalize(scores, lengths)
    retrieved = torch.sum(attn_scores.unsqueeze(2) * memory, dim=1) # (bsz, hidden)
    return retrieved 
Example #13
Source File: core.py    From spinningup with MIT License 5 votes vote down vote up
def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape. 
Example #14
Source File: core.py    From spinningup with MIT License 5 votes vote down vote up
def forward(self, obs):
        return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape. 
Example #15
Source File: core.py    From spinningup with MIT License 5 votes vote down vote up
def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape. 
Example #16
Source File: core.py    From spinningup with MIT License 5 votes vote down vote up
def forward(self, obs):
        return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape. 
Example #17
Source File: memory.py    From LSH_Memory with Apache License 2.0 5 votes vote down vote up
def update(self, query, y, y_hat, y_hat_indices):
        batch_size, dims = query.size()

        # 1) Untouched: Increment memory by 1
        self.age += 1

        # Divide batch by correctness
        result = torch.squeeze(torch.eq(y_hat, torch.unsqueeze(y.data, dim=1))).float()
        incorrect_examples = torch.squeeze(torch.nonzero(1-result))
        correct_examples = torch.squeeze(torch.nonzero(result))

        incorrect = len(incorrect_examples.size()) > 0
        correct = len(correct_examples.size()) > 0

        # 2) Correct: if V[n1] = v
        # Update Key k[n1] <- normalize(q + K[n1]), Reset Age A[n1] <- 0
        if correct:
            correct_indices = y_hat_indices[correct_examples]
            correct_keys = self.keys[correct_indices]
            correct_query = query.data[correct_examples]

            new_correct_keys = F.normalize(correct_keys + correct_query, dim=1)
            self.keys[correct_indices] = new_correct_keys
            self.age[correct_indices] = 0

        # 3) Incorrect: if V[n1] != v
        # Select item with oldest age, Add random offset - n' = argmax_i(A[i]) + r_i 
        # K[n'] <- q, V[n'] <- v, A[n'] <- 0
        if incorrect:
            incorrect_size = incorrect_examples.size()[0]
            incorrect_query = query.data[incorrect_examples]
            incorrect_values = y.data[incorrect_examples]

            age_with_noise = self.age + random_uniform((self.memory_size, 1), -self.age_noise, self.age_noise, cuda=True)
            topk_values, topk_indices = torch.topk(age_with_noise, incorrect_size, dim=0)
            oldest_indices = torch.squeeze(topk_indices)

            self.keys[oldest_indices] = incorrect_query
            self.values[oldest_indices] = incorrect_values
            self.age[oldest_indices] = 0 
Example #18
Source File: pairwise.py    From pykg2vec with MIT License 5 votes vote down vote up
def embed(self, h, r, t):
        """Function to get the embedding value.

           Args:
               h (Tensor): Head entities ids.
               r (Tensor): Relation ids of the triple.
               t (Tensor): Tail entity ids of the triple.

            Returns:
                Tensors: Returns head, relation and tail embedding Tensors.
        """
        h_e = self.ent_embeddings(h)
        r_e = self.rel_embeddings(r)
        t_e = self.ent_embeddings(t)

        h_e = F.normalize(h_e, p=2, dim=-1)
        r_e = F.normalize(r_e, p=2, dim=-1)
        t_e = F.normalize(t_e, p=2, dim=-1)

        h_e = torch.unsqueeze(h_e, 1)
        t_e = torch.unsqueeze(t_e, 1)
        # [b, 1, k]

        matrix = self.rel_matrix(r)
        # [b, k, d]

        transform_h_e = self.transform(h_e, matrix)
        transform_t_e = self.transform(t_e, matrix)
        # [b, 1, d] = [b, 1, k] * [b, k, d]

        h_e = torch.squeeze(transform_h_e, axis=1)
        t_e = torch.squeeze(transform_t_e, axis=1)
        # [b, d]
        return h_e, r_e, t_e 
Example #19
Source File: ppo_loss_function.py    From rlgraph with Apache License 2.0 5 votes vote down vote up
def _graph_fn_value_function_loss_per_item(self, state_values, prev_state_values, advantages):
        """
        Computes the loss for V(s).

        Args:
            state_values (SingleDataOp): State value predictions V(s).
            prev_state_values (SingleDataOp): Previous state value predictions V(s) (before the update).
            advantages (SingleDataOp): GAE (advantage) values.

        Returns:
            SingleDataOp: Value function loss per item.
        """
        if get_backend() == "tf":
            state_values = tf.squeeze(input=state_values, axis=-1)
            prev_state_values = tf.squeeze(input=prev_state_values, axis=-1)
            v_targets = advantages + prev_state_values
            v_targets = tf.stop_gradient(input=v_targets)
            vf_loss = (state_values - v_targets) ** 2
            if self.value_function_clipping:
                vf_clipped = prev_state_values + tf.clip_by_value(
                    state_values - prev_state_values, -self.value_function_clipping, self.value_function_clipping
                )
                clipped_loss = (vf_clipped - v_targets) ** 2
                return tf.maximum(vf_loss, clipped_loss)
            else:
                return vf_loss

        elif get_backend() == "pytorch":
            state_values = torch.squeeze(state_values, dim=-1)
            prev_state_values = torch.squeeze(input=prev_state_values, dim=-1)
            v_targets = advantages + prev_state_values
            v_targets = v_targets.detach()
            vf_loss = (state_values - v_targets) ** 2
            if self.value_function_clipping:
                vf_clipped = prev_state_values + torch.clamp(
                    state_values - prev_state_values, -self.value_function_clipping, self.value_function_clipping
                )
                clipped_loss = (vf_clipped - v_targets) ** 2
                return torch.max(vf_loss, clipped_loss)
            else:
                return vf_loss 
Example #20
Source File: StereoNet.py    From DenseMatchingBenchmark with MIT License 5 votes vote down vote up
def forward(self, raw_cost):
        # default down-sample to 1/8 resolution, it also can be 1/16
        # raw_cost: (BatchSize, Channels, MaxDisparity/8, Height/8, Width/8)
        for i in range(self.num):
            raw_cost = self.classify[i](raw_cost)

        # cost: (BatchSize, 1, MaxDisparity/8, Height/8, Width/8)
        cost = self.lastconv(raw_cost)

        # (BatchSize, MaxDisparity/8, Height/8, Width/8)
        cost = torch.squeeze(cost, 1)


        return [cost] 
Example #21
Source File: model_utils.py    From Extremely-Fine-Grained-Entity-Typing with MIT License 5 votes vote down vote up
def forward(self, span_chars):
    char_embed = self.char_W(span_chars).transpose(1, 2)  # [batch_size, char_embedding, max_char_seq]
    conv_output = [self.conv1d(char_embed)]  # list of [batch_size, filter_dim, max_char_seq, filter_number]
    conv_output = [F.relu(c) for c in conv_output]  # batch_size, filter_dim, max_char_seq, filter_num
    cnn_rep = [F.max_pool1d(i, i.size(2)) for i in conv_output]  # batch_size, filter_dim, 1, filter_num
    cnn_output = torch.squeeze(torch.cat(cnn_rep, 1), 2)  # batch_size, filter_num * filter_dim, 1
    return cnn_output 
Example #22
Source File: selector.py    From DISTRE with Apache License 2.0 5 votes vote down vote up
def forward(self, x, scopes=None, label=None):
        scopes = scopes or [(0, x.size(0))]

        if self.training:
            attention_logit = self._attention_train_logit(x, label)

            tower_repre = []
            for start, end in scopes:
                sen_matrix = x[start: end]
                attention_score = F.softmax(torch.transpose(attention_logit[start: end], 0, 1), 1)
                final_repre = torch.squeeze(torch.matmul(attention_score, sen_matrix))
                tower_repre.append(final_repre)
            stack_repre = torch.stack(tower_repre)
            stack_repre = self.dropout(stack_repre)
            logits = self.get_logits(stack_repre)
            return logits

        else:
            attention_logit = self._attention_test_logit(x)

            tower_output = []
            for start, end in scopes:
                sen_matrix = x[start: end]
                attention_score = F.softmax(torch.transpose(attention_logit[start: end], 0, 1), 1)
                final_repre = torch.matmul(attention_score, sen_matrix)
                logits = self.get_logits(final_repre)
                tower_output.append(torch.diag(F.softmax(logits, 1)))
            stack_output = torch.stack(tower_output)
            return stack_output 
Example #23
Source File: basic.py    From DSMnet with Apache License 2.0 5 votes vote down vote up
def forward(self, left, right):

        refimg_fea     = self.feature_extraction(left)
        targetimg_fea  = self.feature_extraction(right)
 
        #matching
        cost = Variable(torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1]*2, self.maxdisp/4,  refimg_fea.size()[2],  refimg_fea.size()[3]).zero_(), volatile= not self.training).cuda()

        for i in range(self.maxdisp/4):
            if i > 0 :
             cost[:, :refimg_fea.size()[1], i, :,i:]   = refimg_fea[:,:,:,i:]
             cost[:, refimg_fea.size()[1]:, i, :,i:] = targetimg_fea[:,:,:,:-i]
            else:
             cost[:, :refimg_fea.size()[1], i, :,:]   = refimg_fea
             cost[:, refimg_fea.size()[1]:, i, :,:]   = targetimg_fea
        cost = cost.contiguous()

        cost0 = self.dres0(cost)
        cost0 = self.dres1(cost0) + cost0
        cost0 = self.dres2(cost0) + cost0 
        cost0 = self.dres3(cost0) + cost0 
        cost0 = self.dres4(cost0) + cost0

        cost = self.classify(cost0)
        cost = F.upsample(cost, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear')
        cost = torch.squeeze(cost,1)
        pred = F.softmax(cost)
        pred = disparityregression(self.maxdisp)(pred)

        return pred 
Example #24
Source File: Patient2Vec.py    From Patient2Vec with MIT License 5 votes vote down vote up
def add_beta_attention(self, states, batch_size):
        # beta attention
        att_wts = []
        for i in range(self.seq_len):
            m1 = self.conv2(torch.unsqueeze(states[:, i], dim=1))
            att_wts.append(torch.squeeze(m1, dim=2))
        att_wts = torch.stack(att_wts, dim=2)
        att_beta = []
        for i in range(self.n_filters):
            a0 = self.func_softmax(att_wts[:, i])
            att_beta.append(a0)
        att_beta = torch.stack(att_beta, dim=1)
        context = torch.bmm(att_beta, states)
        context = context.view(batch_size, -1)
        return att_beta, context 
Example #25
Source File: pytorch_ops.py    From deep_architect with MIT License 5 votes vote down vote up
def global_pool2d():

    def compile_fn(di, dh):
        (_, _, height, width) = di['in'].size()

        def fn(di):
            x = F.avg_pool2d(di['in'], (height, width))
            x = torch.squeeze(x, 2)
            return {'out': torch.squeeze(x, 2)}

        return fn, []

    return siso_pytorch_module('GlobalAveragePool', compile_fn, {}) 
Example #26
Source File: Losses.py    From BMSG-GAN with MIT License 5 votes vote down vote up
def gen_loss(self, _, fake_samps):
        preds, _, _ = self.dis(fake_samps)
        return self.criterion(th.squeeze(preds),
                              th.ones(fake_samps.shape[0]).to(fake_samps.device)) 
Example #27
Source File: nonlocal_layer.py    From Attention-Gated-Networks with MIT License 5 votes vote down vote up
def _concatenation_proper_down(self, x):
        batch_size = x.size(0)

        # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)
        g_x = self.g(x).view(batch_size, self.inter_channels, -1)

        # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
        # phi  =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2)
        theta_x = self.theta(x)
        downsampled_size = theta_x.size()
        theta_x = theta_x.view(batch_size, self.inter_channels, -1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)

        # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw)
        # phi => (b, 0.5, thw/s**2) ->  (expand) (b, 0.5c, thw/s**2, thw)
        # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw)
        f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \
            phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2))
        f = F.relu(f, inplace=True)

        # psi -> W_psi^t * f -> (b, 0.5c, thw/s**2, thw) -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw)
        f = torch.squeeze(self.psi(f), dim=1)

        # Normalise the relations
        f_div_c = F.softmax(f, dim=1)

        # g(x_j) * f(x_j, x_i)
        # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)
        y = torch.matmul(g_x, f_div_c)
        y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:])

        # upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3)
        y = F.upsample(y, size=x.size()[2:], mode='trilinear')

        # attention block output
        W_y = self.W(y)
        z = W_y + x

        return z 
Example #28
Source File: nonlocal_layer.py    From Attention-Gated-Networks with MIT License 5 votes vote down vote up
def _concatenation_proper(self, x):
        batch_size = x.size(0)

        # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)
        g_x = self.g(x).view(batch_size, self.inter_channels, -1)

        # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
        # phi  =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2)
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)

        # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw)
        # phi => (b, 0.5c, thw/s**2) ->  (expand) (b, 0.5c, thw/s**2, thw)
        # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw)
        f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \
            phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2))
        f = F.relu(f, inplace=True)

        # psi -> W_psi^t * f -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw)
        f = torch.squeeze(self.psi(f), dim=1)

        # Normalise the relations
        f_div_c = F.softmax(f, dim=1)

        # g(x_j) * f(x_j, x_i)
        # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)
        y = torch.matmul(g_x, f_div_c)
        y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        return z 
Example #29
Source File: utils.py    From DPC with MIT License 5 votes vote down vote up
def update(self, pred, tar):
        pred, tar = pred.cpu().numpy(), tar.cpu().numpy()
        pred = np.squeeze(pred)
        tar = np.squeeze(tar)
        for p,t in zip(pred.flat, tar.flat):
            self.mat[p][t] += 1 
Example #30
Source File: utils.py    From DPC with MIT License 5 votes vote down vote up
def update(self, pred, tar):
        pred = torch.squeeze(pred)
        tar = torch.squeeze(tar)
        for i, j in zip(pred, tar):
            i = int(i)
            j = int(j)
            if j not in self.dict.keys():
                self.dict[j] = {'count':0,'correct':0}
            self.dict[j]['count'] += 1
            if i == j:
                self.dict[j]['correct'] += 1