Python torch.cumprod() Examples

The following are 19 code examples of torch.cumprod(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: functions.py    From fairseq with MIT License 6 votes vote down vote up
def safe_cumprod(tensor, dim: int, eps: float = 1e-10):
    """
    An implementation of cumprod to prevent precision issue.
    cumprod(x)
    = [x1, x1x2, x1x2x3, ....]
    = [exp(log(x1)), exp(log(x1) + log(x2)), exp(log(x1) + log(x2) + log(x3)), ...]
    = exp(cumsum(log(x)))
    """

    if (tensor + eps < 0).any().item():
        raise RuntimeError(
            "Safe cumprod can only take non-negative tensors as input."
            "Consider use torch.cumprod if you want to calculate negative values."
        )

    log_tensor = torch.log(tensor + eps)
    cumsum_log_tensor = torch.cumsum(log_tensor, dim)
    exp_cumsum_log_tensor = torch.exp(cumsum_log_tensor)
    return exp_cumsum_log_tensor 
Example #2
Source File: memory.py    From pytorch-dnc with MIT License 6 votes vote down vote up
def allocate(self, usage, write_gate):
    # ensure values are not too small prior to cumprod.
    usage = δ + (1 - δ) * usage
    batch_size = usage.size(0)
    # free list
    sorted_usage, φ = T.topk(usage, self.mem_size, dim=1, largest=False)

    # cumprod with exclusive=True
    # https://discuss.pytorch.org/t/cumprod-exclusive-true-equivalences/2614/8
    v = var(sorted_usage.data.new(batch_size, 1).fill_(1))
    cat_sorted_usage = T.cat((v, sorted_usage), 1)
    prod_sorted_usage = T.cumprod(cat_sorted_usage, 1)[:, :-1]

    sorted_allocation_weights = (1 - sorted_usage) * prod_sorted_usage.squeeze()

    # construct the reverse sorting index https://stackoverflow.com/questions/2483696/undo-or-reverse-argsort-python
    _, φ_rev = T.topk(φ, k=self.mem_size, dim=1, largest=False)
    allocation_weights = sorted_allocation_weights.gather(1, φ_rev.long())

    return allocation_weights.unsqueeze(1), usage 
Example #3
Source File: fake_ops.py    From pytorch-dnc with MIT License 6 votes vote down vote up
def fake_cumprod(vb):
    """
    args:
        vb:  [hei x wid]
          -> NOTE: we are lazy here so now it only supports cumprod along wid
    """
    # real_cumprod = torch.cumprod(vb.data, 1)
    vb = vb.unsqueeze(0)
    mul_mask_vb = Variable(torch.zeros(vb.size(2), vb.size(1), vb.size(2))).type_as(vb)
    for i in range(vb.size(2)):
       mul_mask_vb[i, :, :i+1] = 1
    add_mask_vb = 1 - mul_mask_vb
    vb = vb.expand_as(mul_mask_vb) * mul_mask_vb + add_mask_vb
    # vb = torch.prod(vb, 2).transpose(0, 2)                # 0.1.12
    vb = torch.prod(vb, 2, keepdim=True).transpose(0, 2)    # 0.2.0
    # print(real_cumprod - vb.data) # NOTE: checked, ==0
    return vb 
Example #4
Source File: functions.py    From fairseq with MIT License 6 votes vote down vote up
def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
    """
    Implementing exclusive cumprod.
    There is cumprod in pytorch, however there is no exclusive mode.
    cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
    exclusive means cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
    """
    tensor_size = list(tensor.size())
    tensor_size[dim] = 1
    return_tensor = safe_cumprod(
        torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), dim=dim, eps=eps
    )

    if dim == 0:
        return return_tensor[:-1]
    elif dim == 1:
        return return_tensor[:, :-1]
    elif dim == 2:
        return return_tensor[:, :, :-1]
    else:
        raise RuntimeError("Cumprod on dimension 3 and more is not implemented") 
Example #5
Source File: fxc.py    From ASNG-NAS with MIT License 5 votes vote down vote up
def forward(self, c):
        self.eval += 1
        fx = torch.mean(self.sigmod(self.x), 1)
        # fx = self.nc - torch.sum(c * xx, 0)  # onemax like
        fx = self.nc - torch.cumprod(c * fx, 0).sum()  # leading ones like
        fx += 1e-8 * torch.sum(self.x ** 2)
        return fx, fx 
