Python ray.tune() Examples
The following are 30
code examples of ray.tune().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
ray
, or try the search function
.
Example #1
Source File: search.py From ConvLab with MIT License | 6 votes |
def ray_trainable(config, reporter): ''' Create an instance of a trainable function for ray: https://ray.readthedocs.io/en/latest/tune-usage.html#training-api Lab needs a spec and a trial_index to be carried through config, pass them with config in ray.run() like so: config = { 'spec': spec, 'trial_index': tune.sample_from(lambda spec: gen_trial_index()), ... # normal ray config with sample, grid search etc. } ''' from convlab.experiment.control import Trial # restore data carried from ray.run() config spec = config.pop('spec') trial_index = config.pop('trial_index') spec['meta']['trial'] = trial_index spec = inject_config(spec, config) # run SLM Lab trial metrics = Trial(spec).run() metrics.update(config) # carry config for analysis too # ray report to carry data in ray trial.last_result reporter(trial_data={trial_index: metrics})
Example #2
Source File: search.py From SLM-Lab with MIT License | 6 votes |
def run_param_specs(param_specs): '''Run the given param_specs in parallel trials using ray. Used for benchmarking.''' ray.init() ray_trials = tune.run( ray_trainable, name='param_specs', config={ 'spec': tune.grid_search(param_specs), 'trial_index': 0, }, resources_per_trial=infer_trial_resources(param_specs[0]), num_samples=1, reuse_actors=False, server_port=util.get_port(), ) ray.shutdown()
Example #3
Source File: search.py From SLM-Lab with MIT License | 6 votes |
def ray_trainable(config, reporter): ''' Create an instance of a trainable function for ray: https://ray.readthedocs.io/en/latest/tune-usage.html#training-api Lab needs a spec and a trial_index to be carried through config, pass them with config in ray.run() like so: config = { 'spec': spec, 'trial_index': tune.sample_from(lambda spec: gen_trial_index()), ... # normal ray config with sample, grid search etc. } ''' import os os.environ.pop('CUDA_VISIBLE_DEVICES', None) # remove CUDA id restriction from ray from slm_lab.experiment.control import Trial # restore data carried from ray.run() config spec = config.pop('spec') spec = inject_config(spec, config) # tick trial_index with proper offset trial_index = config.pop('trial_index') spec['meta']['trial'] = trial_index - 1 spec_util.tick(spec, 'trial') # run SLM Lab trial metrics = Trial(spec).run() metrics.update(config) # carry config for analysis too # ray report to carry data in ray trial.last_result reporter(trial_data={trial_index: metrics})
Example #4
Source File: run_ray.py From iroko with Apache License 2.0 | 6 votes |
def get_tune_experiment(config, agent, episodes, root_dir, is_schedule): scheduler = None agent_class = get_agent(agent) ex_conf = {} ex_conf["name"] = agent ex_conf["run"] = agent_class ex_conf["local_dir"] = config["env_config"]["output_dir"] ex_conf["stop"] = {"episodes_total": episodes} if is_schedule: ex_conf["stop"] = {"time_total_s": 300} ex_conf["num_samples"] = 2 config["env_config"]["parallel_envs"] = True # custom changes to experiment log.info("Performing tune experiment") config, scheduler = set_tuning_parameters(agent, config) ex_conf["config"] = config experiment = Experiment(**ex_conf) return experiment, scheduler
Example #5
Source File: hyperfind.py From incremental_learning.pytorch with MIT License | 6 votes |
def get_tune_config(tune_options, options_files): with open(tune_options) as f: options = yaml.load(f, Loader=yaml.FullLoader) if "epochs" in options and options["epochs"] == 1: raise ValueError("Using only 1 epoch, must be a mistake.") config = {} for k, v in options.items(): if not k.startswith("var:"): config[k] = v else: config[k.replace("var:", "")] = tune.grid_search(v) if options_files is not None: print("Options files: {}".format(options_files)) config["options"] = [os.path.realpath(op) for op in options_files] return config
Example #6
Source File: hyperfind.py From incremental_learning.pytorch with MIT License | 6 votes |
def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("-rd", "--ray-directory", default="/data/douillard/ray_results") parser.add_argument("-o", "--output-options") parser.add_argument("-t", "--tune") parser.add_argument("-g", "--gpus", nargs="+", default=["0"]) parser.add_argument("-per", "--gpu-percent", type=float, default=0.5) parser.add_argument("-topn", "--topn", default=5, type=int) parser.add_argument("-earlystop", default="ucir", type=str) parser.add_argument("-options", "--options", default=None, nargs="+") parser.add_argument("-threads", default=2, type=int) parser.add_argument("-resume", default=False, action="store_true") parser.add_argument("-metric", default="avg_inc_acc", choices=["avg_inc_acc", "last_acc"]) return parser.parse_args()
Example #7
Source File: test_sync.py From ray with Apache License 2.0 | 6 votes |
def testNoSync(self): """Sync should not run on a single node.""" def sync_func(source, target): pass with patch.object(CommandBasedClient, "_execute") as mock_sync: [trial] = tune.run( "__fake", name="foo", max_failures=0, **{ "stop": { "training_iteration": 1 }, "sync_to_driver": sync_func }).trials self.assertEqual(mock_sync.call_count, 0)
Example #8
Source File: test_sync.py From ray with Apache License 2.0 | 6 votes |
def testCloudFunctions(self): tmpdir = tempfile.mkdtemp() tmpdir2 = tempfile.mkdtemp() os.mkdir(os.path.join(tmpdir2, "foo")) def sync_func(local, remote): for filename in glob.glob(os.path.join(local, "*.json")): shutil.copy(filename, remote) [trial] = tune.run( "__fake", name="foo", max_failures=0, local_dir=tmpdir, stop={ "training_iteration": 1 }, upload_dir=tmpdir2, sync_to_cloud=sync_func).trials test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json")) self.assertTrue(test_file_path) shutil.rmtree(tmpdir) shutil.rmtree(tmpdir2)
Example #9
Source File: test_commands.py From ray with Apache License 2.0 | 6 votes |
def test_ls_with_cfg(start_ray, tmpdir): experiment_name = "test_ls_with_cfg" experiment_path = os.path.join(str(tmpdir), experiment_name) tune.run( "__fake", name=experiment_name, stop={"training_iteration": 1}, config={"test_variable": tune.grid_search(list(range(5)))}, local_dir=str(tmpdir)) columns = [CONFIG_PREFIX + "test_variable", "trial_id"] limit = 4 with Capturing() as output: commands.list_trials(experiment_path, info_keys=columns, limit=limit) lines = output.captured assert all(col in lines[1] for col in columns) assert lines[1].count("|") == len(columns) + 1 assert len(lines) == 3 + limit + 1
Example #10
Source File: test_commands.py From ray with Apache License 2.0 | 6 votes |
def test_time(start_ray, tmpdir): experiment_name = "test_time" experiment_path = os.path.join(str(tmpdir), experiment_name) num_samples = 2 tune.run_experiments({ experiment_name: { "run": "__fake", "stop": { "training_iteration": 1 }, "num_samples": num_samples, "local_dir": str(tmpdir) } }) times = [] for i in range(5): start = time.time() subprocess.check_call(["tune", "ls", experiment_path]) times += [time.time() - start] assert sum(times) / len(times) < 3.0, "CLI is taking too long!"
Example #11
Source File: test_api.py From ray with Apache License 2.0 | 6 votes |
def testLotsOfStops(self): class TestTrainable(Trainable): def step(self): result = {"name": self.trial_name, "trial_id": self.trial_id} return result def cleanup(self): time.sleep(2) open(os.path.join(self.logdir, "marker"), "a").close() return 1 analysis = tune.run( TestTrainable, num_samples=10, stop={TRAINING_ITERATION: 1}) ray.shutdown() for trial in analysis.trials: path = os.path.join(trial.logdir, "marker") assert os.path.exists(path)
Example #12
Source File: test_api.py From ray with Apache License 2.0 | 6 votes |
def testBadStoppingFunction(self): def train(config, reporter): for i in range(10): reporter(test=i) class CustomStopper: def stop(self, result): return result["test"] > 6 def stop(result): return result["test"] > 6 with self.assertRaises(TuneError): tune.run(train, stop=CustomStopper().stop) with self.assertRaises(TuneError): tune.run(train, stop=stop)
Example #13
Source File: test_api.py From ray with Apache License 2.0 | 6 votes |
def testStopper(self): def train(config, reporter): for i in range(10): reporter(test=i) class CustomStopper(Stopper): def __init__(self): self._count = 0 def __call__(self, trial_id, result): print("called") self._count += 1 return result["test"] > 6 def stop_all(self): return self._count > 5 trials = tune.run(train, num_samples=5, stop=CustomStopper()).trials self.assertTrue(all(t.status == Trial.TERMINATED for t in trials)) self.assertTrue( any( t.last_result.get("training_iteration") is None for t in trials))
Example #14
Source File: test_api.py From ray with Apache License 2.0 | 6 votes |
def testTrainableCallable(self): def dummy_fn(config, reporter, steps): reporter(timesteps_total=steps, done=True) from functools import partial steps = 500 register_trainable("test", partial(dummy_fn, steps=steps)) [trial] = run_experiments({ "foo": { "run": "test", } }) self.assertEqual(trial.status, Trial.TERMINATED) self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps) [trial] = tune.run(partial(dummy_fn, steps=steps)).trials self.assertEqual(trial.status, Trial.TERMINATED) self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
Example #15
Source File: test_api.py From ray with Apache License 2.0 | 6 votes |
def testLongFilename(self): def train(config, reporter): assert os.path.join(ray.utils.get_user_temp_dir(), "logdir", "foo") in os.getcwd(), os.getcwd() reporter(timesteps_total=1) register_trainable("f1", train) run_experiments({ "foo": { "run": "f1", "local_dir": os.path.join(ray.utils.get_user_temp_dir(), "logdir"), "config": { "a" * 50: tune.sample_from(lambda spec: 5.0 / 7), "b" * 50: tune.sample_from(lambda spec: "long" * 40), }, } })
Example #16
Source File: test_api.py From ray with Apache License 2.0 | 6 votes |
def testNestedStoppingReturn(self): def train(config, reporter): for i in range(10): reporter(test={"test1": {"test2": i}}) with self.assertRaises(TuneError): [trial] = tune.run( train, stop={ "test": { "test1": { "test2": 6 } } }).trials [trial] = tune.run(train, stop={"test/test1/test2": 6}).trials self.assertEqual(trial.last_result["training_iteration"], 7)
Example #17
Source File: search.py From SLM-Lab with MIT License | 5 votes |
def run_ray_search(spec): ''' Method to run ray search from experiment. Uses RandomSearch now. TODO support for other ray search algorithms: https://ray.readthedocs.io/en/latest/tune-searchalg.html ''' logger.info(f'Running ray search for spec {spec["name"]}') # generate trial index to pass into Lab Trial global trial_index # make gen_trial_index passable into ray.run trial_index = -1 def gen_trial_index(): global trial_index trial_index += 1 return trial_index ray.init() ray_trials = tune.run( ray_trainable, name=spec['name'], config={ 'spec': spec, 'trial_index': tune.sample_from(lambda spec: gen_trial_index()), **build_config_space(spec) }, resources_per_trial=infer_trial_resources(spec), num_samples=spec['meta']['max_trial'], reuse_actors=False, server_port=util.get_port(), ) trial_data_dict = {} # data for Lab Experiment to analyze for ray_trial in ray_trials: ray_trial_data = ray_trial.last_result['trial_data'] trial_data_dict.update(ray_trial_data) ray.shutdown() return trial_data_dict
Example #18
Source File: search.py From ConvLab with MIT License | 5 votes |
def run_ray_search(spec): ''' Method to run ray search from experiment. Uses RandomSearch now. TODO support for other ray search algorithms: https://ray.readthedocs.io/en/latest/tune-searchalg.html ''' logger.info(f'Running ray search for spec {spec["name"]}') # generate trial index to pass into Lab Trial global trial_index # make gen_trial_index passable into ray.run trial_index = -1 def gen_trial_index(): global trial_index trial_index += 1 return trial_index ray.init() ray_trials = tune.run( ray_trainable, name=spec['name'], config={ "spec": spec, "trial_index": tune.sample_from(lambda spec: gen_trial_index()), **build_config_space(spec) }, resources_per_trial=infer_trial_resources(spec), num_samples=spec['meta']['max_trial'], queue_trials=True, ) trial_data_dict = {} # data for Lab Experiment to analyze for ray_trial in ray_trials: ray_trial_data = ray_trial.last_result['trial_data'] trial_data_dict.update(ray_trial_data) ray.shutdown() return trial_data_dict
Example #19
Source File: search.py From SLM-Lab with MIT License | 5 votes |
def build_config_space(spec): ''' Build ray config space from flattened spec.search Specify a config space in spec using `"{key}__{space_type}": {v}`. Where `{space_type}` is `grid_search` of `ray.tune`, or any function name of `np.random`: - `grid_search`: str/int/float. v = list of choices - `choice`: str/int/float. v = list of choices - `randint`: int. v = [low, high) - `uniform`: float. v = [low, high) - `normal`: float. v = [mean, stdev) For example: - `"explore_anneal_epi__randint": [10, 60],` will sample integers uniformly from 10 to 60 for `explore_anneal_epi`, - `"lr__uniform": [0.001, 0.1]`, and it will sample `lr` using `np.random.uniform(0.001, 0.1)` If any key uses `grid_search`, it will be combined exhaustively in combination with other random sampling. ''' space_types = ('grid_search', 'choice', 'randint', 'uniform', 'normal') config_space = {} for k, v in util.flatten_dict(spec['search']).items(): key, space_type = k.split('__') assert space_type in space_types, f'Please specify your search variable as {key}__<space_type> in one of {space_types}' if space_type == 'grid_search': config_space[key] = tune.grid_search(v) elif space_type == 'choice': config_space[key] = tune.sample_from(lambda spec, v=v: random.choice(v)) else: np_fn = getattr(np.random, space_type) config_space[key] = tune.sample_from(lambda spec, v=v: np_fn(*v)) return config_space
Example #20
Source File: run_ray.py From iroko with Apache License 2.0 | 5 votes |
def get_args(args=None): p = argparse.ArgumentParser() p.add_argument('--topo', '-t', dest='topo', type=str.lower, default='dumbbell', help='The topology to operate on.') p.add_argument('--num_hosts', dest='num_hosts', type=int, default='4', help='The number of hosts in the topology.') p.add_argument('--agent', '-a', dest='agent', default="PG", type=str.lower, help='must be string of either: PPO, DDPG, PG,' ' DCTCP, TCP_NV, PCC, or TCP') p.add_argument('--episodes', '-e', dest='episodes', type=int, default=5, help='Total number of episodes to train the RL agent.') p.add_argument('--iterations', '-i', dest='timesteps', type=int, default=10000, help='Total number of episodes to train the RL agent.') p.add_argument('--pattern', '-p', dest='pattern_index', type=int, default=0, help='Traffic pattern we are testing.') p.add_argument('--rate', '-r', dest='rate', default=10, type=int, help='Maximum bandwidth in mbit that each link supports. ') p.add_argument('--output', dest='root_output', default=ROOT_OUTPUT_DIR, help='Folder which contains all the collected metrics.') p.add_argument('--env', dest='env', type=str.lower, default='iroko', help='The platform to run.') p.add_argument('--transport', dest='transport', default="udp", type=str.lower, help='The transport protocol of the hosts.') p.add_argument('--tune', action="store_true", default=False, help='Specify whether to run the tune framework') p.add_argument('--schedule', action="store_true", default=False, help='Specify whether to perform hyperparameter tuning') return p.parse_args(args)
Example #21
Source File: run_ray.py From iroko with Apache License 2.0 | 5 votes |
def tune_run(config, episodes, root_dir, is_schedule): agent = config['env_config']['agent'] experiment, scheduler = get_tune_experiment( config, agent, episodes, root_dir, is_schedule) tune.run(experiment, config=config, scheduler=scheduler, verbose=2) log.info("Tune run over. Clearing dc_env...")
Example #22
Source File: hyperfind.py From incremental_learning.pytorch with MIT License | 5 votes |
def main(): args = parse_args() set_seen_gpus(args.gpus) if args.tune is not None: config = get_tune_config(args.tune, args.options) config["threads"] = args.threads try: os.system("echo '\ek{}_gridsearch\e\\'".format(args.tune)) except: pass ray.init() tune.run( train_func, name=args.tune.rstrip("/").split("/")[-1], stop={"avg_inc_acc": 100}, config=config, resources_per_trial={ "cpu": 2, "gpu": args.gpu_percent }, local_dir=args.ray_directory, resume=args.resume ) args.ray_directory = os.path.join(args.ray_directory, args.tune.rstrip("/").split("/")[-1]) if args.tune is not None: print("\n\n", args.tune, args.options, "\n\n") if args.ray_directory is not None: best_config = analyse_ray_dump( _get_abs_path(args.ray_directory), args.topn, metric=args.metric ) if args.output_options: with open(args.output_options, "w+") as f: yaml.dump(_convert_config(best_config), f)
Example #23
Source File: search.py From fast-autoaugment with MIT License | 5 votes |
def step_w_log(self): original = gorilla.get_original_attribute(ray.tune.trial_runner.TrialRunner, 'step') # log cnts = OrderedDict() for status in [Trial.RUNNING, Trial.TERMINATED, Trial.PENDING, Trial.PAUSED, Trial.ERROR]: cnt = len(list(filter(lambda x: x.status == status, self._trials))) cnts[status] = cnt best_top1_acc = 0. for trial in filter(lambda x: x.status == Trial.TERMINATED, self._trials): if not trial.last_result: continue best_top1_acc = max(best_top1_acc, trial.last_result['top1_valid']) print('iter', self._iteration, 'top1_acc=%.3f' % best_top1_acc, cnts, end='\r') return original(self)
Example #24
Source File: search.py From ConvLab with MIT License | 5 votes |
def build_config_space(spec): ''' Build ray config space from flattened spec.search Specify a config space in spec using `"{key}__{space_type}": {v}`. Where `{space_type}` is `grid_search` of `ray.tune`, or any function name of `np.random`: - `grid_search`: str/int/float. v = list of choices - `choice`: str/int/float. v = list of choices - `randint`: int. v = [low, high) - `uniform`: float. v = [low, high) - `normal`: float. v = [mean, stdev) For example: - `"explore_anneal_epi__randint": [10, 60],` will sample integers uniformly from 10 to 60 for `explore_anneal_epi`, - `"lr__uniform": [0.001, 0.1]`, and it will sample `lr` using `np.random.uniform(0.001, 0.1)` If any key uses `grid_search`, it will be combined exhaustively in combination with other random sampling. ''' space_types = ('grid_search', 'choice', 'randint', 'uniform', 'normal') config_space = {} for k, v in util.flatten_dict(spec['search']).items(): key, space_type = k.split('__') assert space_type in space_types, f'Please specify your search variable as {key}__<space_type> in one of {space_types}' if space_type == 'grid_search': config_space[key] = tune.grid_search(v) elif space_type == 'choice': config_space[key] = tune.sample_from(lambda spec, v=v: random.choice(v)) else: np_fn = getattr(np.random, space_type) config_space[key] = tune.sample_from(lambda spec, v=v: np_fn(*v)) return config_space
Example #25
Source File: tune_ray.py From blueoil with Apache License 2.0 | 5 votes |
def run(config_file, tunable_id, local_dir): register_trainable(tunable_id, TrainTunable) lm_config = config_util.load(config_file) def easydict_to_dict(config): if isinstance(config, EasyDict): config = dict(config) for key, value in config.items(): if isinstance(value, EasyDict): value = dict(value) easydict_to_dict(value) config[key] = value return config tune_space = easydict_to_dict(lm_config['TUNE_SPACE']) tune_spec = easydict_to_dict(lm_config['TUNE_SPEC']) tune_spec['run'] = tunable_id tune_spec['config'] = {'lm_config': os.path.join(os.getcwd(), config_file)} tune_spec['local_dir'] = local_dir tune_spec['trial_name_creator'] = ray.tune.function(trial_str_creator) # Expecting use of gpus to do parameter search ray.init(num_cpus=multiprocessing.cpu_count() // 2, num_gpus=max(get_num_gpu(), 1)) algo = HyperOptSearch(tune_space, max_concurrent=4, reward_attr="mean_accuracy") scheduler = AsyncHyperBandScheduler(time_attr="training_iteration", reward_attr="mean_accuracy", max_t=200) trials = run_experiments(experiments={'exp_tune': tune_spec}, search_alg=algo, scheduler=scheduler) print("The best result is", get_best_result(trials, metric="mean_accuracy", param='config'))
Example #26
Source File: train.py From flow with MIT License | 5 votes |
def train_rllib(submodule, flags): """Train policies using the PPO algorithm in RLlib.""" import ray from ray.tune import run_experiments flow_params = submodule.flow_params n_cpus = submodule.N_CPUS n_rollouts = submodule.N_ROLLOUTS policy_graphs = getattr(submodule, "POLICY_GRAPHS", None) policy_mapping_fn = getattr(submodule, "policy_mapping_fn", None) policies_to_train = getattr(submodule, "policies_to_train", None) alg_run, gym_name, config = setup_exps_rllib( flow_params, n_cpus, n_rollouts, policy_graphs, policy_mapping_fn, policies_to_train) ray.init(num_cpus=n_cpus + 1, object_store_memory=200 * 1024 * 1024) exp_config = { "run": alg_run, "env": gym_name, "config": { **config }, "checkpoint_freq": 20, "checkpoint_at_end": True, "max_failures": 999, "stop": { "training_iteration": flags.num_steps, }, } if flags.checkpoint_path is not None: exp_config['restore'] = flags.checkpoint_path run_experiments({flow_params["exp_tag"]: exp_config})
Example #27
Source File: registry.py From ray with Apache License 2.0 | 5 votes |
def register(self, category, key, value): if category not in KNOWN_CATEGORIES: from ray.tune import TuneError raise TuneError("Unknown category {} not among {}".format( category, KNOWN_CATEGORIES)) self._to_flush[(category, key)] = pickle.dumps(value) if _internal_kv_initialized(): self.flush_values()
Example #28
Source File: registry.py From ray with Apache License 2.0 | 5 votes |
def register_trainable(name, trainable): """Register a trainable function or class. This enables a class or function to be accessed on every Ray process in the cluster. Args: name (str): Name to register. trainable (obj): Function or tune.Trainable class. Functions must take (config, status_reporter) as arguments and will be automatically converted into a class during registration. """ from ray.tune.trainable import Trainable from ray.tune.function_runner import wrap_function if isinstance(trainable, type): logger.debug("Detected class for trainable.") elif isinstance(trainable, FunctionType): logger.debug("Detected function for trainable.") trainable = wrap_function(trainable) elif callable(trainable): logger.warning( "Detected unknown callable for trainable. Converting to class.") trainable = wrap_function(trainable) if not issubclass(trainable, Trainable): raise TypeError("Second argument must be convertable to Trainable", trainable) _global_registry.register(TRAINABLE_CLASS, name, trainable)
Example #29
Source File: test_api.py From ray with Apache License 2.0 | 5 votes |
def testEarlyStopping(self): def train(config, reporter): reporter(test=0) top = 3 with self.assertRaises(ValueError): EarlyStopping("test", top=0) with self.assertRaises(ValueError): EarlyStopping("test", top="0") with self.assertRaises(ValueError): EarlyStopping("test", std=0) with self.assertRaises(ValueError): EarlyStopping("test", patience=-1) with self.assertRaises(ValueError): EarlyStopping("test", std="0") with self.assertRaises(ValueError): EarlyStopping("test", mode="0") stopper = EarlyStopping("test", top=top, mode="min") analysis = tune.run(train, num_samples=10, stop=stopper) self.assertTrue( all(t.status == Trial.TERMINATED for t in analysis.trials)) self.assertTrue(len(analysis.dataframe()) <= top) patience = 5 stopper = EarlyStopping("test", top=top, mode="min", patience=patience) analysis = tune.run(train, num_samples=20, stop=stopper) self.assertTrue( all(t.status == Trial.TERMINATED for t in analysis.trials)) self.assertTrue(len(analysis.dataframe()) <= patience) stopper = EarlyStopping("test", top=top, mode="min") analysis = tune.run(train, num_samples=10, stop=stopper) self.assertTrue( all(t.status == Trial.TERMINATED for t in analysis.trials)) self.assertTrue(len(analysis.dataframe()) <= top)
Example #30
Source File: test_sync.py From ray with Apache License 2.0 | 5 votes |
def testClusterSyncFunction(self): def sync_func_driver(source, target): assert ":" in source, "Source {} not a remote path.".format(source) assert ":" not in target, "Target is supposed to be local." with open(os.path.join(target, "test.log2"), "w") as f: print("writing to", f.name) f.write(source) [trial] = tune.run( "__fake", name="foo", max_failures=0, stop={ "training_iteration": 1 }, sync_to_driver=sync_func_driver).trials test_file_path = os.path.join(trial.logdir, "test.log2") self.assertFalse(os.path.exists(test_file_path)) with patch("ray.services.get_node_ip_address") as mock_sync: mock_sync.return_value = "0.0.0.0" [trial] = tune.run( "__fake", name="foo", max_failures=0, stop={ "training_iteration": 1 }, sync_to_driver=sync_func_driver).trials test_file_path = os.path.join(trial.logdir, "test.log2") self.assertTrue(os.path.exists(test_file_path)) os.remove(test_file_path)