Python tensorflow.python.framework.ops.get_collection() Examples

The following are 30 code examples of tensorflow.python.framework.ops.get_collection(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.framework.ops , or try the search function .
Example #1
Source File: supervisor.py    From lambda-packs with MIT License 6 votes vote down vote up
def start_queue_runners(self, sess, queue_runners=None):
    """Start threads for `QueueRunners`.

    Note that the queue runners collected in the graph key `QUEUE_RUNNERS`
    are already started automatically when you create a session with the
    supervisor, so unless you have non-collected queue runners to start
    you do not need to call this explicitly.

    Args:
      sess: A `Session`.
      queue_runners: A list of `QueueRunners`. If not specified, we'll use the
        list of queue runners gathered in the graph under the key
        `GraphKeys.QUEUE_RUNNERS`.

    Returns:
      The list of threads started for the `QueueRunners`.
    """
    if queue_runners is None:
      queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
    threads = []
    for qr in queue_runners:
      threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
                                       start=True))
    return threads 
Example #2
Source File: utils.py    From tensornets with MIT License 6 votes vote down vote up
def convert_collection_to_dict(collection, clear_collection=False):
  """Returns an OrderedDict of Tensors with their aliases as keys.

  Args:
    collection: A collection.
    clear_collection: When True, it clears the collection after converting to
      OrderedDict.

  Returns:
    An OrderedDict of {alias: tensor}
  """
  output = OrderedDict((alias, tensor)
                       for tensor in ops.get_collection(collection)
                       for alias in get_tensor_aliases(tensor))
  if clear_collection:
    ops.get_default_graph().clear_collection(collection)
  return output 
Example #3
Source File: supervisor.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _get_first_op_from_collection(self, key):
    """Returns the first `Operation` from a collection.

    Args:
      key: A string collection key.

    Returns:
      The first Op found in a collection, or `None` if the collection is empty.
    """
    try:
      op_list = ops.get_collection(key)
      if len(op_list) > 1:
        logging.info("Found %d %s operations. Returning the first one.",
                     len(op_list), key)
      if op_list:
        return op_list[0]
    except LookupError:
      pass

    return None 
Example #4
Source File: supervisor.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def start_queue_runners(self, sess, queue_runners=None):
    """Start threads for `QueueRunners`.

    Note that the queue runners collected in the graph key `QUEUE_RUNNERS`
    are already started automatically when you create a session with the
    supervisor, so unless you have non-collected queue runners to start
    you do not need to call this explicitly.

    Args:
      sess: A `Session`.
      queue_runners: A list of `QueueRunners`. If not specified, we'll use the
        list of queue runners gathered in the graph under the key
        `GraphKeys.QUEUE_RUNNERS`.

    Returns:
      The list of threads started for the `QueueRunners`.
    """
    if queue_runners is None:
      queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
    threads = []
    for qr in queue_runners:
      threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
                                       start=True))
    return threads 
Example #5
Source File: loader_impl.py    From lambda-packs with MIT License 6 votes vote down vote up
def _get_main_op_tensor(meta_graph_def_to_load):
  """Gets the main op tensor, if one exists.

  Args:
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.

  Returns:
    The main op tensor, if it exists and `None` otherwise.

  Raises:
    RuntimeError: If the collection def corresponding to the main op key has
        other than exactly one tensor.
  """
  collection_def = meta_graph_def_to_load.collection_def
  main_op_tensor = None
  if constants.MAIN_OP_KEY in collection_def:
    main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value
    if len(main_ops) != 1:
      raise RuntimeError("Expected exactly one SavedModel main op.")
    main_op_tensor = ops.get_collection(constants.MAIN_OP_KEY)[0]
  return main_op_tensor 
Example #6
Source File: loader_impl.py    From lambda-packs with MIT License 6 votes vote down vote up
def _get_legacy_init_op_tensor(meta_graph_def_to_load):
  """Gets the legacy init op tensor, if one exists.

  Args:
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.

  Returns:
    The legacy init op tensor, if it exists and `None` otherwise.

  Raises:
    RuntimeError: If the collection def corresponding to the legacy init op key
        has other than exactly one tensor.
  """
  collection_def = meta_graph_def_to_load.collection_def
  legacy_init_op_tensor = None
  if constants.LEGACY_INIT_OP_KEY in collection_def:
    legacy_init_ops = collection_def[
        constants.LEGACY_INIT_OP_KEY].node_list.value
    if len(legacy_init_ops) != 1:
      raise RuntimeError("Expected exactly one legacy serving init op.")
    legacy_init_op_tensor = ops.get_collection(constants.LEGACY_INIT_OP_KEY)[0]
  return legacy_init_op_tensor 