Example #6
Source File: ipd_DiCE.py    From LOLA_DiCE with MIT License 5 votes vote down vote up
def dice_objective(self):
        self_logprobs = torch.stack(self.self_logprobs, dim=1)
        other_logprobs = torch.stack(self.other_logprobs, dim=1)
        values = torch.stack(self.values, dim=1)
        rewards = torch.stack(self.rewards, dim=1)

        # apply discount:
        cum_discount = torch.cumprod(hp.gamma * torch.ones(*rewards.size()), dim=1)/hp.gamma
        discounted_rewards = rewards * cum_discount
        discounted_values = values * cum_discount

        # stochastics nodes involved in rewards dependencies:
        dependencies = torch.cumsum(self_logprobs + other_logprobs, dim=1)

        # logprob of each stochastic nodes:
        stochastic_nodes = self_logprobs + other_logprobs

        # dice objective:
        dice_objective = torch.mean(torch.sum(magic_box(dependencies) * discounted_rewards, dim=1))

        if hp.use_baseline:
            # variance_reduction:
            baseline_term = torch.mean(torch.sum((1 - magic_box(stochastic_nodes)) * discounted_values, dim=1))
            dice_objective = dice_objective + baseline_term

        return -dice_objective # want to minimize -objective 
Example #7
Source File: ipd_DiCE_om.py    From LOLA_DiCE with MIT License 5 votes vote down vote up
def dice_objective(self):
        self_logprobs = torch.stack(self.self_logprobs, dim=1)
        other_logprobs = torch.stack(self.other_logprobs, dim=1)
        values = torch.stack(self.values, dim=1)
        rewards = torch.stack(self.rewards, dim=1)

        # apply discount:
        cum_discount = torch.cumprod(hp.gamma * torch.ones(*rewards.size()), dim=1)/hp.gamma
        discounted_rewards = rewards * cum_discount
        discounted_values = values * cum_discount

        # stochastics nodes involved in rewards dependencies:
        dependencies = torch.cumsum(self_logprobs + other_logprobs, dim=1)

        # logprob of each stochastic nodes:
        stochastic_nodes = self_logprobs + other_logprobs

        # dice objective:
        dice_objective = torch.mean(torch.sum(magic_box(dependencies) * discounted_rewards, dim=1))

        if hp.use_baseline:
            # variance_reduction:
            baseline_term = torch.mean(torch.sum((1 - magic_box(stochastic_nodes)) * discounted_values, dim=1))
            dice_objective = dice_objective + baseline_term

        return -dice_objective # want to minimize -objective 
Example #8
Source File: dynamic_write_head.py    From pytorch-dnc with MIT License 5 votes vote down vote up
def _allocation(self, usage_vb, epsilon=1e-6):
        """
        computes allocation by sorting usage, a = a_t[\phi_t[j]]
        variables needed:
            usage_vb: [batch_size x mem_hei]
                   -> indicating current memory usage, this is equal to u_t in
                      the paper when we only have one write head, but for
                      multiple write heads, one should update the usage while
                      iterating through the write heads to take into account the
                      allocation returned by this function
        returns:
            alloc_vb: [batch_size x num_write_heads x mem_hei]
        """
        # ensure values are not too small prior to cumprod
        usage_vb = epsilon + (1 - epsilon) * usage_vb
        # NOTE: we sort usage in ascending order
        sorted_usage_vb, indices_vb = torch.topk(usage_vb, k=self.mem_hei, dim=1, largest=False)
        # to imitate tf.cumrprod(exclusive=True) https://discuss.pytorch.org/t/cumprod-exclusive-true-equivalences/2614/8
        cat_sorted_usage_vb = torch.cat((Variable(torch.ones(self.batch_size, 1)).type(self.dtype), sorted_usage_vb), 1)[:, :-1]
        # TODO: seems we have to wait for this PR: https://github.com/pytorch/pytorch/pull/1439
        prod_sorted_usage_vb = fake_cumprod(cat_sorted_usage_vb)
        # prod_sorted_usage_vb = torch.cumprod(cat_sorted_usage_vb, dim=1) # TODO: use this once the PR is ready
        # alloc_weight_vb = (1 - sorted_usage_vb) * prod_sorted_usage_vb  # equ. (1)            # 0.1.12
        alloc_weight_vb = (1 - sorted_usage_vb) * prod_sorted_usage_vb.squeeze()  # equ. (1)    # 0.2.0
        _, indices_vb = torch.topk(indices_vb, k=self.mem_hei, dim=1, largest=False)
        alloc_weight_vb = alloc_weight_vb.gather(1, indices_vb)
        return alloc_weight_vb 
