Python gin.configurable() Examples

The following are 19 code examples of gin.configurable(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module gin , or try the search function .
Example #1
Source File: trax.py    From BERT with Apache License 2.0 6 votes vote down vote up
def get_random_number_generator_and_set_seed(seed=None):
  """Get a JAX random number generator and set random seed everywhere."""
  random.seed(seed)
  # While python random accepts None as seed and uses time/os seed then,
  # some other functions expect integers so we create one here.
  if seed is None:
    seed = random.randint(0, 2**31 - 1)
  tf.set_random_seed(seed)
  numpy.random.seed(seed)
  return jax_random.get_prng(seed)


# TODO(trax):
# * Make configurable:
#   * loss
#   * metrics
# * Training loop callbacks/hooks/...
# * Save/restore: pickle unsafe. Use np.array.savez + MessagePack?
# * Move metrics to metrics.py
# * Setup namedtuples for interfaces (e.g. lr fun constructors can take a
#   LearningRateInit, metric funs, etc.).
# * Allow disabling eval 
Example #2
Source File: test_util.py    From ml-fairness-gym with Apache License 2.0 6 votes vote down vote up
def _step_impl(self, state, action):
    """Run one timestep of the environment's dynamics.

    At each timestep, x is flipped from zero to one or one to zero.

    Args:
      state: A `State` object containing the current state.
      action: An action in `action_space`.

    Returns:
      A `State` object containing the updated state.
    """
    del action  # Unused.
    state.x = [1 - x for x in state.x]
    return state


# TODO(): There isn't actually anything to configure in DummyMetric,
# but we mark it as configurable so that we can refer to it on the
# right-hand-side of expressions in gin configurations.  Find out whether
# there's a better way of indicating that than gin.configurable. 
Example #3
Source File: transformer.py    From mesh with Apache License 2.0 6 votes vote down vote up
def initialize(self):
    """Initialize the teacher model from the checkpoint.

    This function will be called after the graph has been constructed.
    """
    if self.fraction_soft == 0.0:
      # Do nothing if we do not need the teacher.
      return
    vars_to_restore = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope="teacher")
    tf.train.init_from_checkpoint(
        self.teacher_checkpoint,
        {v.name[len("teacher/"):].split(":")[0]: v for v in vars_to_restore})


# gin-configurable constructors 
Example #4
Source File: trainer.py    From trax with Apache License 2.0 5 votes vote down vote up
def _output_dir_or_default():
  """Returns a path to the output directory."""
  if FLAGS.output_dir:
    output_dir = FLAGS.output_dir
    trainer_lib.log('Using --output_dir {}'.format(output_dir))
    return os.path.expanduser(output_dir)

  # Else, generate a default output dir (under the user's home directory).
  try:
    dataset_name = gin.query_parameter('data_streams.dataset_name')
  except ValueError:
    dataset_name = 'random'
  output_name = '{model_name}_{dataset_name}_{timestamp}'.format(
      model_name=gin.query_parameter('train.model').configurable.name,
      dataset_name=dataset_name,
      timestamp=datetime.datetime.now().strftime('%Y%m%d_%H%M'),
  )
  output_dir = os.path.join('~', 'trax', output_name)
  output_dir = os.path.expanduser(output_dir)
  print()
  trainer_lib.log('No --output_dir specified')
  trainer_lib.log('Using default output_dir: {}'.format(output_dir))
  return output_dir


# TODO(afrozm): Share between trainer.py and rl_trainer.py 
Example #5
Source File: utils.py    From text-to-text-transfer-transformer with Apache License 2.0 5 votes vote down vote up
def rate_unsupervised(task, value=1e6):
  """Gin-configurable mixing rate for the unsupervised co-training task."""
  del task
  return value 
Example #6
Source File: environment_utilities.py    From agents with Apache License 2.0 5 votes vote down vote up
def compute_optimal_reward_with_classification_environment(
    observation, environment):
  """Helper function for gin configurable Regret metric."""
  del observation
  return environment.compute_optimal_reward() 
Example #7
Source File: environment_utilities.py    From agents with Apache License 2.0 5 votes vote down vote up
def compute_optimal_action_with_classification_environment(
    observation, environment):
  """Helper function for gin configurable SuboptimalArms metric."""
  del observation
  return environment.compute_optimal_action() 