Example #7
Source File: supervisor.py    From ctw-baseline with MIT License 6 votes vote down vote up
def start_queue_runners(self, sess, queue_runners=None):
    """Start threads for `QueueRunners`.

    Note that the queue runners collected in the graph key `QUEUE_RUNNERS`
    are already started automatically when you create a session with the
    supervisor, so unless you have non-collected queue runners to start
    you do not need to call this explicitly.

    Args:
      sess: A `Session`.
      queue_runners: A list of `QueueRunners`. If not specified, we'll use the
        list of queue runners gathered in the graph under the key
        `GraphKeys.QUEUE_RUNNERS`.

    Returns:
      The list of threads started for the `QueueRunners`.
    """
    if queue_runners is None:
      queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
    threads = []
    for qr in queue_runners:
      threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
                                       start=True))
    return threads 
Example #8
Source File: variables.py    From lambda-packs with MIT License 6 votes vote down vote up
def get_variables(scope=None, suffix=None,
                  collection=ops.GraphKeys.GLOBAL_VARIABLES):
  """Gets the list of variables, filtered by scope and/or suffix.

  Args:
    scope: an optional scope for filtering the variables to return. Can be a
      variable scope or a string.
    suffix: an optional suffix for filtering the variables to return.
    collection: in which collection search for. Defaults to
      `GraphKeys.GLOBAL_VARIABLES`.

  Returns:
    a list of variables in collection with scope and suffix.
  """
  if isinstance(scope, variable_scope.VariableScope):
    scope = scope.name
  if suffix is not None:
    if ':' not in suffix:
      suffix += ':'
    scope = (scope or '') + '.*' + suffix
  return ops.get_collection(collection, scope) 
Example #9
Source File: variables.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def local_variables():
  """Returns local variables.

  Local variables - per process variables, usually not saved/restored to
  checkpoint and used for temporary or intermediate values.
  For example, they can be used as counters for metrics computation or
  number of epochs this machine has read data.
  The `local_variable()` automatically adds new variable to
  `GraphKeys.LOCAL_VARIABLES`.
  This convenience function returns the contents of that collection.

  An alternative to local variables are global variables. See
  [`tf.global_variables()`](../../api_docs/python/state_ops.md#global_variables)

  Returns:
    A list of local `Variable` objects.
  """
  return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES) 
Example #10
Source File: supervisor.py    From ctw-baseline with MIT License 6 votes vote down vote up
def _get_first_op_from_collection(self, key):
    """Returns the first `Operation` from a collection.

    Args:
      key: A string collection key.

    Returns:
      The first Op found in a collection, or `None` if the collection is empty.
    """
    try:
      op_list = ops.get_collection(key)
      if len(op_list) > 1:
        logging.info("Found %d %s operations. Returning the first one.",
                     len(op_list), key)
      if op_list:
        return op_list[0]
    except LookupError:
      pass

    return None 
Example #11
Source File: logging_ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def get_summary_op():
  """Returns a single Summary op that would run all summaries.

  Either existing one from `SUMMARY_OP` collection or merges all existing
  summaries.

  Returns:
    If no summaries were collected, returns None. Otherwise returns a scalar
    `Tensor` of type `string` containing the serialized `Summary` protocol
    buffer resulting from the merging.
  """
  summary_op = ops.get_collection(ops.GraphKeys.SUMMARY_OP)
  if summary_op is not None:
    if summary_op:
      summary_op = summary_op[0]
    else:
      summary_op = None
  if summary_op is None:
    summary_op = merge_all_summaries()
    if summary_op is not None:
      ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op)
  return summary_op 
