Python functools.partial() Examples
The following are 30
code examples of functools.partial().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
functools
, or try the search function
.
Example #1
Source File: misc.py From mmdetection with Apache License 2.0 | 8 votes |
def multi_apply(func, *args, **kwargs): """Apply function to a list of arguments. Note: This function applies the ``func`` to multiple inputs and map the multiple outputs of the ``func`` into different list. Each list contains the same type of outputs corresponding to different inputs. Args: func (Function): A function that will be applied to a list of arguments Returns: tuple(list): A tuple containing multiple list, each list contains a kind of returned results by the function """ pfunc = partial(func, **kwargs) if kwargs else func map_results = map(pfunc, *args) return tuple(map(list, zip(*map_results)))
Example #2
Source File: dsl.py From gql with MIT License | 6 votes |
def get_arg_serializer(arg_type): if isinstance(arg_type, GraphQLNonNull): return get_arg_serializer(arg_type.of_type) if isinstance(arg_type, GraphQLInputField): return get_arg_serializer(arg_type.type) if isinstance(arg_type, GraphQLInputObjectType): serializers = {k: get_arg_serializer(v) for k, v in arg_type.fields.items()} return lambda value: ObjectValueNode( fields=FrozenList( ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) for k, v in value.items() ) ) if isinstance(arg_type, GraphQLList): inner_serializer = get_arg_serializer(arg_type.of_type) return partial(serialize_list, inner_serializer) if isinstance(arg_type, GraphQLEnumType): return lambda value: EnumValueNode(value=arg_type.serialize(value)) return lambda value: ast_from_value(arg_type.serialize(value), arg_type)
Example #3
Source File: wspbus.py From cherrypy with BSD 3-Clause "New" or "Revised" License | 6 votes |
def subscribe(self, channel, callback=None, priority=None): """Add the given callback at the given channel (if not present). If callback is None, return a partial suitable for decorating the callback. """ if callback is None: return functools.partial( self.subscribe, channel, priority=priority, ) ch_listeners = self.listeners.setdefault(channel, set()) ch_listeners.add(callback) if priority is None: priority = getattr(callback, 'priority', 50) self._priorities[(channel, callback)] = priority
Example #4
Source File: window.py From LPHK with GNU General Public License v3.0 | 6 votes |
def popup_choice(self, window, title, image, text, choices): popup = tk.Toplevel(window) popup.resizable(False, False) if MAIN_ICON != None: if os.path.splitext(MAIN_ICON)[1].lower() == ".gif": dummy = None #popup.call('wm', 'iconphoto', popup._w, tk.PhotoImage(file=MAIN_ICON)) else: popup.iconbitmap(MAIN_ICON) popup.wm_title(title) popup.tkraise(window) def run_end(func): popup.destroy() if func != None: func() picture_label = tk.Label(popup, image=image) picture_label.photo = image picture_label.grid(column=0, row=0, rowspan=2, padx=10, pady=10) tk.Label(popup, text=text, justify=tk.CENTER).grid(column=1, row=0, columnspan=len(choices), padx=10, pady=10) for idx, choice in enumerate(choices): run_end_func = partial(run_end, choice[1]) tk.Button(popup, text=choice[0], command=run_end_func).grid(column=1 + idx, row=1, padx=10, pady=10) popup.wait_visibility() popup.grab_set() popup.wait_window()
Example #5
Source File: manager.py From wafw00f with BSD 3-Clause "New" or "Revised" License | 6 votes |
def load_plugins(): here = os.path.abspath(os.path.dirname(__file__)) get_path = partial(os.path.join, here) plugin_dir = get_path('plugins') plugin_base = PluginBase( package='wafw00f.plugins', searchpath=[plugin_dir] ) plugin_source = plugin_base.make_plugin_source( searchpath=[plugin_dir], persist=True ) plugin_dict = {} for plugin_name in plugin_source.list_plugins(): plugin_dict[plugin_name] = plugin_source.load_plugin(plugin_name) return plugin_dict
Example #6
Source File: ecs.py From aegea with Apache License 2.0 | 6 votes |
def tasks(args): list_clusters = clients.ecs.get_paginator("list_clusters") list_tasks = clients.ecs.get_paginator("list_tasks") def list_tasks_worker(worker_args): cluster, status = worker_args return cluster, status, list(paginate(list_tasks, cluster=cluster, desiredStatus=status)) def describe_tasks_worker(t, cluster=None): return clients.ecs.describe_tasks(cluster=cluster, tasks=t)["tasks"] if t else [] task_descs = [] if args.clusters is None: args.clusters = [__name__.replace(".", "_")] if args.tasks else list(paginate(list_clusters)) if args.tasks: task_descs = describe_tasks_worker(args.tasks, cluster=args.clusters[0]) else: with ThreadPoolExecutor() as executor: for cluster, status, tasks in executor.map(list_tasks_worker, product(args.clusters, args.desired_status)): worker = partial(describe_tasks_worker, cluster=cluster) descs = executor.map(worker, (tasks[pos:pos + 100] for pos in range(0, len(tasks), 100))) task_descs += sum(descs, []) page_output(tabulate(task_descs, args))
Example #7
Source File: model.py From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License | 6 votes |
def fprop(self, x, **kwargs): del kwargs my_conv = functools.partial(tf.layers.conv2d, kernel_size=3, strides=2, padding='valid', activation=tf.nn.relu, kernel_initializer=HeReLuNormalInitializer) my_dense = functools.partial( tf.layers.dense, kernel_initializer=HeReLuNormalInitializer) with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE): for depth in [96, 256, 384, 384, 256]: x = my_conv(x, depth) y = tf.layers.flatten(x) y = my_dense(y, 4096, tf.nn.relu) y = fc7 = my_dense(y, 4096, tf.nn.relu) y = my_dense(y, 1000) return {'fc7': fc7, self.O_LOGITS: y, self.O_PROBS: tf.nn.softmax(logits=y)}
Example #8
Source File: monitor.py From multibootusb with GNU General Public License v2.0 | 6 votes |
def run(self): self.monitor.start() notifier = poll.Poll.for_events( (self.monitor, 'r'), (self._stop_event.source, 'r')) while True: for file_descriptor, event in eintr_retry_call(notifier.poll): if file_descriptor == self._stop_event.source.fileno(): # in case of a stop event, close our pipe side, and # return from the thread self._stop_event.source.close() return elif file_descriptor == self.monitor.fileno() and event == 'r': read_device = partial(eintr_retry_call, self.monitor.poll, timeout=0) for device in iter(read_device, None): self._callback(device) else: raise EnvironmentError('Observed monitor hung up')
Example #9
Source File: train.py From spleeter with MIT License | 6 votes |
def _create_evaluation_spec(params, audio_adapter, audio_path): """ Setup eval spec evaluating ever n seconds :param params: TF params to build spec from. :returns: Built evaluation spec. """ input_fn = partial( get_validation_dataset, params, audio_adapter, audio_path) evaluation_spec = tf.estimator.EvalSpec( input_fn=input_fn, steps=None, throttle_secs=params['throttle_secs']) return evaluation_spec
Example #10
Source File: 2_simple_mnist.py From deep-learning-note with MIT License | 6 votes |
def __init__(self, learning_rate, max_iteration_steps, seed=None): """Initializes a `Generator` that builds `SimpleCNNs`. Args: learning_rate: The float learning rate to use. max_iteration_steps: The number of steps per iteration. seed: The random seed. Returns: An instance of `Generator`. """ self._seed = seed self._cnn_builder_fn = functools.partial( SimpleCNNBuilder, learning_rate=learning_rate, max_iteration_steps=max_iteration_steps)
Example #11
Source File: rnn_cell.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 6 votes |
def __init__(self, input_shape, num_hidden, h2h_kernel=(3, 3), h2h_dilate=(1, 1), i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1), i2h_dilate=(1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2), prefix='ConvRNN_', params=None, conv_layout='NCHW'): super(ConvRNNCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden, h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate, i2h_kernel=i2h_kernel, i2h_stride=i2h_stride, i2h_pad=i2h_pad, i2h_dilate=i2h_dilate, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, h2h_bias_initializer=h2h_bias_initializer, activation=activation, prefix=prefix, params=params, conv_layout=conv_layout)
Example #12
Source File: rnn_cell.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 6 votes |
def __init__(self, input_shape, num_hidden, h2h_kernel=(3, 3), h2h_dilate=(1, 1), i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1), i2h_dilate=(1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2), prefix='ConvLSTM_', params=None, conv_layout='NCHW'): super(ConvLSTMCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden, h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate, i2h_kernel=i2h_kernel, i2h_stride=i2h_stride, i2h_pad=i2h_pad, i2h_dilate=i2h_dilate, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, h2h_bias_initializer=h2h_bias_initializer, activation=activation, prefix=prefix, params=params, conv_layout=conv_layout)
Example #13
Source File: rnn_cell.py From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 | 6 votes |
def __init__(self, input_shape, num_hidden, h2h_kernel=(3, 3), h2h_dilate=(1, 1), i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1), i2h_dilate=(1, 1), i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2), prefix='ConvGRU_', params=None, conv_layout='NCHW'): super(ConvGRUCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden, h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate, i2h_kernel=i2h_kernel, i2h_stride=i2h_stride, i2h_pad=i2h_pad, i2h_dilate=i2h_dilate, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, h2h_bias_initializer=h2h_bias_initializer, activation=activation, prefix=prefix, params=params, conv_layout=conv_layout)
Example #14
Source File: eval.py From DOTA_models with Apache License 2.0 | 6 votes |
def main(unused_argv): assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.' assert FLAGS.eval_dir, '`eval_dir` is missing.' if FLAGS.pipeline_config_path: model_config, eval_config, input_config = get_configs_from_pipeline_file() else: model_config, eval_config, input_config = get_configs_from_multiple_files() model_fn = functools.partial( model_builder.build, model_config=model_config, is_training=False) create_input_dict_fn = functools.partial( input_reader_builder.build, input_config) label_map = label_map_util.load_labelmap(input_config.label_map_path) max_num_classes = max([item.id for item in label_map.item]) categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes) evaluator.evaluate(create_input_dict_fn, model_fn, eval_config, categories, FLAGS.checkpoint_dir, FLAGS.eval_dir)
Example #15
Source File: losses.py From DOTA_models with Apache License 2.0 | 6 votes |
def mmd_loss(source_samples, target_samples, weight, scope=None): """Adds a similarity loss term, the MMD between two representations. This Maximum Mean Discrepancy (MMD) loss is calculated with a number of different Gaussian kernels. Args: source_samples: a tensor of shape [num_samples, num_features]. target_samples: a tensor of shape [num_samples, num_features]. weight: the weight of the MMD loss. scope: optional name scope for summary tags. Returns: a scalar tensor representing the MMD loss value. """ sigmas = [ 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100, 1e3, 1e4, 1e5, 1e6 ] gaussian_kernel = partial( utils.gaussian_kernel_matrix, sigmas=tf.constant(sigmas)) loss_value = maximum_mean_discrepancy( source_samples, target_samples, kernel=gaussian_kernel) loss_value = tf.maximum(1e-4, loss_value) * weight assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value]) with tf.control_dependencies([assert_op]): tag = 'MMD Loss' if scope: tag = scope + tag tf.summary.scalar(tag, loss_value) tf.losses.add_loss(loss_value) return loss_value
Example #16
Source File: minitaur_env_randomizer_from_config.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def _build_randomization_function_dict(self, env): func_dict = {} func_dict["mass"] = functools.partial( self._randomize_masses, minitaur=env.minitaur) func_dict["inertia"] = functools.partial( self._randomize_inertia, minitaur=env.minitaur) func_dict["latency"] = functools.partial( self._randomize_latency, minitaur=env.minitaur) func_dict["joint friction"] = functools.partial( self._randomize_joint_friction, minitaur=env.minitaur) func_dict["motor friction"] = functools.partial( self._randomize_motor_friction, minitaur=env.minitaur) func_dict["restitution"] = functools.partial( self._randomize_contact_restitution, minitaur=env.minitaur) func_dict["lateral friction"] = functools.partial( self._randomize_contact_friction, minitaur=env.minitaur) func_dict["battery"] = functools.partial( self._randomize_battery_level, minitaur=env.minitaur) func_dict["motor strength"] = functools.partial( self._randomize_motor_strength, minitaur=env.minitaur) # Settinmg control step needs access to the environment. func_dict["control step"] = functools.partial( self._randomize_control_step, env=env) return func_dict
Example #17
Source File: wrappers.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def step(self, action, blocking=True): """Step the environment. Args: action: The action to apply to the environment. blocking: Whether to wait for the result. Returns: Transition tuple when blocking, otherwise callable that returns the transition tuple. """ self._conn.send((self._ACTION, action)) if blocking: return self._receive(self._TRANSITION) else: return functools.partial(self._receive, self._TRANSITION)
Example #18
Source File: data.py From End-to-end-ASR-Pytorch with MIT License | 6 votes |
def load_textset(n_jobs, use_gpu, pin_memory, corpus, text): # Text tokenizer tokenizer = load_text_encoder(**text) # Dataset tr_set, dv_set, tr_loader_bs, dv_loader_bs, data_msg = create_textset( tokenizer, **corpus) collect_tr = partial(collect_text_batch, mode='train') collect_dv = partial(collect_text_batch, mode='dev') # Dataloader (Text data stored in RAM, no need num_workers) tr_set = DataLoader(tr_set, batch_size=tr_loader_bs, shuffle=True, drop_last=True, collate_fn=collect_tr, num_workers=0, pin_memory=use_gpu) dv_set = DataLoader(dv_set, batch_size=dv_loader_bs, shuffle=False, drop_last=False, collate_fn=collect_dv, num_workers=0, pin_memory=pin_memory) # Messages to show data_msg.append('I/O spec. | Token type = {}\t| Vocab size = {}' .format(tokenizer.token_type, tokenizer.vocab_size)) return tr_set, dv_set, tokenizer.vocab_size, tokenizer, data_msg
Example #19
Source File: __init__.py From facebook-wda with MIT License | 6 votes |
def set_alert_callback(self, callback): """ Args: callback (func): called when alert popup Example of callback: def callback(session): session.alert.accept() """ if callable(callback): self.http.alert_callback = functools.partial(callback, self) else: self.http.alert_callback = None #Not working #def get_clipboard(self): # return self.http.post("/wda/getPasteboard").value # Not working #def siri_activate(self, text): # self.http.post("/wda/siri/activate", {"text": text})
Example #20
Source File: __init__.py From aegea with Apache License 2.0 | 5 votes |
def add_time_bound_args(p, snap=0): t = partial(Timestamp, snap=snap) p.add_argument("--start-time", type=t, default=Timestamp("-7d", snap=snap), help=Timestamp.__doc__, metavar="START") p.add_argument("--end-time", type=t, help=Timestamp.__doc__, metavar="END")
Example #21
Source File: losses_test.py From DOTA_models with Apache License 2.0 | 5 votes |
def test_mmd_is_zero_when_inputs_are_same(self): with self.test_session(): x = tf.random_uniform((2, 3), seed=1) kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.])) self.assertEquals(0, losses.maximum_mean_discrepancy(x, x, kernel).eval())
Example #22
Source File: losses_test.py From DOTA_models with Apache License 2.0 | 5 votes |
def test_mmd_name(self): with self.test_session(): x = tf.random_uniform((2, 3), seed=1) kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.])) loss = losses.maximum_mean_discrepancy(x, x, kernel) self.assertEquals(loss.op.name, 'MaximumMeanDiscrepancy/value')
Example #23
Source File: losses_test.py From DOTA_models with Apache License 2.0 | 5 votes |
def test_fast_mmd_is_similar_to_slow_mmd(self): with self.test_session(): x = tf.constant(np.random.normal(size=(2, 3)), tf.float32) y = tf.constant(np.random.rand(2, 3), tf.float32) cost_old = MaximumMeanDiscrepancySlow(x, y, [1.]).eval() kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.])) cost_new = losses.maximum_mean_discrepancy(x, y, kernel).eval() self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
Example #24
Source File: losses_test.py From DOTA_models with Apache License 2.0 | 5 votes |
def test_mmd_is_zero_when_distributions_are_same(self): with self.test_session(): x = tf.random_uniform((1000, 10), seed=1) y = tf.random_uniform((1000, 10), seed=3) kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([100.])) loss = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval() self.assertAlmostEqual(0, loss, delta=1e-4)
Example #25
Source File: wrappers.py From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def reset(self, blocking=True): """Reset the environment. Args: blocking: Whether to wait for the result. Returns: New observation when blocking, otherwise callable that returns the new observation. """ self._conn.send((self._RESET, None)) if blocking: return self._receive(self._OBSERV) else: return functools.partial(self._receive, self._OBSERV)
Example #26
Source File: wrappers_test.py From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def test_close_no_hang_after_init(self): constructor = functools.partial( tools.MockEnvironment, observ_shape=(2, 3), action_shape=(2,), min_duration=2, max_duration=2) env = tools.wrappers.ExternalProcess(constructor) env.close()
Example #27
Source File: wrappers_test.py From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def test_reraise_exception_in_step(self): constructor = functools.partial( MockEnvironmentCrashInStep, crash_at_step=3) env = tools.wrappers.ExternalProcess(constructor) env.reset() env.step(env.action_space.sample()) env.step(env.action_space.sample()) with self.assertRaises(Exception): env.step(env.action_space.sample())
Example #28
Source File: train_ppo_test.py From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def test_no_crash_observation_shape(self): nets = networks.ForwardGaussianPolicy, networks.RecurrentGaussianPolicy observ_shapes = (1,), (2, 3), (2, 3, 4) for network, observ_shape in itertools.product(nets, observ_shapes): config = self._define_config() with config.unlocked: config.env = functools.partial( tools.MockEnvironment, observ_shape, action_shape=(3,), min_duration=15, max_duration=15) config.max_length = 20 config.steps = 100 config.network = network for score in train.train(config, env_processes=False): float(score)
Example #29
Source File: configs.py From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def pybullet_racecar(): """Configuration for Bullet MIT Racecar task.""" locals().update(default()) # Environment env = 'RacecarBulletEnv-v0' #functools.partial(racecarGymEnv.RacecarGymEnv, isDiscrete=False, renders=True) max_length = 10 steps = 1e7 # 10M return locals()
Example #30
Source File: train_ppo_test.py From soccer-matlab with BSD 2-Clause "Simplified" License | 5 votes |
def test_no_crash_variable_duration(self): config = self._define_config() with config.unlocked: config.env = functools.partial( tools.MockEnvironment, observ_shape=(2, 3), action_shape=(3,), min_duration=5, max_duration=25) config.max_length = 25 config.steps = 200 config.network = networks.RecurrentGaussianPolicy for score in train.train(config, env_processes=False): float(score)