Example #9
Source File: mocha.py    From neural_sp with Apache License 2.0 5 votes vote down vote up
def exclusive_cumprod(x):
    """Exclusive cumulative product [a, b, c] => [1, a, a * b].

        Args:
            x (FloatTensor): `[B, H, qlen, klen]`
        Returns:
            x (FloatTensor): `[B, H, qlen, klen]`

    """
    return torch.cumprod(torch.cat([x.new_ones(x.size(0), x.size(1), x.size(2), 1),
                                    x[:, :, :, :-1]], dim=-1), dim=-1) 
Example #10
Source File: arithmetics.py    From heat with MIT License 5 votes vote down vote up
def cumprod(a, axis, dtype=None, out=None):
    """
    Return the cumulative product of elements along a given axis.

    Parameters
    ----------
    a : DNDarray
        Input array.
    axis : int
        Axis along which the cumulative product is computed.
    dtype : dtype, optional
        Type of the returned array, as well as of the accumulator in which
        the elements are multiplied.  If *dtype* is not specified, it
        defaults to the dtype of `a`, unless `a` has an integer dtype with
        a precision less than that of the default platform integer.  In
        that case, the default platform integer is used instead.
    out : DNDarray, optional
        Alternative output array in which to place the result. It must
        have the same shape and buffer length as the expected output
        but the type of the resulting values will be cast if necessary.

    Returns
    -------
    cumprod : DNDarray
        A new array holding the result is returned unless `out` is
        specified, in which case a reference to out is returned.

    Examples
    --------
    >>> a = ht.full((3,3), 2)
    >>> ht.cumprod(a, 0)
    tensor([[2., 2., 2.],
            [4., 4., 4.],
            [8., 8., 8.])
    """
    return operations.__cum_op(a, torch.cumprod, MPI.PROD, torch.mul, 1, axis, dtype, out)


# Alias support 
Example #11
Source File: metric.py    From pt-ranking.github.io with MIT License 5 votes vote down vote up
def tor_err_at_ks(sys_sorted_labels, ks=None, multi_level_rele=True, max_rele_level=None):
    '''
    :param sys_sorted_labels: the standard labels sorted in descending order according to predicted relevance scores
    :param ks:
    :param multi_level_rele:
    :param max_rele_level:
    :return:
    '''
    valid_max = sys_sorted_labels.size(0)
    used_ks = [k for k in ks if k <= valid_max] if valid_max < max(ks) else ks

    max_cutoff = max(used_ks)
    inds = torch.from_numpy(np.asarray(used_ks) - 1)
    if multi_level_rele:
        positions = torch.arange(max_cutoff) + 1.0
        expt_ranks = 1.0 / positions    # expected stop positions

        tor_max_rele = torch.Tensor([max_rele_level]).float()
        satis_pros = (torch.pow(2.0, sys_sorted_labels[0:max_cutoff]) - 1.0)/torch.pow(2.0, tor_max_rele)
        non_satis_pros = torch.ones(max_cutoff) - satis_pros
        cum_non_satis_pros = torch.cumprod(non_satis_pros, dim=0)

        cascad_non_satis_pros = positions
        cascad_non_satis_pros[1:max_cutoff] = cum_non_satis_pros[0:max_cutoff-1]
        expt_satis_ranks = expt_ranks * satis_pros * cascad_non_satis_pros  # w.r.t. all rank positions
        err_at_ranks = torch.cumsum(expt_satis_ranks, dim=0)

        err_at_ks = err_at_ranks[inds]
        if valid_max < max(ks):
            padded_err_at_ks = torch.zeros(len(ks))
            padded_err_at_ks[0:len(used_ks)] = err_at_ks
            return padded_err_at_ks
        else:
            return err_at_ks
    else:
        raise NotImplementedError 