Example #12
Source File: logging_ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def merge_all_summaries(key=ops.GraphKeys.SUMMARIES):
  """Merges all summaries collected in the default graph.

  This op is deprecated. Please switch to tf.summary.merge_all, which has
  identical behavior.

  Args:
    key: `GraphKey` used to collect the summaries.  Defaults to
      `GraphKeys.SUMMARIES`.

  Returns:
    If no summaries were collected, returns None.  Otherwise returns a scalar
    `Tensor` of type `string` containing the serialized `Summary` protocol
    buffer resulting from the merging.
  """
  summary_ops = ops.get_collection(key)
  if not summary_ops:
    return None
  else:
    return merge_summary(summary_ops) 
Example #13
Source File: variables.py    From lambda-packs with MIT License 6 votes vote down vote up
def global_variables():
  """Returns global variables.

  Global variables are variables that are shared across machines in a
  distributed environment. The `Variable()` constructor or `get_variable()`
  automatically adds new variables to the graph collection
  `GraphKeys.GLOBAL_VARIABLES`.
  This convenience function returns the contents of that collection.

  An alternative to global variables are local variables. See
  @{tf.local_variables}

  Returns:
    A list of `Variable` objects.
  """
  return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 
Example #14
Source File: variables.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def global_variables():
  """Returns global variables.

  Global variables are variables that are shared across machines in a
  distributed environment. The `Variable()` constructor or `get_variable()`
  automatically adds new variables to the graph collection
  `GraphKeys.GLOBAL_VARIABLES`.
  This convenience function returns the contents of that collection.

  An alternative to global variables are local variables. See
  [`tf.local_variables()`](../../api_docs/python/state_ops.md#local_variables)

  Returns:
    A list of `Variable` objects.
  """
  return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 
Example #15
Source File: supervisor.py    From lambda-packs with MIT License 6 votes vote down vote up
def _get_first_op_from_collection(self, key):
    """Returns the first `Operation` from a collection.

    Args:
      key: A string collection key.

    Returns:
      The first Op found in a collection, or `None` if the collection is empty.
    """
    try:
      op_list = ops.get_collection(key)
      if len(op_list) > 1:
        logging.info("Found %d %s operations. Returning the first one.",
                     len(op_list), key)
      if op_list:
        return op_list[0]
    except LookupError:
      pass

    return None 
Example #16
Source File: summary.py    From lambda-packs with MIT License 6 votes vote down vote up
def merge_all(key=_ops.GraphKeys.SUMMARIES):
  """Merges all summaries collected in the default graph.

  Args:
    key: `GraphKey` used to collect the summaries.  Defaults to
      `GraphKeys.SUMMARIES`.

  Returns:
    If no summaries were collected, returns None.  Otherwise returns a scalar
    `Tensor` of type `string` containing the serialized `Summary` protocol
    buffer resulting from the merging.
  """
  summary_ops = _ops.get_collection(key)
  if not summary_ops:
    return None
  else:
    return merge(summary_ops) 
Example #17
Source File: rnn_cell.py    From lambda-packs with MIT License 6 votes vote down vote up
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable 
Example #18
Source File: saved_model_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def testClearDevices(self):
    export_dir = os.path.join(test.get_temp_dir(), "test_clear_devices")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Specify a device and save a variable.
    ops.reset_default_graph()
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      with sess.graph.device("/cpu:0"):
        self._init_and_validate_variable(sess, "v", 42)
        builder.add_meta_graph_and_variables(
            sess, [tag_constants.TRAINING], clear_devices=True)

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with a single predefined tag whose variables were saved
    # without any device information.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, [tag_constants.TRAINING], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 
Example #19
Source File: loader_impl.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _get_legacy_init_op_tensor(meta_graph_def_to_load):
  """Gets the legacy init op tensor, if one exists.

  Args:
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.

  Returns:
    The legacy init op tensor, if it exists and `None` otherwise.

  Raises:
    RuntimeError: If the collection def corresponding to the legacy init op key
        has other than exactly one tensor.
  """
  collection_def = meta_graph_def_to_load.collection_def
  legacy_init_op_tensor = None
  if constants.LEGACY_INIT_OP_KEY in collection_def:
    legacy_init_ops = collection_def[
        constants.LEGACY_INIT_OP_KEY].node_list.value
    if len(legacy_init_ops) != 1:
      raise RuntimeError("Expected exactly one legacy serving init op.")
    legacy_init_op_tensor = ops.get_collection(constants.LEGACY_INIT_OP_KEY)[0]
  return legacy_init_op_tensor 
