Python tensorflow.python.ops.control_flow_ops._AddNextAndBackEdge() Examples

The following are 5 code examples of tensorflow.python.ops.control_flow_ops._AddNextAndBackEdge(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.ops.control_flow_ops , or try the search function .
Example #1
Source File: control_flow_grad.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _SwitchGrad(op, *grad):
  """Gradients for a Switch op is calculated using a Merge op.

  If the switch is a loop switch, it will be visited twice. We create
  the merge on the first visit, and update the other input of the merge
  on the second visit. A next_iteration is also added on second visit.
  """
  graph = ops.get_default_graph()
  # pylint: disable=protected-access
  op_ctxt = op._get_control_flow_context()
  grad_ctxt = graph._get_control_flow_context()
  # pylint: enable=protected-access
  if isinstance(op_ctxt, WhileContext):
    merge_grad = grad_ctxt.grad_state.switch_map.get(op)
    if merge_grad is not None:
      # This is the second time this Switch is visited. It comes from
      # the non-exit branch of the Switch, so update the second input
      # to the Merge.
      # TODO: Perform shape inference with this new input.
      if grad[1] is not None:
        # pylint: disable=protected-access
        control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
        # pylint: enable=protected-access
      return None, None
    else:
      # This is the first time this Switch is visited. It always comes from
      # the Exit branch, which is grad[0]. grad[1] is empty at this point.
      # Use grad[0] for both inputs to merge for now, but update the second
      # input of merge when we see this Switch the second time.
      merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
      grad_ctxt.grad_state.switch_map[op] = merge_grad
      return merge_grad, None
  elif isinstance(op_ctxt, CondContext):
    good_grad = grad[op_ctxt.branch]
    zero_grad = grad[1 - op_ctxt.branch]
    # At this point, we have created zero_grad guarded by the right switch.
    return merge([good_grad, zero_grad], name="cond_grad")[0], None
  else:
    false_grad = switch(grad[0], op.inputs[1])[0]
    true_grad = switch(grad[1], op.inputs[1])[1]
    return merge([false_grad, true_grad])[0], None 
Example #2
Source File: control_flow_grad.py    From lambda-packs with MIT License 4 votes vote down vote up
def _SwitchGrad(op, *grad):
  """Gradients for a Switch op is calculated using a Merge op.

  If the switch is a loop switch, it will be visited twice. We create
  the merge on the first visit, and update the other input of the merge
  on the second visit. A next_iteration is also added on second visit.
  """
  graph = ops.get_default_graph()
  # pylint: disable=protected-access
  op_ctxt = op._get_control_flow_context()
  grad_ctxt = graph._get_control_flow_context()
  # pylint: enable=protected-access
  if isinstance(op_ctxt, WhileContext):
    merge_grad = grad_ctxt.grad_state.switch_map.get(op)
    if merge_grad is not None:
      # This is the second time this Switch is visited. It comes from
      # the non-exit branch of the Switch, so update the second input
      # to the Merge.
      # TODO(yuanbyu): Perform shape inference with this new input.
      if grad[1] is not None:
        # pylint: disable=protected-access
        control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
        # pylint: enable=protected-access
      return None, None
    elif grad[0] is not None:
      # This is the first time this Switch is visited. It comes from
      # the Exit branch, which is grad[0]. grad[1] is empty at this point.
      # Use grad[0] for both inputs to merge for now, but update the second
      # input of merge when we see this Switch the second time.
      merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
      grad_ctxt.grad_state.switch_map[op] = merge_grad
      return merge_grad, None
    else:
      # This is the first time this Switch is visited. It comes from the
      # Identity branch. Such a Switch has `None` gradient for the Exit branch,
      # meaning the output is not differentiable.
      return None, None
  elif isinstance(op_ctxt, CondContext):
    good_grad = grad[op_ctxt.branch]
    zero_grad = grad[1 - op_ctxt.branch]
    # At this point, we have created zero_grad guarded by the right switch.
    return merge([good_grad, zero_grad], name="cond_grad")[0], None
  else:
    false_grad = switch(grad[0], op.inputs[1])[0]
    true_grad = switch(grad[1], op.inputs[1])[1]
    return merge([false_grad, true_grad])[0], None 
Example #3
Source File: control_flow_grad.py    From auto-alt-text-lambda-api with MIT License 4 votes vote down vote up
def _SwitchGrad(op, *grad):
  """Gradients for a Switch op is calculated using a Merge op.

  If the switch is a loop switch, it will be visited twice. We create
  the merge on the first visit, and update the other input of the merge
  on the second visit. A next_iteration is also added on second visit.
  """
  graph = ops.get_default_graph()
  # pylint: disable=protected-access
  op_ctxt = op._get_control_flow_context()
  grad_ctxt = graph._get_control_flow_context()
  # pylint: enable=protected-access
  if isinstance(op_ctxt, WhileContext):
    merge_grad = grad_ctxt.grad_state.switch_map.get(op)
    if merge_grad is not None:
      # This is the second time this Switch is visited. It comes from
      # the non-exit branch of the Switch, so update the second input
      # to the Merge.
      # TODO: Perform shape inference with this new input.
      if grad[1] is not None:
        # pylint: disable=protected-access
        control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
        # pylint: enable=protected-access
      return None, None
    elif grad[0] is not None:
      # This is the first time this Switch is visited. It comes from
      # the Exit branch, which is grad[0]. grad[1] is empty at this point.
      # Use grad[0] for both inputs to merge for now, but update the second
      # input of merge when we see this Switch the second time.
      merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
      grad_ctxt.grad_state.switch_map[op] = merge_grad
      return merge_grad, None
    else:
      # This is the first time this Switch is visited. It comes from the
      # Identity branch. Such a Switch has `None` gradient for the Exit branch,
      # meaning the output is not differentiable.
      return None, None
  elif isinstance(op_ctxt, CondContext):
    good_grad = grad[op_ctxt.branch]
    zero_grad = grad[1 - op_ctxt.branch]
    # At this point, we have created zero_grad guarded by the right switch.
    return merge([good_grad, zero_grad], name="cond_grad")[0], None
  else:
    false_grad = switch(grad[0], op.inputs[1])[0]
    true_grad = switch(grad[1], op.inputs[1])[1]
    return merge([false_grad, true_grad])[0], None 
Example #4
Source File: control_flow_grad.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 4 votes vote down vote up
def _SwitchGrad(op, *grad):
  """Gradients for a Switch op is calculated using a Merge op.

  If the switch is a loop switch, it will be visited twice. We create
  the merge on the first visit, and update the other input of the merge
  on the second visit. A next_iteration is also added on second visit.
  """
  graph = ops.get_default_graph()
  # pylint: disable=protected-access
  op_ctxt = op._get_control_flow_context()
  grad_ctxt = graph._get_control_flow_context()
  # pylint: enable=protected-access
  if isinstance(op_ctxt, WhileContext):
    merge_grad = grad_ctxt.grad_state.switch_map.get(op)
    if merge_grad is not None:
      # This is the second time this Switch is visited. It comes from
      # the non-exit branch of the Switch, so update the second input
      # to the Merge.
      # TODO(yuanbyu): Perform shape inference with this new input.
      if grad[1] is not None:
        # pylint: disable=protected-access
        control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
        # pylint: enable=protected-access
      return None, None
    elif grad[0] is not None:
      # This is the first time this Switch is visited. It comes from
      # the Exit branch, which is grad[0]. grad[1] is empty at this point.
      # Use grad[0] for both inputs to merge for now, but update the second
      # input of merge when we see this Switch the second time.
      merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
      grad_ctxt.grad_state.switch_map[op] = merge_grad
      return merge_grad, None
    else:
      # This is the first time this Switch is visited. It comes from the
      # Identity branch. Such a Switch has `None` gradient for the Exit branch,
      # meaning the output is not differentiable.
      return None, None
  elif isinstance(op_ctxt, CondContext):
    good_grad = grad[op_ctxt.branch]
    zero_grad = grad[1 - op_ctxt.branch]
    # At this point, we have created zero_grad guarded by the right switch.
    # Unfortunately, we may still get None here for not trainable data types.
    if zero_grad is None:
      return None, None
    return merge([good_grad, zero_grad], name="cond_grad")[0], None
  else:
    false_grad = switch(grad[0], op.inputs[1])[0]
    true_grad = switch(grad[1], op.inputs[1])[1]
    return merge([false_grad, true_grad])[0], None 
Example #5
Source File: control_flow_grad.py    From keras-lambda with MIT License 4 votes vote down vote up
def _SwitchGrad(op, *grad):
  """Gradients for a Switch op is calculated using a Merge op.

  If the switch is a loop switch, it will be visited twice. We create
  the merge on the first visit, and update the other input of the merge
  on the second visit. A next_iteration is also added on second visit.
  """
  graph = ops.get_default_graph()
  # pylint: disable=protected-access
  op_ctxt = op._get_control_flow_context()
  grad_ctxt = graph._get_control_flow_context()
  # pylint: enable=protected-access
  if isinstance(op_ctxt, WhileContext):
    merge_grad = grad_ctxt.grad_state.switch_map.get(op)
    if merge_grad is not None:
      # This is the second time this Switch is visited. It comes from
      # the non-exit branch of the Switch, so update the second input
      # to the Merge.
      # TODO: Perform shape inference with this new input.
      if grad[1] is not None:
        # pylint: disable=protected-access
        control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
        # pylint: enable=protected-access
      return None, None
    elif grad[0] is not None:
      # This is the first time this Switch is visited. It comes from
      # the Exit branch, which is grad[0]. grad[1] is empty at this point.
      # Use grad[0] for both inputs to merge for now, but update the second
      # input of merge when we see this Switch the second time.
      merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
      grad_ctxt.grad_state.switch_map[op] = merge_grad
      return merge_grad, None
    else:
      # This is the first time this Switch is visited. It comes from the
      # Identity branch. Such a Switch has `None` gradient for the Exit branch,
      # meaning the output is not differentiable.
      return None, None
  elif isinstance(op_ctxt, CondContext):
    good_grad = grad[op_ctxt.branch]
    zero_grad = grad[1 - op_ctxt.branch]
    # At this point, we have created zero_grad guarded by the right switch.
    return merge([good_grad, zero_grad], name="cond_grad")[0], None
  else:
    false_grad = switch(grad[0], op.inputs[1])[0]
    true_grad = switch(grad[1], op.inputs[1])[1]
    return merge([false_grad, true_grad])[0], None