Example #8
Source File: suite_mujoco.py    From agents with Apache License 2.0 5 votes vote down vote up
def load(
    environment_name: Text,
    discount: types.Float = 1.0,
    max_episode_steps: Optional[types.Int] = None,
    gym_env_wrappers: Sequence[types.GymEnvWrapper] = (),
    env_wrappers: Sequence[types.PyEnvWrapper] = (),
    spec_dtype_map: Optional[Dict[gym.Space, np.dtype]] = None
) -> py_environment.PyEnvironment:
  """Loads the selected environment and wraps it with the specified wrappers.

  Note that by default a TimeLimit wrapper is used to limit episode lengths
  to the default benchmarks defined by the registered environments.

  Args:
    environment_name: Name for the environment to load.
    discount: Discount to use for the environment.
    max_episode_steps: If None the max_episode_steps will be set to the default
      step limit defined in the environment's spec. No limit is applied if set
      to 0 or if there is no timestep_limit set in the environment's spec.
    gym_env_wrappers: Iterable with references to wrapper classes to use
      directly on the gym environment.
    env_wrappers: Iterable with references to wrapper classes to use on the
      gym_wrapped environment.
    spec_dtype_map: A dict that maps gym specs to tf dtypes to use as the
      default dtype for the tensors. An easy way how to configure a custom
      mapping through Gin is to define a gin-configurable function that returns
      desired mapping and call it in your Gin config file, for example:
      `suite_gym.load.spec_dtype_map = @get_custom_mapping()`.

  Returns:
    A PyEnvironmentBase instance.
  """
  return suite_gym.load(environment_name, discount, max_episode_steps,
                        gym_env_wrappers, env_wrappers, spec_dtype_map) 
Example #9
Source File: transformer_layers.py    From mesh with Apache License 2.0 5 votes vote down vote up
def attention_internal(self, context, q, m, memory_length, bias):
    logits = mtf.layers.us_einsum(
        [q, m], reduced_dims=[context.model.model_dim])
    if bias is not None:
      logits += bias
    weights = mtf.softmax(logits, memory_length)
    # TODO(noam): make dropout_broadcast_dims configurable
    dropout_broadcast_dims = [context.length_dim]
    weights = mtf.dropout(
        weights, rate=self.dropout_rate if context.train else 0.0,
        noise_shape=weights.shape - dropout_broadcast_dims)
    u = mtf.einsum([weights, m], reduced_dims=[memory_length])
    return self.compute_y(context, u) 
Example #10
Source File: utils.py    From mesh with Apache License 2.0 5 votes vote down vote up
def separate_vocabularies(inputs=gin.REQUIRED, targets=gin.REQUIRED):
  """Gin-configurable helper function to generate a tuple of vocabularies."""
  return (inputs, targets)


# TODO(katherinelee): Update layout_rules string when noam updates the
# definition in run 
Example #11
Source File: utils.py    From mesh with Apache License 2.0 5 votes vote down vote up
def parse_gin_defaults_and_flags():
  """Parses all default gin files and those provided via flags."""
  # Register .gin file search paths with gin
  for gin_file_path in FLAGS.gin_location_prefix:
    gin.add_config_file_search_path(gin_file_path)
  # Set up the default values for the configurable parameters. These values will
  # be overridden by any user provided gin files/parameters.
  gin.parse_config_file(
      pkg_resources.resource_filename(__name__, _DEFAULT_CONFIG_FILE))
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)


# TODO(noam): maybe add gin-config to mtf.get_variable so we can delete
#  this stupid VariableDtype class and stop passing it all over creation. 
Example #12
Source File: vrgripper_env_models.py    From tensor2robot with Apache License 2.0 5 votes vote down vote up
def loss_fn(self, labels, inference_outputs, mode, params=None):
    """This implements outer loss and configurable inner losses."""
    if params and params.get('is_outer_loss', False):
      pass
    if self._num_mixture_components > 1:
      gm = mdn.get_mixture_distribution(
          inference_outputs['dist_params'], self._num_mixture_components,
          self._action_size,
          self._output_mean if self._normalize_outputs else None)
      return -tf.reduce_mean(gm.log_prob(labels.action))
    else:
      return self._outer_loss_multiplier * tf.losses.mean_squared_error(
          labels=labels.action,
          predictions=inference_outputs['inference_output']) 