Example #20
Source File: export.py    From lambda-packs with MIT License 6 votes vote down vote up
def _export_graph(graph, saver, checkpoint_path, export_dir,
                  default_graph_signature, named_graph_signatures,
                  exports_to_keep):
  """Exports graph via session_bundle, by creating a Session."""
  with graph.as_default():
    with tf_session.Session('') as session:
      variables.local_variables_initializer()
      lookup_ops.tables_initializer()
      saver.restore(session, checkpoint_path)

      export = exporter.Exporter(saver)
      export.init(
          init_op=control_flow_ops.group(
              variables.local_variables_initializer(),
              lookup_ops.tables_initializer()),
          default_graph_signature=default_graph_signature,
          named_graph_signatures=named_graph_signatures,
          assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))
      return export.export(export_dir, contrib_variables.get_global_step(),
                           session, exports_to_keep=exports_to_keep) 
Example #21
Source File: loader_impl.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _get_main_op_tensor(meta_graph_def_to_load):
  """Gets the main op tensor, if one exists.

  Args:
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.

  Returns:
    The main op tensor, if it exists and `None` otherwise.

  Raises:
    RuntimeError: If the collection def corresponding to the main op key has
        other than exactly one tensor.
  """
  collection_def = meta_graph_def_to_load.collection_def
  main_op_tensor = None
  if constants.MAIN_OP_KEY in collection_def:
    main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value
    if len(main_ops) != 1:
      raise RuntimeError("Expected exactly one SavedModel main op.")
    main_op_tensor = ops.get_collection(constants.MAIN_OP_KEY)[0]
  return main_op_tensor 
Example #22
Source File: variables.py    From lambda-packs with MIT License 6 votes vote down vote up
def local_variables():
  """Returns local variables.

  Local variables - per process variables, usually not saved/restored to
  checkpoint and used for temporary or intermediate values.
  For example, they can be used as counters for metrics computation or
  number of epochs this machine has read data.
  The `tf.contrib.framework.local_variable()` function automatically adds the
  new variable to `GraphKeys.LOCAL_VARIABLES`.
  This convenience function returns the contents of that collection.

  An alternative to local variables are global variables. See
  @{tf.global_variables}

  Returns:
    A list of local `Variable` objects.
  """
  return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES) 
Example #23
Source File: logging_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def merge_all_summaries(key=ops.GraphKeys.SUMMARIES):
  """Merges all summaries collected in the default graph.

  This op is deprecated. Please switch to tf.summary.merge_all, which has
  identical behavior.

  Args:
    key: `GraphKey` used to collect the summaries.  Defaults to
      `GraphKeys.SUMMARIES`.

  Returns:
    If no summaries were collected, returns None.  Otherwise returns a scalar
    `Tensor` of type `string` containing the serialized `Summary` protocol
    buffer resulting from the merging.
  """
  summary_ops = ops.get_collection(key)
  if not summary_ops:
    return None
  else:
    return merge_summary(summary_ops) 
Example #24
Source File: saved_model_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _build_asset_collection(self, asset_file_name, asset_file_contents,
                              asset_file_tensor_name):
    asset_filepath = os.path.join(
        compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_file_name))
    file_io.write_string_to_file(asset_filepath, asset_file_contents)
    asset_file_tensor = constant_op.constant(
        asset_filepath, name=asset_file_tensor_name)
    ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file_tensor)
    asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
    return asset_collection 
Example #25
Source File: queue_runner_impl.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
                        collection=ops.GraphKeys.QUEUE_RUNNERS):
  """Starts all queue runners collected in the graph.

  This is a companion method to `add_queue_runner()`.  It just starts
  threads for all queue runners collected in the graph.  It returns
  the list of all threads.

  Args:
    sess: `Session` used to run the queue ops.  Defaults to the
      default session.
    coord: Optional `Coordinator` for coordinating the started threads.
    daemon: Whether the threads should be marked as `daemons`, meaning
      they don't block program exit.
    start: Set to `False` to only create the threads, not start them.
    collection: A `GraphKey` specifying the graph collection to
      get the queue runners from.  Defaults to `GraphKeys.QUEUE_RUNNERS`.

  Returns:
    A list of threads.
  """
  if sess is None:
    sess = ops.get_default_session()
    if not sess:
      raise ValueError("Cannot start queue runners: No default session is "
                       "registered. Use `with sess.as_default()` or pass an "
                       "explicit session to tf.start_queue_runners(sess=sess)")
  with sess.graph.as_default():
    threads = []
    for qr in ops.get_collection(collection):
      threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon,
                                       start=start))
  return threads 