Example #12
Source File: adhoc_metric.py    From pt-ranking.github.io with MIT License 5 votes vote down vote up
def torch_ideal_err(sorted_labels, k=10, point=True, gpu=False):
	assert sorted_labels.size(0) >= k

	max_label = torch.max(sorted_labels)

	labels = sorted_labels[0:k]
	satis_pros = (torch.pow(2.0, labels) - 1.0) / torch.pow(2.0, max_label)

	unsatis_pros = torch.ones_like(labels) - satis_pros
	cum_unsatis_pros = torch.cumprod(unsatis_pros, dim=0)

	if gpu:
		ranks = torch.arange(k).type(tensor) + 1.0
		expt_ranks = 1.0 / ranks
	else:
		ranks = torch.arange(k) + 1.0
		expt_ranks = 1.0 / ranks

	cascad_unsatis_pros = ranks
	cascad_unsatis_pros[1:k] = cum_unsatis_pros[0:k-1]

	expt_satis_ranks = expt_ranks * satis_pros * cascad_unsatis_pros  # w.r.t. all rank positions

	if point: # a specific position
		ideal_err = torch.sum(expt_satis_ranks, dim=0)
		return ideal_err

	else:
		ideal_err_at_ks = torch.cumsum(expt_satis_ranks, dim=0)
		return ideal_err_at_ks 
Example #13
Source File: adhoc_metric.py    From pt-ranking.github.io with MIT License 5 votes vote down vote up
def torch_batch_ideal_err(batch_sorted_labels, k=10, gpu=False, point=True):
	assert batch_sorted_labels.size(1) > k

	batch_max = torch.max(batch_sorted_labels, dim=1)

	batch_labels = batch_sorted_labels[:, 0:k]
	batch_satis_pros = (torch.pow(2.0, batch_labels) - 1.0) / torch.pow(2.0, batch_max)

	batch_unsatis_pros = torch.ones(batch_labels) - batch_satis_pros
	batch_cum_unsatis_pros = torch.cumprod(batch_unsatis_pros, dim=1)

	positions = torch.arange(k) + 1.0
	positions = positions.view(1, -1)
	positions = torch.repeat_interleave(positions, batch_sorted_labels.size(0), dim=0)

	batch_expt_ranks = 1.0 / positions

	cascad_unsatis_pros = positions
	cascad_unsatis_pros[:, 1:k] = batch_cum_unsatis_pros[:, 0:k-1]

	expt_satis_ranks = batch_expt_ranks * batch_satis_pros * cascad_unsatis_pros  # w.r.t. all rank positions

	if point:
		batch_errs = torch.sum(expt_satis_ranks, dim=1)
		return batch_errs
	else:
		batch_err_at_ks = torch.cumsum(expt_satis_ranks, dim=1)
		return batch_err_at_ks 
Example #14
Source File: adhoc_metric.py    From pt-ranking.github.io with MIT License 5 votes vote down vote up
def torch_nerr_at_ks(sys_sorted_labels, ideal_sorted_labels, ks=None, multi_level_rele=True):
	'''
	:param sys_sorted_labels: the standard labels sorted in descending order according to predicted relevance scores
	:param ks:
	:param multi_level_rele:
	:return:
	'''
	valid_max = sys_sorted_labels.size(0)
	used_ks = [k for k in ks if k <= valid_max] if valid_max < max(ks) else ks

	max_cutoff = max(used_ks)
	inds = torch.from_numpy(np.asarray(used_ks) - 1)
	if multi_level_rele:
		positions = torch.arange(max_cutoff) + 1.0
		expt_ranks = 1.0 / positions    # expected stop positions

		tor_max_rele = torch.max(sys_sorted_labels)
		satis_pros = (torch.pow(2.0, sys_sorted_labels[0:max_cutoff]) - 1.0)/torch.pow(2.0, tor_max_rele)
		non_satis_pros = torch.ones(max_cutoff) - satis_pros
		cum_non_satis_pros = torch.cumprod(non_satis_pros, dim=0)

		cascad_non_satis_pros = positions
		cascad_non_satis_pros[1:max_cutoff] = cum_non_satis_pros[0:max_cutoff-1]
		expt_satis_ranks = expt_ranks * satis_pros * cascad_non_satis_pros  # w.r.t. all rank positions
		err_at_ks = torch.cumsum(expt_satis_ranks, dim=0)
		#print(err_at_ks)

		ideal_err_at_ks = torch_ideal_err(ideal_sorted_labels, k=max_cutoff, point=False)
		tmp_nerr_at_ks = err_at_ks/ideal_err_at_ks

		nerr_at_ks = tmp_nerr_at_ks[inds]
		if valid_max < max(ks):
			padded_nerr_at_ks = torch.zeros(len(ks))
			padded_nerr_at_ks[0:len(used_ks)] = nerr_at_ks
			return padded_nerr_at_ks
		else:
			return nerr_at_ks
	else:
		raise NotImplementedError 