Example #13
Source File: maml_model.py    From tensor2robot with Apache License 2.0 5 votes vote down vote up
def create_train_op(self,
                      loss,
                      optimizer,
                      update_ops=None,
                      train_outputs=None):
    """Create meta-training op.

    MAMLModel has a configurable var_scope used to select which variables to
    train on. Note that MAMLInnerLoopGradientDescent also has such a parameter
    to decide which variables to update in the *inner* loop. If you don't want
    to update a set of variables in both the inner and outer loop, you'll need
    to configure var_scope for both MAMLModel *and*
    MAMLInnerLoopGradientDescent.

    Args:
      loss: The loss we compute within model_train_fn.
      optimizer: An instance of `tf.train.Optimizer`.
      update_ops: List of update ops to execute alongside the training op.
      train_outputs: (Optional) A dict with additional tensors the training
        model generates.

    Returns:
      train_op: Op for the training step.
    """
    vars_to_train = tf.trainable_variables()
    if self._var_scope is not None:
      vars_to_train = [
          v for v in vars_to_train if v.op.name.startswith(self._var_scope)]
    summarize_gradients = self._summarize_gradients
    if self.is_device_tpu:
      # TPUs don't support summaries up until now. Hence, we overwrite the user
      # provided summarize_gradients option to False.
      if self._summarize_gradients:
        logging.info('We cannot use summarize_gradients on TPUs.')
      summarize_gradients = False
    return contrib_training.create_train_op(
        loss,
        optimizer,
        variables_to_train=vars_to_train,
        summarize_gradients=summarize_gradients,
        update_ops=update_ops) 
Example #14
Source File: transformer.py    From mesh with Apache License 2.0 4 votes vote down vote up
def make_bi_student_teacher(input_vocab_size=gin.REQUIRED,
                            output_vocab_size=gin.REQUIRED,
                            layout=None,
                            mesh_shape=None):
  """Gin-configurable bitransformer student teacher constructor.

  In your config file you need to set the encoder and decoder layers like this:
    encoder_layers = [
        @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
        @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
    ]
    decoder_layers = [
        @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
        @mesh_tensorflow.transformer.transformer_layers.EncDecAttention,
        @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
    ]
    teacher/encoder/transformer.make_layer_stack.layers = %encoder_layers
    teacher/decoder/transformer.make_layer_stack.layers = %decoder_layers
    student/encoder/transformer.make_layer_stack.layers = %encoder_layers
    student/decoder/transformer.make_layer_stack.layers = %decoder_layers

  Args:
    input_vocab_size: a integer
    output_vocab_size: an integer
    layout: optional - an input to mtf.convert_to_layout_rules Some layers (e.g.
      MoE layers) cheat by looking at layout and mesh_shape
    mesh_shape: optional - an input to mtf.convert_to_shape Some layers (e.g.
      MoE layers) cheat by looking at layout and mesh_shape

  Returns:
    a StudentTeacher
  """
  with gin.config_scope("student"):
    student = make_bitransformer(
        input_vocab_size=input_vocab_size,
        output_vocab_size=output_vocab_size,
        layout=layout,
        mesh_shape=mesh_shape)
  with gin.config_scope("teacher"):
    teacher = make_bitransformer(
        input_vocab_size=input_vocab_size,
        output_vocab_size=output_vocab_size,
        layout=layout,
        mesh_shape=mesh_shape)
  return StudentTeacher(student=student, teacher=teacher) 
Example #15
Source File: suite_gym.py    From agents with Apache License 2.0 4 votes vote down vote up
def load(
    environment_name: Text,
    discount: types.Float = 1.0,
    max_episode_steps: Optional[types.Int] = None,
    gym_env_wrappers: Sequence[types.GymEnvWrapper] = (),
    env_wrappers: Sequence[types.PyEnvWrapper] = (),
    spec_dtype_map: Optional[Dict[gym.Space, np.dtype]] = None,
    gym_kwargs=None) -> py_environment.PyEnvironment:
  """Loads the selected environment and wraps it with the specified wrappers.

  Note that by default a TimeLimit wrapper is used to limit episode lengths
  to the default benchmarks defined by the registered environments.

  Args:
    environment_name: Name for the environment to load.
    discount: Discount to use for the environment.
    max_episode_steps: If None the max_episode_steps will be set to the default
      step limit defined in the environment's spec. No limit is applied if set
      to 0 or if there is no max_episode_steps set in the environment's spec.
    gym_env_wrappers: Iterable with references to wrapper classes to use
      directly on the gym environment.
    env_wrappers: Iterable with references to wrapper classes to use on the
      gym_wrapped environment.
    spec_dtype_map: A dict that maps gym spaces to np dtypes to use as the
      default dtype for the arrays. An easy way how to configure a custom
      mapping through Gin is to define a gin-configurable function that returns
      desired mapping and call it in your Gin congif file, for example:
      `suite_gym.load.spec_dtype_map = @get_custom_mapping()`.
    gym_kwargs: The kwargs to pass to the Gym environment class.

  Returns:
    A PyEnvironment instance.
  """
  gym_kwargs = gym_kwargs if gym_kwargs else {}
  gym_spec = gym.spec(environment_name)
  gym_env = gym_spec.make(**gym_kwargs)

  if max_episode_steps is None and gym_spec.max_episode_steps is not None:
    max_episode_steps = gym_spec.max_episode_steps

  return wrap_env(
      gym_env,
      discount=discount,
      max_episode_steps=max_episode_steps,
      gym_env_wrappers=gym_env_wrappers,
      env_wrappers=env_wrappers,
      spec_dtype_map=spec_dtype_map) 