Example #26
Source File: stochastic_graph_impl.py    From lambda-packs with MIT License 5 votes vote down vote up
def _stochastic_dependencies_map(fixed_losses, stochastic_tensors=None):
  """Map stochastic tensors to the fixed losses that depend on them.

  Args:
    fixed_losses: a list of `Tensor`s.
    stochastic_tensors: a list of `StochasticTensor`s to map to fixed losses.
      If `None`, all `StochasticTensor`s in the graph will be used.

  Returns:
    A dict `dependencies` that maps `StochasticTensor` objects to subsets of
    `fixed_losses`.

    If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there
    is a direct path from `st.value()` to `loss` in the graph.
  """
  stoch_value_collection = stochastic_tensors or ops.get_collection(
      stochastic_tensor_impl.STOCHASTIC_TENSOR_COLLECTION)

  if not stoch_value_collection:
    return {}

  stoch_value_map = dict(
      (node.value(), node) for node in stoch_value_collection)

  # Step backwards through the graph to see which surrogate losses correspond
  # to which fixed_losses.
  #
  # TODO(ebrevdo): Ensure that fixed_losses and stochastic values are in the
  # same frame.
  stoch_dependencies_map = collections.defaultdict(set)
  for loss in fixed_losses:
    boundary = set([loss])
    while boundary:
      edge = boundary.pop()
      edge_stoch_node = stoch_value_map.get(edge, None)
      if edge_stoch_node:
        stoch_dependencies_map[edge_stoch_node].add(loss)
      boundary.update(edge.op.inputs)

  return stoch_dependencies_map 
Example #27
Source File: loss_ops.py    From lambda-packs with MIT License 5 votes vote down vote up
def get_regularization_losses(scope=None):
  """Gets the regularization losses.

  Args:
    scope: an optional scope for filtering the losses to return.

  Returns:
    A list of regularization losses as Tensors.
  """
  return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope) 
Example #28
Source File: core.py    From lambda-packs with MIT License 5 votes vote down vote up
def get_axis_order():
  """Get the axis_order set by any containing axis_order_scope.

  Returns:
    List of strings giving an order to use for axis names, or None, if no axis
    order is set.
  """
  # By storing axis_order in the graph, we can ensure that axis_order_scope is
  # thread-safe.
  axis_order_list = ops.get_collection(_AXIS_ORDER_KEY)
  if axis_order_list:
    axis_order, = axis_order_list
  else:
    axis_order = None
  return axis_order 
Example #29
Source File: variables.py    From lambda-packs with MIT License 5 votes vote down vote up
def add_model_variable(var):
  """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection.

  Args:
    var: a variable.
  """
  if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
    ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var) 
Example #30
Source File: variational_inference_impl.py    From lambda-packs with MIT License 5 votes vote down vote up
def _find_variational_and_priors(model,
                                 variational_with_prior,
                                 require_prior=True):
  """Find upstream StochasticTensors and match with registered priors."""
  if variational_with_prior is None:
    # pylint: disable=protected-access
    upstreams = sg._upstream_stochastic_nodes([model])
    # pylint: enable=protected-access
    upstreams = list(upstreams[model])
    if not upstreams:
      raise ValueError("No upstream stochastic nodes found for tensor: %s",
                       model)
    prior_map = dict(ops.get_collection(VI_PRIORS))
    variational_with_prior = {}
    for q in upstreams:
      if require_prior and (q not in prior_map or prior_map[q] is None):
        raise ValueError("No prior specified for StochasticTensor: %s", q)
      variational_with_prior[q] = prior_map.get(q)

  if not all(
      [isinstance(q, st.StochasticTensor) for q in variational_with_prior]):
    raise TypeError("variationals must be StochasticTensors")
  if not all([
      p is None or isinstance(p, distribution.Distribution)
      for p in variational_with_prior.values()
  ]):
    raise TypeError("priors must be Distribution objects")

  return variational_with_prior