Example #15
Source File: kcenter.py    From clusternet with MIT License 5 votes vote down vote up
def __call__(self, x):
        '''
        Evaluates E_S[softmax_{customers} min_{i \in S} dist(customer, i)] where 
        the expectation is over the set of facility locations S. Every 
        location is included in S independently with probability x_i. 
        '''
        x_sort = x[self.order]
        probs = 1 - torch.cumprod(1 - x_sort, dim=1)
        vals = self.dmax + (self.m*probs).sum(dim=1)
        if self.hardmax:
            return vals.max()
        weights = torch.softmax(self.temp*vals, dim=0)
        return torch.dot(vals, weights) 
Example #16
Source File: ParsingNetwork.py    From PRPN-Analysis with MIT License 5 votes vote down vote up
def forward(self, emb, parser_state):
        emb_last, cum_gate = parser_state
        ntimestep = emb.size(0)

        emb_last = torch.cat([emb_last, emb], dim=0)
        emb = emb_last.transpose(0, 1).transpose(1, 2)  # bsz, ninp, ntimestep + nlookback

        gates = self.gate(emb)  # bsz, 2, ntimestep
        gate = gates[:, 0, :]
        gate_next = gates[:, 1, :]
        cum_gate = torch.cat([cum_gate, gate], dim=1)
        gate_hat = torch.stack([cum_gate[:, i:i + ntimestep] for i in range(self.nslots, 0, -1)],
                               dim=2)  # bsz, ntimestep, nslots

        if self.hard:
            memory_gate = (F.hardtanh((gate[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
        else:
            memory_gate = F.sigmoid(
                (gate[:, :, None] - gate_hat) / self.resolution * 10 + 5)  # bsz, ntimestep, nslots
        memory_gate = torch.cumprod(memory_gate, dim=2)  # bsz, ntimestep, nlookback+1
        memory_gate = torch.unbind(memory_gate, dim=1)

        if self.hard:
            memory_gate_next = (F.hardtanh((gate_next[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
        else:
            memory_gate_next = F.sigmoid(
                (gate_next[:, :, None] - gate_hat) / self.resolution * 10 + 5)  # bsz, ntimestep, nslots
        memory_gate_next = torch.cumprod(memory_gate_next, dim=2)  # bsz, ntimestep, nlookback+1
        memory_gate_next = torch.unbind(memory_gate_next, dim=1)

        return (memory_gate, memory_gate_next), gate, (emb_last[-self.nlookback:], cum_gate[:, -self.nslots:]) 
Example #17
Source File: ParsingNetwork.py    From PRPN with MIT License 5 votes vote down vote up
def forward(self, emb, parser_state):
        emb_last, cum_gate = parser_state
        ntimestep = emb.size(0)

        emb_last = torch.cat([emb_last, emb], dim=0)
        emb = emb_last.transpose(0, 1).transpose(1, 2)  # bsz, ninp, ntimestep + nlookback

        gates = self.gate(emb)  # bsz, 2, ntimestep
        gate = gates[:, 0, :]
        gate_next = gates[:, 1, :]
        cum_gate = torch.cat([cum_gate, gate], dim=1)
        gate_hat = torch.stack([cum_gate[:, i:i + ntimestep] for i in range(self.nslots, 0, -1)],
                               dim=2)  # bsz, ntimestep, nslots

        if self.hard:
            memory_gate = (F.hardtanh((gate[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
        else:
            memory_gate = F.sigmoid(
                (gate[:, :, None] - gate_hat) / self.resolution * 10 + 5)  # bsz, ntimestep, nslots
        memory_gate = torch.cumprod(memory_gate, dim=2)  # bsz, ntimestep, nlookback+1
        memory_gate = torch.unbind(memory_gate, dim=1)

        if self.hard:
            memory_gate_next = (F.hardtanh((gate_next[:, :, None] - gate_hat) / self.resolution * 2 + 1) + 1) / 2
        else:
            memory_gate_next = F.sigmoid(
                (gate_next[:, :, None] - gate_hat) / self.resolution * 10 + 5)  # bsz, ntimestep, nslots
        memory_gate_next = torch.cumprod(memory_gate_next, dim=2)  # bsz, ntimestep, nlookback+1
        memory_gate_next = torch.unbind(memory_gate_next, dim=1)

        return (memory_gate, memory_gate_next), gate, (emb_last[-self.nlookback:], cum_gate[:, -self.nslots:]) 
Example #18
Source File: rewards.py    From texar-pytorch with Apache License 2.0 5 votes vote down vote up
def _discount_reward_tensor_1d(reward: torch.Tensor,
                               sequence_length: Optional[torch.LongTensor],
                               discount: float = 1.) -> torch.Tensor:
    r"""Computes discounted reward.

    Args:
        reward: 1D Tensor with shape `[batch_size]`.
        sequence_length: A Tensor of shape `[batch_size]`.
        Time steps beyond the respective sequence lengths will be masked.
        discount (float): A scalar. The discount factor.

    Returns:
        A 2D Tensor of the discounted reward.
    """
    if sequence_length is None:
        raise ValueError('sequence_length must not be `None` for 1D reward.')

    if not isinstance(sequence_length, torch.Tensor):
        sequence_length = torch.tensor(
            sequence_length, dtype=torch.int64, device=reward.device)

    batch_size = reward.shape[0]
    max_seq_length = torch.max(sequence_length)
    dtype: torch.dtype = reward.dtype

    if discount == 1.:
        disc_reward = reward.unsqueeze(-1).expand(batch_size, max_seq_length)
    else:
        mask = sequence_mask(sequence_length, dtype=dtype)
        mask = torch.cat((mask[:, 1:], torch.zeros_like(mask[:, -1:])), dim=1)
        # Make each row = [discount, ..., discount, 1, ..., 1]
        dmat = mask * discount + (1 - mask)
        dmat = torch.flip(dmat, (1,))
        dmat = torch.cumprod(dmat, dim=1)
        dmat = torch.flip(dmat, (1,))
        disc_reward = dmat * reward.unsqueeze(-1)

    disc_reward = mask_sequences(disc_reward, sequence_length, dtype=dtype)

    return disc_reward 
Example #19
Source File: _functions.py    From garage with MIT License 4 votes vote down vote up
def compute_advantages(discount, gae_lambda, max_path_length, baselines,
                       rewards):
    """Calculate advantages.

    Advantages are a discounted cumulative sum.

    Calculate advantages using a baseline according to Generalized Advantage
    Estimation (GAE)

    The discounted cumulative sum can be computed using conv2d with filter.
    filter:
        [1, (discount * gae_lambda), (discount * gae_lambda) ^ 2, ...]
        where the length is same with max_path_length.

    baselines and rewards are also has same shape.
        baselines:
        [ [b_11, b_12, b_13, ... b_1n],
          [b_21, b_22, b_23, ... b_2n],
          ...
          [b_m1, b_m2, b_m3, ... b_mn] ]
        rewards:
        [ [r_11, r_12, r_13, ... r_1n],
          [r_21, r_22, r_23, ... r_2n],
          ...
          [r_m1, r_m2, r_m3, ... r_mn] ]

    Args:
        discount (float): RL discount factor (i.e. gamma).
        gae_lambda (float): Lambda, as used for Generalized Advantage
            Estimation (GAE).
        max_path_length (int): Maximum length of a single rollout.
        baselines (torch.Tensor): A 2D vector of value function estimates with
            shape (N, T), where N is the batch dimension (number of episodes)
            and T is the maximum path length experienced by the agent. If an
            episode terminates in fewer than T time steps, the remaining
            elements in that episode should be set to 0.
        rewards (torch.Tensor): A 2D vector of per-step rewards with shape
            (N, T), where N is the batch dimension (number of episodes) and T
            is the maximum path length experienced by the agent. If an episode
            terminates in fewer than T time steps, the remaining elements in
            that episode should be set to 0.

    Returns:
        torch.Tensor: A 2D vector of calculated advantage values with shape
            (N, T), where N is the batch dimension (number of episodes) and T
            is the maximum path length experienced by the agent. If an episode
            terminates in fewer than T time steps, the remaining values in that
            episode should be set to 0.

    """
    adv_filter = torch.full((1, 1, 1, max_path_length - 1),
                            discount * gae_lambda)
    adv_filter = torch.cumprod(F.pad(adv_filter, (1, 0), value=1), dim=-1)

    deltas = (rewards + discount * F.pad(baselines, (0, 1))[:, 1:] - baselines)
    deltas = F.pad(deltas, (0, max_path_length - 1)).unsqueeze(0).unsqueeze(0)

    advantages = F.conv2d(deltas, adv_filter, stride=1).reshape(rewards.shape)
    return advantages