Python tensorflow.python.ops.control_flow_ops._AddNextAndBackEdge() Examples
The following are 5
code examples of tensorflow.python.ops.control_flow_ops._AddNextAndBackEdge().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.ops.control_flow_ops
, or try the search function
.
Example #1
Source File: control_flow_grad.py From deep_image_model with Apache License 2.0 | 5 votes |
def _SwitchGrad(op, *grad): """Gradients for a Switch op is calculated using a Merge op. If the switch is a loop switch, it will be visited twice. We create the merge on the first visit, and update the other input of the merge on the second visit. A next_iteration is also added on second visit. """ graph = ops.get_default_graph() # pylint: disable=protected-access op_ctxt = op._get_control_flow_context() grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access if isinstance(op_ctxt, WhileContext): merge_grad = grad_ctxt.grad_state.switch_map.get(op) if merge_grad is not None: # This is the second time this Switch is visited. It comes from # the non-exit branch of the Switch, so update the second input # to the Merge. # TODO: Perform shape inference with this new input. if grad[1] is not None: # pylint: disable=protected-access control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1]) # pylint: enable=protected-access return None, None else: # This is the first time this Switch is visited. It always comes from # the Exit branch, which is grad[0]. grad[1] is empty at this point. # Use grad[0] for both inputs to merge for now, but update the second # input of merge when we see this Switch the second time. merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] grad_ctxt.grad_state.switch_map[op] = merge_grad return merge_grad, None elif isinstance(op_ctxt, CondContext): good_grad = grad[op_ctxt.branch] zero_grad = grad[1 - op_ctxt.branch] # At this point, we have created zero_grad guarded by the right switch. return merge([good_grad, zero_grad], name="cond_grad")[0], None else: false_grad = switch(grad[0], op.inputs[1])[0] true_grad = switch(grad[1], op.inputs[1])[1] return merge([false_grad, true_grad])[0], None
Example #2
Source File: control_flow_grad.py From lambda-packs with MIT License | 4 votes |
def _SwitchGrad(op, *grad): """Gradients for a Switch op is calculated using a Merge op. If the switch is a loop switch, it will be visited twice. We create the merge on the first visit, and update the other input of the merge on the second visit. A next_iteration is also added on second visit. """ graph = ops.get_default_graph() # pylint: disable=protected-access op_ctxt = op._get_control_flow_context() grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access if isinstance(op_ctxt, WhileContext): merge_grad = grad_ctxt.grad_state.switch_map.get(op) if merge_grad is not None: # This is the second time this Switch is visited. It comes from # the non-exit branch of the Switch, so update the second input # to the Merge. # TODO(yuanbyu): Perform shape inference with this new input. if grad[1] is not None: # pylint: disable=protected-access control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1]) # pylint: enable=protected-access return None, None elif grad[0] is not None: # This is the first time this Switch is visited. It comes from # the Exit branch, which is grad[0]. grad[1] is empty at this point. # Use grad[0] for both inputs to merge for now, but update the second # input of merge when we see this Switch the second time. merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] grad_ctxt.grad_state.switch_map[op] = merge_grad return merge_grad, None else: # This is the first time this Switch is visited. It comes from the # Identity branch. Such a Switch has `None` gradient for the Exit branch, # meaning the output is not differentiable. return None, None elif isinstance(op_ctxt, CondContext): good_grad = grad[op_ctxt.branch] zero_grad = grad[1 - op_ctxt.branch] # At this point, we have created zero_grad guarded by the right switch. return merge([good_grad, zero_grad], name="cond_grad")[0], None else: false_grad = switch(grad[0], op.inputs[1])[0] true_grad = switch(grad[1], op.inputs[1])[1] return merge([false_grad, true_grad])[0], None
Example #3
Source File: control_flow_grad.py From auto-alt-text-lambda-api with MIT License | 4 votes |
def _SwitchGrad(op, *grad): """Gradients for a Switch op is calculated using a Merge op. If the switch is a loop switch, it will be visited twice. We create the merge on the first visit, and update the other input of the merge on the second visit. A next_iteration is also added on second visit. """ graph = ops.get_default_graph() # pylint: disable=protected-access op_ctxt = op._get_control_flow_context() grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access if isinstance(op_ctxt, WhileContext): merge_grad = grad_ctxt.grad_state.switch_map.get(op) if merge_grad is not None: # This is the second time this Switch is visited. It comes from # the non-exit branch of the Switch, so update the second input # to the Merge. # TODO: Perform shape inference with this new input. if grad[1] is not None: # pylint: disable=protected-access control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1]) # pylint: enable=protected-access return None, None elif grad[0] is not None: # This is the first time this Switch is visited. It comes from # the Exit branch, which is grad[0]. grad[1] is empty at this point. # Use grad[0] for both inputs to merge for now, but update the second # input of merge when we see this Switch the second time. merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] grad_ctxt.grad_state.switch_map[op] = merge_grad return merge_grad, None else: # This is the first time this Switch is visited. It comes from the # Identity branch. Such a Switch has `None` gradient for the Exit branch, # meaning the output is not differentiable. return None, None elif isinstance(op_ctxt, CondContext): good_grad = grad[op_ctxt.branch] zero_grad = grad[1 - op_ctxt.branch] # At this point, we have created zero_grad guarded by the right switch. return merge([good_grad, zero_grad], name="cond_grad")[0], None else: false_grad = switch(grad[0], op.inputs[1])[0] true_grad = switch(grad[1], op.inputs[1])[1] return merge([false_grad, true_grad])[0], None
Example #4
Source File: control_flow_grad.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 4 votes |
def _SwitchGrad(op, *grad): """Gradients for a Switch op is calculated using a Merge op. If the switch is a loop switch, it will be visited twice. We create the merge on the first visit, and update the other input of the merge on the second visit. A next_iteration is also added on second visit. """ graph = ops.get_default_graph() # pylint: disable=protected-access op_ctxt = op._get_control_flow_context() grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access if isinstance(op_ctxt, WhileContext): merge_grad = grad_ctxt.grad_state.switch_map.get(op) if merge_grad is not None: # This is the second time this Switch is visited. It comes from # the non-exit branch of the Switch, so update the second input # to the Merge. # TODO(yuanbyu): Perform shape inference with this new input. if grad[1] is not None: # pylint: disable=protected-access control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1]) # pylint: enable=protected-access return None, None elif grad[0] is not None: # This is the first time this Switch is visited. It comes from # the Exit branch, which is grad[0]. grad[1] is empty at this point. # Use grad[0] for both inputs to merge for now, but update the second # input of merge when we see this Switch the second time. merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] grad_ctxt.grad_state.switch_map[op] = merge_grad return merge_grad, None else: # This is the first time this Switch is visited. It comes from the # Identity branch. Such a Switch has `None` gradient for the Exit branch, # meaning the output is not differentiable. return None, None elif isinstance(op_ctxt, CondContext): good_grad = grad[op_ctxt.branch] zero_grad = grad[1 - op_ctxt.branch] # At this point, we have created zero_grad guarded by the right switch. # Unfortunately, we may still get None here for not trainable data types. if zero_grad is None: return None, None return merge([good_grad, zero_grad], name="cond_grad")[0], None else: false_grad = switch(grad[0], op.inputs[1])[0] true_grad = switch(grad[1], op.inputs[1])[1] return merge([false_grad, true_grad])[0], None
Example #5
Source File: control_flow_grad.py From keras-lambda with MIT License | 4 votes |
def _SwitchGrad(op, *grad): """Gradients for a Switch op is calculated using a Merge op. If the switch is a loop switch, it will be visited twice. We create the merge on the first visit, and update the other input of the merge on the second visit. A next_iteration is also added on second visit. """ graph = ops.get_default_graph() # pylint: disable=protected-access op_ctxt = op._get_control_flow_context() grad_ctxt = graph._get_control_flow_context() # pylint: enable=protected-access if isinstance(op_ctxt, WhileContext): merge_grad = grad_ctxt.grad_state.switch_map.get(op) if merge_grad is not None: # This is the second time this Switch is visited. It comes from # the non-exit branch of the Switch, so update the second input # to the Merge. # TODO: Perform shape inference with this new input. if grad[1] is not None: # pylint: disable=protected-access control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1]) # pylint: enable=protected-access return None, None elif grad[0] is not None: # This is the first time this Switch is visited. It comes from # the Exit branch, which is grad[0]. grad[1] is empty at this point. # Use grad[0] for both inputs to merge for now, but update the second # input of merge when we see this Switch the second time. merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] grad_ctxt.grad_state.switch_map[op] = merge_grad return merge_grad, None else: # This is the first time this Switch is visited. It comes from the # Identity branch. Such a Switch has `None` gradient for the Exit branch, # meaning the output is not differentiable. return None, None elif isinstance(op_ctxt, CondContext): good_grad = grad[op_ctxt.branch] zero_grad = grad[1 - op_ctxt.branch] # At this point, we have created zero_grad guarded by the right switch. return merge([good_grad, zero_grad], name="cond_grad")[0], None else: false_grad = switch(grad[0], op.inputs[1])[0] true_grad = switch(grad[1], op.inputs[1])[1] return merge([false_grad, true_grad])[0], None