Python tensorflow.python.ops.array_ops.matrix_set_diag() Examples

The following are 30 code examples of tensorflow.python.ops.array_ops.matrix_set_diag(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.ops.array_ops , or try the search function .
Example #1
Source File: linear_operator_identity.py    From keras-lambda with MIT License 6 votes vote down vote up
def add_to_tensor(self, mat, name="add_to_tensor"):
    """Add matrix represented by this operator to `mat`.  Equiv to `I + mat`.

    Args:
      mat:  `Tensor` with same `dtype` and shape broadcastable to `self`.
      name:  A name to give this `Op`.

    Returns:
      A `Tensor` with broadcast shape and same `dtype` as `self`.
    """
    with self._name_scope(name, values=[mat]):
      # Shape [B1,...,Bb, 1]
      multiplier_vector = array_ops.expand_dims(self.multiplier, -1)

      # Shape [C1,...,Cc, M, M]
      mat = ops.convert_to_tensor(mat, name="mat")

      # Shape [C1,...,Cc, M]
      mat_diag = array_ops.matrix_diag_part(mat)

      # multiplier_vector broadcasts here.
      new_diag = multiplier_vector + mat_diag

      return array_ops.matrix_set_diag(mat, new_diag) 
Example #2
Source File: array_grad.py    From keras-lambda with MIT License 6 votes vote down vote up
def _MatrixSetDiagGrad(op, grad):
  input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
  diag_shape = op.inputs[1].get_shape()
  batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
  matrix_shape = input_shape[-2:]
  if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
    diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
  else:
    with ops.colocate_with(grad):
      grad_shape = array_ops.shape(grad)
      grad_rank = array_ops.rank(grad)
      batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
      matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
      min_dim = math_ops.reduce_min(matrix_shape)
      diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
  grad_input = array_ops.matrix_set_diag(
      grad, array_ops.zeros(
          diag_shape, dtype=grad.dtype))
  grad_diag = array_ops.matrix_diag_part(grad)
  return (grad_input, grad_diag) 
Example #3
Source File: linalg_grad.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 6 votes vote down vote up
def _CholeskyGrad(op, grad):
  """Gradient for Cholesky."""

  # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
  l = op.outputs[0]
  num_rows = array_ops.shape(l)[-1]
  batch_shape = array_ops.shape(l)[:-2]
  l_inverse = linalg_ops.matrix_triangular_solve(l,
                                                 linalg_ops.eye(
                                                     num_rows,
                                                     batch_shape=batch_shape,
                                                     dtype=l.dtype))

  middle = math_ops.matmul(l, grad, adjoint_a=True)
  middle = array_ops.matrix_set_diag(middle,
                                     0.5 * array_ops.matrix_diag_part(middle))
  middle = array_ops.matrix_band_part(middle, -1, 0)

  grad_a = math_ops.matmul(
      math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)

  grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
  return grad_a * 0.5 
Example #4
Source File: array_grad.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 6 votes vote down vote up
def _MatrixSetDiagGrad(op, grad):
  """Gradient for MatrixSetDiag."""
  input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
  diag_shape = op.inputs[1].get_shape()
  batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
  matrix_shape = input_shape[-2:]
  if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
    diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
  else:
    with ops.colocate_with(grad):
      grad_shape = array_ops.shape(grad)
      grad_rank = array_ops.rank(grad)
      batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
      matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
      min_dim = math_ops.reduce_min(matrix_shape)
      diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
  grad_input = array_ops.matrix_set_diag(
      grad, array_ops.zeros(
          diag_shape, dtype=grad.dtype))
  grad_diag = array_ops.matrix_diag_part(grad)
  return (grad_input, grad_diag) 
Example #5
Source File: linear_operator_identity.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def add_to_tensor(self, mat, name="add_to_tensor"):
    """Add matrix represented by this operator to `mat`.  Equiv to `I + mat`.

    Args:
      mat:  `Tensor` with same `dtype` and shape broadcastable to `self`.
      name:  A name to give this `Op`.

    Returns:
      A `Tensor` with broadcast shape and same `dtype` as `self`.
    """
    with self._name_scope(name, values=[mat]):
      # Shape [B1,...,Bb, 1]
      multiplier_vector = array_ops.expand_dims(self.multiplier, -1)

      # Shape [C1,...,Cc, M, M]
      mat = ops.convert_to_tensor(mat, name="mat")

      # Shape [C1,...,Cc, M]
      mat_diag = array_ops.matrix_diag_part(mat)

      # multiplier_vector broadcasts here.
      new_diag = multiplier_vector + mat_diag

      return array_ops.matrix_set_diag(mat, new_diag) 
Example #6
Source File: array_grad.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _MatrixSetDiagGrad(op, grad):
  input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
  diag_shape = op.inputs[1].get_shape()
  batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
  matrix_shape = input_shape[-2:]
  if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
    diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
  else:
    with ops.colocate_with(grad):
      grad_shape = array_ops.shape(grad)
      grad_rank = array_ops.rank(grad)
      batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
      matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
      min_dim = math_ops.reduce_min(matrix_shape)
      diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
  grad_input = array_ops.matrix_set_diag(
      grad, array_ops.zeros(
          diag_shape, dtype=grad.dtype))
  grad_diag = array_ops.matrix_diag_part(grad)
  return (grad_input, grad_diag) 
Example #7
Source File: array_grad.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def _MatrixSetDiagGrad(op, grad):
  input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
  diag_shape = op.inputs[1].get_shape()
  batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
  matrix_shape = input_shape[-2:]
  if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
    diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
  else:
    with ops.colocate_with(grad):
      grad_shape = array_ops.shape(grad)
      grad_rank = array_ops.rank(grad)
      batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
      matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
      min_dim = math_ops.reduce_min(matrix_shape)
      diag_shape = array_ops.concat(0, [batch_shape, [min_dim]])
  grad_input = array_ops.matrix_set_diag(
      grad, array_ops.zeros(
          diag_shape, dtype=grad.dtype))
  grad_diag = array_ops.matrix_diag_part(grad)
  return (grad_input, grad_diag) 
Example #8
Source File: array_grad.py    From lambda-packs with MIT License 6 votes vote down vote up
def _MatrixSetDiagGrad(op, grad):
  """Gradient for MatrixSetDiag."""
  input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
  diag_shape = op.inputs[1].get_shape()
  batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
  matrix_shape = input_shape[-2:]
  if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
    diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
  else:
    with ops.colocate_with(grad):
      grad_shape = array_ops.shape(grad)
      grad_rank = array_ops.rank(grad)
      batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
      matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
      min_dim = math_ops.reduce_min(matrix_shape)
      diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
  grad_input = array_ops.matrix_set_diag(
      grad, array_ops.zeros(
          diag_shape, dtype=grad.dtype))
  grad_diag = array_ops.matrix_diag_part(grad)
  return (grad_input, grad_diag) 
Example #9
Source File: dirichlet_multinomial.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def _covariance(self):
    x = self._variance_scale_term() * self._mean()
    return array_ops.matrix_set_diag(
        -math_ops.matmul(x[..., array_ops.newaxis],
                         x[..., array_ops.newaxis, :]),  # outer prod
        self._variance()) 
Example #10
Source File: linalg_grad.py    From keras-lambda with MIT License 5 votes vote down vote up
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
  """Gradient for SelfAdjointEigV2."""
  e = op.outputs[0]
  v = op.outputs[1]
  # a = op.inputs[0], which satisfies
  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
  with ops.control_dependencies([grad_e.op, grad_v.op]):
    if grad_v is not None:
      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
      # Notice that because of the term involving f, the gradient becomes
      # infinite (or NaN in practice) when eigenvalues are not unique.
      # Mathematically this should not be surprising, since for (k-fold)
      # degenerate eigenvalues, the corresponding eigenvectors are only defined
      # up to arbitrary rotation in a (k-dimensional) subspace.
      f = array_ops.matrix_set_diag(
          math_ops.reciprocal(
              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
          array_ops.zeros_like(e))
      grad_a = math_ops.matmul(
          v,
          math_ops.matmul(
              array_ops.matrix_diag(grad_e) + f * math_ops.matmul(
                  v, grad_v, adjoint_a=True),
              v,
              adjoint_b=True))
    else:
      grad_a = math_ops.matmul(
          v, math_ops.matmul(
              array_ops.matrix_diag(grad_e), v, adjoint_b=True))
    # The forward op only depends on the lower triangular part of a, so here we
    # symmetrize and take the lower triangle
    grad_a = array_ops.matrix_band_part(
        grad_a + array_ops.matrix_transpose(grad_a), -1, 0)
    grad_a = array_ops.matrix_set_diag(grad_a,
                                       0.5 * array_ops.matrix_diag_part(grad_a))
    return grad_a 
Example #11
Source File: linalg_grad.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
  """Gradient for SelfAdjointEigV2."""
  e = op.outputs[0]
  compute_v = op.get_attr("compute_v")
  # a = op.inputs[0], which satisfies
  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
  with ops.control_dependencies([grad_e, grad_v]):
    if compute_v:
      v = op.outputs[1]
      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
      # Notice that because of the term involving f, the gradient becomes
      # infinite (or NaN in practice) when eigenvalues are not unique.
      # Mathematically this should not be surprising, since for (k-fold)
      # degenerate eigenvalues, the corresponding eigenvectors are only defined
      # up to arbitrary rotation in a (k-dimensional) subspace.
      f = array_ops.matrix_set_diag(
          math_ops.reciprocal(
              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
          array_ops.zeros_like(e))
      grad_a = math_ops.matmul(
          v,
          math_ops.matmul(
              array_ops.matrix_diag(grad_e) +
              f * math_ops.matmul(v, grad_v, adjoint_a=True),
              v,
              adjoint_b=True))
    else:
      _, v = linalg_ops.self_adjoint_eig(op.inputs[0])
      grad_a = math_ops.matmul(v,
                               math_ops.matmul(
                                   array_ops.matrix_diag(grad_e),
                                   v,
                                   adjoint_b=True))
    # The forward op only depends on the lower triangular part of a, so here we
    # symmetrize and take the lower triangle
    grad_a = array_ops.matrix_band_part(
        grad_a + math_ops.conj(array_ops.matrix_transpose(grad_a)), -1, 0)
    grad_a = array_ops.matrix_set_diag(grad_a,
                                       0.5 * array_ops.matrix_diag_part(grad_a))
    return grad_a 
Example #12
Source File: bijector.py    From keras-lambda with MIT License 5 votes vote down vote up
def _preprocess_tril(self, identity_multiplier, diag, tril, event_ndims):
    """Helper to preprocess a lower triangular matrix."""
    tril = array_ops.matrix_band_part(tril, -1, 0)  # Zero out TriU.
    if identity_multiplier is None and diag is None:
      return self._process_matrix(tril, min_rank=2, event_ndims=event_ndims)
    new_diag = array_ops.matrix_diag_part(tril)
    if identity_multiplier is not None:
      new_diag += identity_multiplier
    if diag is not None:
      new_diag += diag
    tril = array_ops.matrix_set_diag(tril, new_diag)
    return self._process_matrix(tril, min_rank=2, event_ndims=event_ndims) 
Example #13
Source File: linear_operator_identity.py    From keras-lambda with MIT License 5 votes vote down vote up
def add_to_tensor(self, mat, name="add_to_tensor"):
    """Add matrix represented by this operator to `mat`.  Equiv to `I + mat`.

    Args:
      mat:  `Tensor` with same `dtype` and shape broadcastable to `self`.
      name:  A name to give this `Op`.

    Returns:
      A `Tensor` with broadcast shape and same `dtype` as `self`.
    """
    with self._name_scope(name, values=[mat]):
      mat = ops.convert_to_tensor(mat, name="mat")
      mat_diag = array_ops.matrix_diag_part(mat)
      new_diag = 1 + mat_diag
      return array_ops.matrix_set_diag(mat, new_diag) 
Example #14
Source File: linear_operator_test_util.py    From keras-lambda with MIT License 5 votes vote down vote up
def random_tril_matrix(shape,
                       dtype,
                       force_well_conditioned=False,
                       remove_upper=True):
  """[batch] lower triangular matrix.

  Args:
    shape:  `TensorShape` or Python `list`.  Shape of the returned matrix.
    dtype:  `TensorFlow` `dtype` or Python dtype
    force_well_conditioned:  Python `bool`. If `True`, returned matrix will have
      eigenvalues with modulus in `(1, 2)`.  Otherwise, eigenvalues are unit
      normal random variables.
    remove_upper:  Python `bool`.
      If `True`, zero out the strictly upper triangle.
      If `False`, the lower triangle of returned matrix will have desired
      properties, but will not not have the strictly upper triangle zero'd out.

  Returns:
    `Tensor` with desired shape and dtype.
  """
  with ops.name_scope("random_tril_matrix"):
    # Totally random matrix.  Has no nice properties.
    tril = random_normal(shape, dtype=dtype)
    if remove_upper:
      tril = array_ops.matrix_band_part(tril, -1, 0)

    # Create a diagonal with entries having modulus in [1, 2].
    if force_well_conditioned:
      maxval = ops.convert_to_tensor(np.sqrt(2.), dtype=dtype.real_dtype)
      diag = random_sign_uniform(
          shape[:-1], dtype=dtype, minval=1., maxval=maxval)
      tril = array_ops.matrix_set_diag(tril, diag)

    return tril 
Example #15
Source File: linear_operator_diag.py    From keras-lambda with MIT License 5 votes vote down vote up
def _add_to_tensor(self, x):
    x_diag = array_ops.matrix_diag_part(x)
    new_diag = self._diag + x_diag
    return array_ops.matrix_set_diag(x, new_diag) 
Example #16
Source File: operator_pd_diag.py    From keras-lambda with MIT License 5 votes vote down vote up
def _add_to_tensor(self, mat):
    mat_diag = array_ops.matrix_diag_part(mat)
    new_diag = math_ops.square(self._diag) + mat_diag
    return array_ops.matrix_set_diag(mat, new_diag) 
Example #17
Source File: operator_pd_diag.py    From keras-lambda with MIT License 5 votes vote down vote up
def _add_to_tensor(self, mat):
    mat_diag = array_ops.matrix_diag_part(mat)
    new_diag = self._diag + mat_diag
    return array_ops.matrix_set_diag(mat, new_diag) 
Example #18
Source File: operator_pd_identity.py    From keras-lambda with MIT License 5 votes vote down vote up
def _add_to_tensor(self, mat):
    # Add to a tensor in O(k) time!
    mat_diag = array_ops.matrix_diag_part(mat)
    new_diag = self._scale + mat_diag
    return array_ops.matrix_set_diag(mat, new_diag) 
Example #19
Source File: multinomial.py    From lambda-packs with MIT License 5 votes vote down vote up
def _covariance(self):
    p = self.probs * array_ops.ones_like(
        self.total_count)[..., array_ops.newaxis]
    return array_ops.matrix_set_diag(
        -math_ops.matmul(self._mean_val[..., array_ops.newaxis],
                         p[..., array_ops.newaxis, :]),  # outer product
        self._variance()) 
Example #20
Source File: multinomial.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def _covariance(self):
    p = self.probs * array_ops.ones_like(
        self.total_count)[..., array_ops.newaxis]
    return array_ops.matrix_set_diag(
        -math_ops.matmul(self._mean_val[..., array_ops.newaxis],
                         p[..., array_ops.newaxis, :]),  # outer product
        self._variance()) 
Example #21
Source File: dirichlet_multinomial.py    From lambda-packs with MIT License 5 votes vote down vote up
def _covariance(self):
    x = self._variance_scale_term() * self._mean()
    return array_ops.matrix_set_diag(
        -math_ops.matmul(x[..., array_ops.newaxis],
                         x[..., array_ops.newaxis, :]),  # outer prod
        self._variance()) 
Example #22
Source File: array_grad.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def _MatrixDiagPartGrad(op, grad):
  matrix_shape = op.inputs[0].get_shape()[-2:]
  if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]:
    return array_ops.matrix_diag(grad)
  else:
    return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad) 
Example #23
Source File: dirichlet.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _variance(self):
    scale = self.alpha_sum * math_ops.sqrt(1. + self.alpha_sum)
    alpha = self.alpha / scale
    outer_prod = -math_ops.batch_matmul(
        array_ops.expand_dims(alpha, dim=-1),  # column
        array_ops.expand_dims(alpha, dim=-2))  # row
    return array_ops.matrix_set_diag(outer_prod,
                                     alpha * (self.alpha_sum / scale - alpha)) 
Example #24
Source File: operator_pd_diag.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _add_to_tensor(self, mat):
    mat_diag = array_ops.matrix_diag_part(mat)
    new_diag = math_ops.square(self._diag) + mat_diag
    return array_ops.matrix_set_diag(mat, new_diag) 
Example #25
Source File: operator_pd_diag.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _add_to_tensor(self, mat):
    mat_diag = array_ops.matrix_diag_part(mat)
    new_diag = self._diag + mat_diag
    return array_ops.matrix_set_diag(mat, new_diag) 
Example #26
Source File: dirichlet_multinomial.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _variance(self):
    alpha_sum = array_ops.expand_dims(self.alpha_sum, -1)
    normalized_alpha = self.alpha / alpha_sum
    variance = -math_ops.batch_matmul(
        array_ops.expand_dims(normalized_alpha, -1),
        array_ops.expand_dims(normalized_alpha, -2))
    variance = array_ops.matrix_set_diag(variance, normalized_alpha *
                                         (1. - normalized_alpha))
    shared_factor = (self.n * (alpha_sum + self.n) /
                     (alpha_sum + 1) * array_ops.ones_like(self.alpha))
    variance *= array_ops.expand_dims(shared_factor, -1)
    return variance 
Example #27
Source File: multinomial.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _variance(self):
    p = self.p * array_ops.expand_dims(array_ops.ones_like(self.n), -1)
    outer_prod = math_ops.batch_matmul(
        array_ops.expand_dims(self._mean_val, -1),
        array_ops.expand_dims(p, -2))
    return array_ops.matrix_set_diag(-outer_prod,
                                     self._mean_val - self._mean_val * p) 
Example #28
Source File: onehot_categorical.py    From lambda-packs with MIT License 5 votes vote down vote up
def _covariance(self):
    p = self.probs
    ret = -math_ops.matmul(p[..., None], p[..., None, :])
    return array_ops.matrix_set_diag(ret, self._variance()) 
Example #29
Source File: operator_pd_identity.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _add_to_tensor(self, mat):
    # Add to a tensor in O(k) time!
    mat_diag = array_ops.matrix_diag_part(mat)
    new_diag = constant_op.constant(1, dtype=self.dtype) + mat_diag
    return array_ops.matrix_set_diag(mat, new_diag) 
Example #30
Source File: linalg_grad.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
  """Gradient for SelfAdjointEigV2."""
  e = op.outputs[0]
  v = op.outputs[1]
  # a = op.inputs[0], which satisfies
  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
  with ops.control_dependencies([grad_e.op, grad_v.op]):
    if grad_v is not None:
      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
      # Notice that because of the term involving f, the gradient becomes
      # infinite (or NaN in practice) when eigenvalues are not unique.
      # Mathematically this should not be surprising, since for (k-fold)
      # degenerate eigenvalues, the corresponding eigenvectors are only defined
      # up to arbitrary rotation in a (k-dimensional) subspace.
      f = array_ops.matrix_set_diag(
          math_ops.inv(
              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
          array_ops.zeros_like(e))
      grad_a = math_ops.batch_matmul(
          v,
          math_ops.batch_matmul(
              array_ops.matrix_diag(grad_e) + f * math_ops.batch_matmul(
                  v, grad_v, adj_x=True),
              v,
              adj_y=True))
    else:
      grad_a = math_ops.batch_matmul(
          v,
          math_ops.batch_matmul(
              array_ops.matrix_diag(grad_e), v, adj_y=True))
    # The forward op only depends on the lower triangular part of a, so here we
    # symmetrize and take the lower triangle
    grad_a = array_ops.matrix_band_part(
        grad_a + array_ops.matrix_transpose(grad_a), -1, 0)
    grad_a = array_ops.matrix_set_diag(grad_a,
                                       0.5 * array_ops.matrix_diag_part(grad_a))
    return grad_a