Python tensorflow.python.framework.tensor_shape.dimension_value() Examples
The following are 3
code examples of tensorflow.python.framework.tensor_shape.dimension_value().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.framework.tensor_shape
, or try the search function
.
Example #1
Source File: layers.py From RLs with Apache License 2.0 | 6 votes |
def build(self, input_shape): super().build(input_shape) self.build = False self.last_dim = tensor_shape.dimension_value(input_shape[-1]) self.noisy_w = self.add_weight( 'noise_kernel', shape=[self.last_dim, self.units], initializer=tf.random_normal_initializer(0.0, .1), regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, dtype=self.dtype, trainable=True) if self.use_bias: self.noisy_b = self.add_weight( 'noise_bias', shape=[self.units, ], initializer=tf.constant_initializer(self.noise_sigma / (self.units**0.5)), regularizer=self.bias_regularizer, constraint=self.bias_constraint, dtype=self.dtype, trainable=True) else: self.bias = None self.build = True
Example #2
Source File: gpt2_model.py From gpt-2-tensorflow2.0 with MIT License | 5 votes |
def build(self, input_shape): if self.proj_weights is None: input_dim = tensor_shape.dimension_value(input_shape[-1]) self.layer_weights = self.add_weight( 'output_layer_weights', shape=[input_dim, self.output_dim], initializer=self.kernel_initializer, trainable=True) super(OutputLayer, self).build(input_shape)
Example #3
Source File: strcuture.py From BERT with Apache License 2.0 | 4 votes |
def convert_legacy_structure(output_types, output_shapes, output_classes): """Returns a `Structure` that represents the given legacy structure. This method provides a way to convert from the existing `Dataset` and `Iterator` structure-related properties to a `Structure` object. A "legacy" structure is represented by the `tf.data.Dataset.output_types`, `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes` properties. TODO(b/110122868): Remove this function once `Structure` is used throughout `tf.data`. Args: output_types: A nested structure of `tf.DType` objects corresponding to each component of a structured value. output_shapes: A nested structure of `tf.TensorShape` objects corresponding to each component a structured value. output_classes: A nested structure of Python `type` objects corresponding to each component of a structured value. Returns: A `Structure`. Raises: TypeError: If a structure cannot be built from the arguments, because one of the component classes in `output_classes` is not supported. """ flat_types = nest.flatten(output_types) flat_shapes = nest.flatten(output_shapes) flat_classes = nest.flatten(output_classes) flat_ret = [] for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes, flat_classes): if isinstance(flat_class, Structure): flat_ret.append(flat_class) elif issubclass(flat_class, sparse_tensor_lib.SparseTensor): flat_ret.append(SparseTensorStructure(flat_type, flat_shape)) elif issubclass(flat_class, ops.Tensor): flat_ret.append(TensorStructure(flat_type, flat_shape)) elif issubclass(flat_class, tensor_array_ops.TensorArray): # We sneaked the dynamic_size and infer_shape into the legacy shape. flat_ret.append( TensorArrayStructure( flat_type, flat_shape[2:], dynamic_size=tensor_shape.dimension_value(flat_shape[0]), infer_shape=tensor_shape.dimension_value(flat_shape[1]))) else: # NOTE(mrry): Since legacy structures produced by iterators only # comprise Tensors, SparseTensors, and nests, we do not need to # support all structure types here. raise TypeError( "Could not build a structure for output class %r" % (flat_class,)) ret = nest.pack_sequence_as(output_classes, flat_ret) if isinstance(ret, Structure): return ret else: return NestedStructure(ret) # NOTE(mrry): The following classes make extensive use of non-public methods of # their base class, so we disable the protected-access lint warning once here. # pylint: disable=protected-access # @tf_export("data.experimental.NestedStructure")