Example #16
Source File: suite_gym.py    From agents with Apache License 2.0 4 votes vote down vote up
def wrap_env(
    gym_env: gym.Env,
    discount: types.Float = 1.0,
    max_episode_steps: Optional[types.Int] = None,
    gym_env_wrappers: Sequence[types.GymEnvWrapper] = (),
    time_limit_wrapper: TimeLimitWrapperType = wrappers.TimeLimit,
    env_wrappers: Sequence[types.PyEnvWrapper] = (),
    spec_dtype_map: Optional[Dict[gym.Space, np.dtype]] = None,
    auto_reset: bool = True) -> py_environment.PyEnvironment:
  """Wraps given gym environment with TF Agent's GymWrapper.

  Note that by default a TimeLimit wrapper is used to limit episode lengths
  to the default benchmarks defined by the registered environments.

  Args:
    gym_env: An instance of OpenAI gym environment.
    discount: Discount to use for the environment.
    max_episode_steps: Used to create a TimeLimitWrapper. No limit is applied
      if set to None or 0. Usually set to `gym_spec.max_episode_steps` in `load.
    gym_env_wrappers: Iterable with references to wrapper classes to use
      directly on the gym environment.
    time_limit_wrapper: Wrapper that accepts (env, max_episode_steps) params to
      enforce a TimeLimit. Usuaully this should be left as the default,
      wrappers.TimeLimit.
    env_wrappers: Iterable with references to wrapper classes to use on the
      gym_wrapped environment.
    spec_dtype_map: A dict that maps gym specs to tf dtypes to use as the
      default dtype for the tensors. An easy way how to configure a custom
      mapping through Gin is to define a gin-configurable function that returns
      desired mapping and call it in your Gin config file, for example:
      `suite_gym.load.spec_dtype_map = @get_custom_mapping()`.
    auto_reset: If True (default), reset the environment automatically after a
      terminal state is reached.

  Returns:
    A PyEnvironment instance.
  """

  for wrapper in gym_env_wrappers:
    gym_env = wrapper(gym_env)
  env = gym_wrapper.GymWrapper(
      gym_env,
      discount=discount,
      spec_dtype_map=spec_dtype_map,
      auto_reset=auto_reset,
  )

  if max_episode_steps is not None and max_episode_steps > 0:
    env = time_limit_wrapper(env, max_episode_steps)

  for wrapper in env_wrappers:
    env = wrapper(env)

  return env 
