Python torch.remainder() Examples

The following are 30 code examples of torch.remainder(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: encoder.py    From laserembeddings with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def convert_padding_direction(src_tokens,
                              padding_idx,
                              right_to_left=False,
                              left_to_right=False):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
Example #2
Source File: utils.py    From helo_word with Apache License 2.0 6 votes vote down vote up
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
Example #3
Source File: utils.py    From attn2d with MIT License 6 votes vote down vote up
def convert_padding_direction(
    src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False
):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    buffered = torch.empty(0).long()
    if max_len > 0:
        torch.arange(max_len, out=buffered)
    range = buffered.type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
Example #4
Source File: utils.py    From XSum with MIT License 6 votes vote down vote up
def convert_padding_direction(
    src_tokens,
    src_lengths,
    padding_idx,
    right_to_left=False,
    left_to_right=False,
):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if pad_mask.max() == 0:
        # no padding, return early
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
Example #5
Source File: utils.py    From XSum with MIT License 6 votes vote down vote up
def convert_padding_direction(
    src_tokens,
    src_lengths,
    padding_idx,
    right_to_left=False,
    left_to_right=False,
):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if pad_mask.max() == 0:
        # no padding, return early
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
Example #6
Source File: utils.py    From inversecooking with MIT License 6 votes vote down vote up
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
Example #7
Source File: qr_embedding_bag.py    From dlrm with MIT License 6 votes vote down vote up
def forward(self, input, offsets=None, per_sample_weights=None):
        input_q = (input / self.num_collisions).long()
        input_r = torch.remainder(input, self.num_collisions).long()

        embed_q = F.embedding_bag(input_q, self.weight_q, offsets, self.max_norm,
                                  self.norm_type, self.scale_grad_by_freq, self.mode,
                                  self.sparse, per_sample_weights)
        embed_r = F.embedding_bag(input_r, self.weight_r, offsets, self.max_norm,
                                  self.norm_type, self.scale_grad_by_freq, self.mode,
                                  self.sparse, per_sample_weights)

        if self.operation == 'concat':
            embed = torch.cat((embed_q, embed_r), dim=1)
        elif self.operation == 'add':
            embed = embed_q + embed_r
        elif self.operation == 'mult':
            embed = embed_q * embed_r

        return embed 
Example #8
Source File: qr_embedding_bag.py    From optimized-models with Apache License 2.0 6 votes vote down vote up
def forward(self, input, offsets=None, per_sample_weights=None):
        input_q = (input / self.num_collisions).long()
        input_r = torch.remainder(input, self.num_collisions).long()

        embed_q = F.embedding_bag(input_q, self.weight_q, offsets, self.max_norm,
                                  self.norm_type, self.scale_grad_by_freq, self.mode,
                                  self.sparse, per_sample_weights)
        embed_r = F.embedding_bag(input_r, self.weight_r, offsets, self.max_norm,
                                  self.norm_type, self.scale_grad_by_freq, self.mode,
                                  self.sparse, per_sample_weights)

        if self.operation == 'concat':
            embed = torch.cat((embed_q, embed_r), dim=1)
        elif self.operation == 'add':
            embed = embed_q + embed_r
        elif self.operation == 'mult':
            embed = embed_q * embed_r

        return embed 
Example #9
Source File: buffer.py    From reversible-rnn with MIT License 6 votes vote down vote up
def forward(ctx, h, z, buf, mask, slice_dim=0):
        ctx.save_for_backward(h, z, mask)

        # Shift buffer left, enlarging if needed, then store modulus of h in buffer.
        if buf is not None:
            h_mod = torch.remainder(h[:, slice_dim:], 2**forget_radix)
            buf.overflow_mul(2**forget_radix, mask[:, slice_dim:])
            buf.add(h_mod, mask[:, slice_dim:])

        # Multiply h by z/(2**forget_radix).
        # Have to do extra work in case h is negative.
        sign_bits = h.__and__(sign_bit)  
        one_bits = negative_bits * -1 * torch.clamp(sign_bits, min=-1)
        h = h.__rshift__(forget_radix * mask) 
        h = h.__or__(one_bits * mask)
        h = mask*h*z + (1-mask)*h

        # Store modulus of buffer in h then divide buffer by z.
        if buf is not None:
            buf_mod = buf.mod(z[:,slice_dim:])
            h[:,slice_dim:] = h[:,slice_dim:] + buf_mod*mask[:, slice_dim:]
            buf.div(z[:,slice_dim:], mask[:, slice_dim:])

        return h 
Example #10
Source File: utils.py    From crosentgec with GNU General Public License v3.0 6 votes vote down vote up
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
Example #11
Source File: utils.py    From fairseq with MIT License 6 votes vote down vote up
def convert_padding_direction(
    src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False
):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    buffered = torch.empty(0).long()
    if max_len > 0:
        torch.arange(max_len, out=buffered)
    range = buffered.type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
Example #12
Source File: utils.py    From training_results_v0.5 with Apache License 2.0 6 votes vote down vote up
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
    assert right_to_left ^ left_to_right
    pad_mask = src_tokens.eq(padding_idx)
    if not pad_mask.any():
        # no padding, return early
        return src_tokens
    if left_to_right and not pad_mask[:, 0].any():
        # already right padded
        return src_tokens
    if right_to_left and not pad_mask[:, -1].any():
        # already left padded
        return src_tokens
    max_len = src_tokens.size(1)
    range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
    num_pads = pad_mask.long().sum(dim=1, keepdim=True)
    if right_to_left:
        index = torch.remainder(range - num_pads, max_len)
    else:
        index = torch.remainder(range + num_pads, max_len)
    return src_tokens.gather(1, index) 
Example #13
Source File: buffer.py    From reversible-rnn with MIT License 6 votes vote down vote up
def mod_divide(self, forget_radix, mask):
        # Only called in reverse process. 
        mask = mask.long()
        self.counter -= 1

        buf_mod = torch.remainder(self.curr_buffer, 2**forget_radix).int()
        self.curr_buffer = mask*(self.curr_buffer/(2**forget_radix)) + (1-mask)*self.curr_buffer

        overflowed = self.overflow_detect.__and__(2**(self.counter % 8))[self.counter // 8]
        if overflowed:
            self.curr_buffer = self.past_buffers[:,:,-1]
            if self.past_buffers.size(2) > 1:
                self.past_buffers = self.past_buffers[:,:,:-1]

        return buf_mod

###############################################################################
# Multiply/divide fixed point numbers with buffer
############################################################################### 
Example #14
Source File: test.py    From Count-Sketch-Optimizers with Apache License 2.0 5 votes vote down vote up
def log_uniform_sample(N, size):
    log_N = math.log(N)
    x = torch.Tensor(size).uniform_(0, 1)
    value = torch.exp(x * log_N).long() - 1
    return torch.remainder(value, N) 
Example #15
Source File: util.py    From PyTorch_GBW_LM with Apache License 2.0 5 votes vote down vote up
def log_uniform_sample(N, size):
    log_N = math.log(N)
    x = torch.Tensor(size).uniform_(0, 1)
    value = torch.exp(x * log_N).long() - 1
    return torch.remainder(value, N) 
Example #16
Source File: test.py    From PyTorch_GBW_LM with Apache License 2.0 5 votes vote down vote up
def log_uniform_sample(N, size):
    log_N = math.log(N)
    x = torch.Tensor(size).uniform_(0, 1)
    value = torch.exp(x * log_N).long() - 1
    return torch.remainder(value, N) 
Example #17
Source File: test.py    From Graph-Transformer with Apache License 2.0 5 votes vote down vote up
def log_uniform_sample(N, size):
    log_N = math.log(N)
    x = torch.Tensor(size).uniform_(0, 1)
    value = torch.exp(x * log_N).long() - 1
    return torch.remainder(value, N) 
Example #18
Source File: arithmetics.py    From heat with MIT License 5 votes vote down vote up
def remainder(t1, t2):
    """
    Element-wise division remainder of values of operand t1 by values of operand t2 (i.e. t1 % t2), not commutative.
    Takes the two operands (scalar or tensor) whose elements are to be divided (operand 1 by operand 2) as arguments.

    Parameters
    ----------
    t1: tensor or scalar
        The first operand whose values are divided
    t2: tensor or scalar
        The second operand by whose values is divided

    Returns
    -------
    result: ht.DNDarray
        A tensor containing the remainder of the element-wise division of t1 by t2.
        It has the same sign as the devisor t2.

    Examples:
    ---------
    >>> import heat as ht
    >>> ht.mod(2, 2)
    tensor([0])

    >>> T1 = ht.int32([[1, 2], [3, 4]])
    >>> T2 = ht.int32([[2, 2], [2, 2]])
    >>> ht.mod(T1, T2)
    tensor([[1, 0],
            [1, 0]], dtype=torch.int32)

    >>> s = 2
    >>> ht.mod(s, T1)
    tensor([[0, 0]
            [2, 2]], dtype=torch.int32)
    """
    return operations.__binary_op(torch.remainder, t1, t2) 
Example #19
Source File: arithmetics.py    From heat with MIT License 5 votes vote down vote up
def mod(t1, t2):
    """
    Element-wise division remainder of values of operand t1 by values of operand t2 (i.e. t1 % t2), not commutative.
    Takes the two operands (scalar or tensor) whose elements are to be divided (operand 1 by operand 2) as arguments.

    Currently t1 and t2 are just passed to remainder.

    Parameters
    ----------
    t1: tensor or scalar
        The first operand whose values are divided
    t2: tensor or scalar
        The second operand by whose values is divided

    Returns
    -------
    result: ht.DNDarray
        A tensor containing the remainder of the element-wise division of t1 by t2.
        It has the same sign as the devisor t2.

    Examples:
    ---------
    >>> import heat as ht
    >>> ht.mod(2, 2)
    tensor([0])

    >>> T1 = ht.int32([[1, 2], [3, 4]])
    >>> T2 = ht.int32([[2, 2], [2, 2]])
    >>> ht.mod(T1, T2)
    tensor([[1, 0],
            [1, 0]], dtype=torch.int32)

    >>> s = 2
    >>> ht.mod(s, T1)
    tensor([[0, 0]
            [2, 2]], dtype=torch.int32)
    """
    return remainder(t1, t2) 
Example #20
Source File: arithmetics.py    From heat with MIT License 5 votes vote down vote up
def fmod(t1, t2):
    """
    Element-wise division remainder of values of operand t1 by values of operand t2 (i.e. C Library function fmod), not commutative.
    Takes the two operands (scalar or tensor, both may contain floating point number) whose elements are to be
    divided (operand 1 by operand 2) as arguments.

    Parameters
    ----------
    t1: tensor or scalar
        The first operand whose values are divided (may be floats)
    t2: tensor or scalar
        The second operand by whose values is divided (may be floats)

    Returns
    -------
    result: ht.DNDarray
        A tensor containing the remainder of the element-wise division (i.e. floating point values) of t1 by t2.
        It has the sign as the dividend t1.

    Examples:
    ---------
    >>> import heat as ht
    >>> ht.fmod(2.0, 2.0)
    tensor([0.])

    >>> T1 = ht.float32([[1, 2], [3, 4]])
    >>> T2 = ht.float32([[2, 2], [2, 2]])
    >>> ht.fmod(T1, T2)
    tensor([[1., 0.],
            [1., 0.]])

    >>> s = 2.0
    >>> ht.fmod(s, T1)
    tensor([[0., 0.]
            [2., 2.]])
    """
    return operations.__binary_op(torch.fmod, t1, t2) 
Example #21
Source File: util.py    From Count-Sketch-Optimizers with Apache License 2.0 5 votes vote down vote up
def log_uniform_sample(N, size):
    log_N = math.log(N)
    x = torch.Tensor(size).uniform_(0, 1)
    value = torch.exp(x * log_N).long() - 1
    return torch.remainder(value, N) 
Example #22
Source File: toep_functions.py    From torchkbnufft with MIT License 5 votes vote down vote up
def reflect_conj_concat(kern, dim):
    """Reflects and conjugates kern before concatenating along dim.

    Args:
        kern (tensor): One half of a full, Hermitian-symmetric kernel.
        dim (int): The integer across which to apply Hermitian symmetry.

    Returns:
        tensor: The full FFT kernel after Hermitian-symmetric reflection.
    """
    dtype, device = kern.dtype, kern.device
    dim = -1 - dim
    flipdims = tuple(torch.arange(abs(dim)) + dim)

    # calculate size of central z block
    zblockshape = torch.tensor(kern.shape)
    zblockshape[dim] = 1
    zblock = torch.zeros(*zblockshape, dtype=dtype, device=device)

    # conjugation array
    conj_arr = torch.tensor([1, -1], dtype=dtype, device=device)
    conj_arr = conj_arr.unsqueeze(0).unsqueeze(0)
    while conj_arr.ndim < kern.ndim:
        conj_arr = conj_arr.unsqueeze(-1)

    # reflect the original block and conjugate it
    tmpblock = conj_arr * kern
    for d in flipdims:
        tmpblock = tmpblock.index_select(
            d,
            torch.remainder(
                -1 * torch.arange(tmpblock.shape[d], device=device), tmpblock.shape[d])
        )
    tmpblock = torch.cat(
        (zblock, tmpblock.narrow(dim, 1, tmpblock.shape[dim]-1)), dim)

    # concatenate and return
    return torch.cat((kern, tmpblock), dim) 
Example #23
Source File: toep_functions.py    From torchkbnufft with MIT License 5 votes vote down vote up
def hermitify(kern, dim):
    """Enforce Hermitian symmetry.

    This function takes an approximately Hermitian-symmetric kernel and
    enforces Hermitian symmetry by calcualting a tensor that reverses the
    coordinates and conjugates the original, then averaging that tensor with
    the original.

    Args:
        kern (tensor): An approximately Hermitian-symmetric kernel.
        dim (int): The last imaging dimension.

    Returns:
        tensor: A Hermitian-symmetric kernel.
    """
    dtype, device = kern.dtype, kern.device
    dim = -1 - dim + kern.ndim

    start = kern.clone()

    # reverse coordinates for each dimension
    for d in range(dim, kern.ndim):
        kern = kern.index_select(
            d,
            torch.remainder(
                -1 * torch.arange(kern.shape[d], device=device), kern.shape[d])
        )

    # conjugate
    conj_arr = torch.tensor([1, -1], dtype=dtype, device=device)
    conj_arr = conj_arr.unsqueeze(0).unsqueeze(0)
    while conj_arr.ndim < kern.ndim:
        conj_arr = conj_arr.unsqueeze(-1)
    kern = conj_arr * kern

    # take the average
    kern = (start + kern) / 2

    return kern 
Example #24
Source File: identifier.py    From kaggle-humpback with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def loss(self, outputs, labels, **_):
        if self.model.training:
            labels_flip = labels +  self.num_classes // 2
            labels_flip = torch.remainder(labels_flip, self.num_classes)
            if labels_flip.dim() == 1:
                labels_flip = labels_flip.unsqueeze(-1)
            onehot = torch.zeros(outputs.size()).cuda()
            onehot.scatter_(1, labels_flip, 1)
            onehot_invert = (onehot == 0).float()
            assert onehot_invert.size() == outputs.size()
            outputs = outputs * onehot_invert - onehot_invert
            return self.criterion(outputs, labels)
        return torch.FloatTensor([0]) 
Example #25
Source File: qr_embedding_bag.py    From dlrm with MIT License 5 votes vote down vote up
def __init__(self, num_categories, embedding_dim, num_collisions,
                 operation='mult', max_norm=None, norm_type=2.,
                 scale_grad_by_freq=False, mode='mean', sparse=False,
                 _weight=None):
        super(QREmbeddingBag, self).__init__()

        assert operation in ['concat', 'mult', 'add'], 'Not valid operation!'

        self.num_categories = num_categories
        if isinstance(embedding_dim, int) or len(embedding_dim) == 1:
            self.embedding_dim = [embedding_dim, embedding_dim]
        else:
            self.embedding_dim = embedding_dim
        self.num_collisions = num_collisions
        self.operation = operation
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq

        if self.operation == 'add' or self.operation == 'mult':
            assert self.embedding_dim[0] == self.embedding_dim[1], \
                'Embedding dimensions do not match!'

        self.num_embeddings = [int(np.ceil(num_categories / num_collisions)),
            num_collisions]

        if _weight is None:
            self.weight_q = Parameter(torch.Tensor(self.num_embeddings[0], self.embedding_dim[0]))
            self.weight_r = Parameter(torch.Tensor(self.num_embeddings[1], self.embedding_dim[1]))
            self.reset_parameters()
        else:
            assert list(_weight[0].shape) == [self.num_embeddings[0], self.embedding_dim[0]], \
                'Shape of weight for quotient table does not match num_embeddings and embedding_dim'
            assert list(_weight[1].shape) == [self.num_embeddings[1], self.embedding_dim[1]], \
                'Shape of weight for remainder table does not match num_embeddings and embedding_dim'
            self.weight_q = Parameter(_weight[0])
            self.weight_r = Parameter(_weight[1])
        self.mode = mode
        self.sparse = sparse 
Example #26
Source File: buffer.py    From reversible-rnn with MIT License 5 votes vote down vote up
def mod(self, divisor):
        divisor = divisor.long()
        return torch.remainder(self.curr_buffer, divisor).int() 
Example #27
Source File: qr_embedding_bag.py    From optimized-models with Apache License 2.0 5 votes vote down vote up
def __init__(self, num_categories, embedding_dim, num_collisions,
                 operation='mult', max_norm=None, norm_type=2.,
                 scale_grad_by_freq=False, mode='mean', sparse=False,
                 _weight=None):
        super(QREmbeddingBag, self).__init__()

        assert operation in ['concat', 'mult', 'add'], 'Not valid operation!'

        self.num_categories = num_categories
        if isinstance(embedding_dim, int) or len(embedding_dim) == 1:
            self.embedding_dim = [embedding_dim, embedding_dim]
        else:
            self.embedding_dim = embedding_dim
        self.num_collisions = num_collisions
        self.operation = operation
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq

        if self.operation == 'add' or self.operation == 'mult':
            assert self.embedding_dim[0] == self.embedding_dim[1], \
                'Embedding dimensions do not match!'

        self.num_embeddings = [int(np.ceil(num_categories / num_collisions)),
            num_collisions]

        if _weight is None:
            self.weight_q = Parameter(torch.Tensor(self.num_embeddings[0], self.embedding_dim[0]))
            self.weight_r = Parameter(torch.Tensor(self.num_embeddings[1], self.embedding_dim[1]))
            self.reset_parameters()
        else:
            assert list(_weight[0].shape) == [self.num_embeddings[0], self.embedding_dim[0]], \
                'Shape of weight for quotient table does not match num_embeddings and embedding_dim'
            assert list(_weight[1].shape) == [self.num_embeddings[1], self.embedding_dim[1]], \
                'Shape of weight for remainder table does not match num_embeddings and embedding_dim'
            self.weight_q = Parameter(_weight[0])
            self.weight_r = Parameter(_weight[1])
        self.mode = mode
        self.sparse = sparse 
Example #28
Source File: buffer.py    From reversible-rnn with MIT License 5 votes vote down vote up
def forward(ctx, h, z, buf, mask, slice_dim=0):
        buf.mul(z[:,slice_dim:], mask[:, slice_dim:])
        h_mod = torch.remainder(h[:,slice_dim:], z[:,slice_dim:])
        buf.add(h_mod, mask[:, slice_dim:])
        h[h<0] = mask[h<0]*(h[h<0]-(z[h<0]-1)) + (1-mask[h<0])*h[h<0]
        h = mask*(h / z) + (1-mask)*h 

        h = mask*(h * (2**forget_radix)) + (1-mask)*h
        buf_mod = buf.mod_divide(forget_radix, mask[:, slice_dim:])
        h[:,slice_dim:] = mask[:, slice_dim:]*(h[:,slice_dim:].__or__(buf_mod)) +\
            (1-mask[:, slice_dim:])*h[:,slice_dim:]

        return h 
Example #29
Source File: interp_functions.py    From torchkbnufft with MIT License 4 votes vote down vote up
def calc_coef_and_indices(tm, kofflist, Jval, table, centers, L, dims, conjcoef=False):
    """Calculates interpolation coefficients and on-grid indices.

    Args:
        tm (tensor): normalized frequency locations.
        kofflist (tensor): A tensor with offset locations to first elements in
            list of nearest neighbords.
        Jval (tensor): A tuple-like tensor for how much to increment offsets.
        table (list): A list of tensors tabulating a Kaiser-Bessel
            interpolation kernel.
        centers (tensor): A tensor with the center locations of the table for
            each dimension.
        L (tensor): A tensor with the table size in each dimension.
        dims (tensor): A tensor with image dimensions.
        conjcoef (boolean, default=False): A boolean for whether to compute
            normal or complex conjugate interpolation coefficients
            (conjugate needed for adjoint).

    Returns:
        tuple: A tuple with interpolation coefficients and indices.
    """
    # type values
    dtype = tm.dtype
    device = tm.device
    int_type = torch.long

    # array shapes
    M = tm.shape[1]
    ndims = tm.shape[0]

    # indexing locations
    gridind = (kofflist + Jval.unsqueeze(1)).to(dtype)
    distind = torch.round(
        (tm - gridind) * L.unsqueeze(1)).to(dtype=int_type)
    gridind = gridind.to(int_type)

    arr_ind = torch.zeros((M,), dtype=int_type, device=device)
    coef = torch.stack((
        torch.ones(M, dtype=dtype, device=device),
        torch.zeros(M, dtype=dtype, device=device)
    ))

    for d in range(ndims):  # spatial dimension
        if conjcoef:
            coef = conj_complex_mult(
                coef,
                table[d][:, distind[d, :] + centers[d]],
                dim=0
            )
        else:
            coef = complex_mult(
                coef,
                table[d][:, distind[d, :] + centers[d]],
                dim=0
            )
        arr_ind = arr_ind + torch.remainder(gridind[d, :], dims[d]).view(-1) * \
            torch.prod(dims[d + 1:])

    return coef, arr_ind 
Example #30
Source File: trainer.py    From advex-uar with Apache License 2.0 4 votes vote down vote up
def _val_epoch(self, epoch):
        self.model.eval()

        val_std_loss = Metric('val_std_loss')
        val_std_acc = Metric('val_std_acc')

        val_adv_acc = Metric('val_adv_acc')
        val_adv_loss = Metric('val_adv_loss')
        val_max_adv_acc = Metric('val_max_adv_acc')
        val_max_adv_loss = Metric('val_max_adv_loss')

        for batch_idx, (data, target) in enumerate(self.val_loader):
            if self.cuda:
                data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            with torch.no_grad():
                output = self.model(data)
                val_std_loss.update(F.cross_entropy(output, target))
                val_std_acc.update(accuracy(output, target))
            if self.attack:
                rand_target = torch.randint(
                    0, len(self.val_dataset.classes) - 1, target.size(),
                    dtype=target.dtype, device='cuda')
                rand_target = torch.remainder(target + rand_target + 1, len(self.val_dataset.classes))
                data_adv = self.attack(self.model, data, rand_target,
                                       avoid_target=False, scale_eps=self.scale_eps)
                data_max_adv = self.attack(self.model, data, rand_target, avoid_target=False, scale_eps=False)
                with torch.no_grad():
                    output_adv = self.model(data_adv)
                    val_adv_loss.update(F.cross_entropy(output_adv, target))
                    val_adv_acc.update(accuracy(output_adv, target))
                    
                    output_max_adv = self.model(data_max_adv)
                    val_max_adv_loss.update(F.cross_entropy(output_max_adv, target))
                    val_max_adv_acc.update(accuracy(output_max_adv, target))
            self.model.eval()

        if hvd.rank() == 0:
            log_dict = {'val_std_loss':val_std_loss.avg.item(),
                        'val_std_acc':val_std_acc.avg.item(),
                        'val_adv_loss':val_adv_loss.avg.item(),
                        'val_adv_acc':val_adv_acc.avg.item(),
                        'val_adv_loss':val_max_adv_loss.avg.item(),
                        'val_max_adv_acc':val_max_adv_acc.avg.item()}
            self.logger.log(log_dict, epoch)

        if self.verbose:
            print(log_dict)

        self.optimizer.synchronize()
        self.optimizer.zero_grad()