Example #17
Source File: transformer.py    From mesh with Apache License 2.0 4 votes vote down vote up
def make_bitransformer(
    input_vocab_size=gin.REQUIRED,
    output_vocab_size=gin.REQUIRED,
    layout=None,
    mesh_shape=None,
    encoder_name="encoder",
    decoder_name="decoder"):
  """Gin-configurable bitransformer constructor.

  In your config file you need to set the encoder and decoder layers like this:
  encoder/make_layer_stack.layers = [
    @transformer_layers.SelfAttention,
    @transformer_layers.DenseReluDense,
  ]
  decoder/make_layer_stack.layers = [
    @transformer_layers.SelfAttention,
    @transformer_layers.EncDecAttention,
    @transformer_layers.DenseReluDense,
  ]

  Args:
    input_vocab_size: a integer
    output_vocab_size: an integer
    layout: optional - an input to mtf.convert_to_layout_rules
      Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
    mesh_shape: optional - an input to mtf.convert_to_shape
      Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape
    encoder_name: optional - a string giving the Unitransformer encoder name.
    decoder_name: optional - a string giving the Unitransformer decoder name.
  Returns:
    a Bitransformer
  """
  with gin.config_scope("encoder"):
    encoder = Unitransformer(
        layer_stack=make_layer_stack(),
        input_vocab_size=input_vocab_size,
        output_vocab_size=None,
        autoregressive=False,
        name=encoder_name,
        layout=layout,
        mesh_shape=mesh_shape)
  with gin.config_scope("decoder"):
    decoder = Unitransformer(
        layer_stack=make_layer_stack(),
        input_vocab_size=output_vocab_size,
        output_vocab_size=output_vocab_size,
        autoregressive=True,
        name=decoder_name,
        layout=layout,
        mesh_shape=mesh_shape)
  return Bitransformer(encoder, decoder) 
Example #18
Source File: tfdata.py    From tensor2robot with Apache License 2.0 4 votes vote down vote up
def default_input_fn_tmpl(
    file_patterns,
    batch_size,
    feature_spec,
    label_spec,
    num_parallel_calls = 4,
    is_training = False,
    preprocess_fn=None,
    shuffle_buffer_size = 500,
    prefetch_buffer_size = (tf.data.experimental.AUTOTUNE),
    parallel_shards = 10):
  """Generic gin-configurable tf.data input pipeline."""
  if isinstance(file_patterns, dict):
    file_patterns_map = file_patterns
  else:
    file_patterns_map = {'': file_patterns}
  datasets = {}
  # Read Each Dataset
  for dataset_key, file_patterns in file_patterns_map.items():
    data_format, filenames = get_data_format_and_filenames(file_patterns)
    filenames_dataset = tf.data.Dataset.list_files(
        filenames, shuffle=is_training)
    if is_training:
      cycle_length = min(parallel_shards, len(filenames))
    else:
      cycle_length = 1
    dataset = filenames_dataset.apply(
        tf.data.experimental.parallel_interleave(
            DATA_FORMAT[data_format],
            cycle_length=cycle_length,
            sloppy=is_training))

    if is_training:
      dataset = dataset.shuffle(buffer_size=shuffle_buffer_size).repeat()
    else:
      dataset = dataset.repeat()
    dataset = dataset.batch(batch_size, drop_remainder=True)
    datasets[dataset_key] = dataset
  # Merge dict of datasets of batched serialized examples into a single dataset
  # of dicts of batched serialized examples.
  dataset = tf.data.Dataset.zip(datasets)
  # Parse all datasets together.
  dataset = serialized_to_parsed(
      dataset, feature_spec, label_spec, num_parallel_calls=num_parallel_calls)
  if preprocess_fn is not None:
    # TODO(psanketi): Consider adding num_parallel calls here.
    dataset = dataset.map(preprocess_fn, num_parallel_calls=parallel_shards)
  if prefetch_buffer_size is not None:
    dataset = dataset.prefetch(prefetch_buffer_size)
  return dataset 
Example #19
Source File: utility.py    From es_on_gke with Apache License 2.0 4 votes vote down vote up
def load_env(env_package, env_name, **kwargs):
    """Load and return an environment.

    This function loads a gym environment.

    Args:
        env_package: str. Name of the package that contains the environment.
            If env_name is 'NULL', it is a built-in gym environment.
        env_name: str. Name of the environment.
        kwargs: dict. Environment configurations.
    Returns:
        gym.Environment.
    """

    if env_package == 'NULL':
        env = gym.make(env_name)
    elif env_package == 'CUSTOMIZED':
        # The customized env must be gin.configurable
        env = env_name(**kwargs)
    else:
        pkg = getattr(pybullet_envs.bullet, env_package)
        env = getattr(pkg, env_name)(**kwargs)
        if not hasattr(env, '_cam_dist'):
            env._cam_dist = 6
            env._cam_yaw = 0
            env._cam_pitch = -30

        # Some pybullet_env do not have close() implemented, add close()
        def close():
            if hasattr(env, '_pybullet_client'):
                env._pybullet_client.resetSimulation()
                del env._pybullet_client

        env.close = close

        # Some pybullet env do not have seed() implemented, add seed()
        def seed(rand_seed):
            np.random.seed(rand_seed)
            random.seed(rand_seed)

        env.seed = seed
    return env