Python ray.tune() Examples

The following are 30 code examples of ray.tune(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module ray , or try the search function .
Example #1
Source File: search.py    From ConvLab with MIT License 6 votes vote down vote up
def ray_trainable(config, reporter):
    '''
    Create an instance of a trainable function for ray: https://ray.readthedocs.io/en/latest/tune-usage.html#training-api
    Lab needs a spec and a trial_index to be carried through config, pass them with config in ray.run() like so:
    config = {
        'spec': spec,
        'trial_index': tune.sample_from(lambda spec: gen_trial_index()),
        ... # normal ray config with sample, grid search etc.
    }
    '''
    from convlab.experiment.control import Trial
    # restore data carried from ray.run() config
    spec = config.pop('spec')
    trial_index = config.pop('trial_index')
    spec['meta']['trial'] = trial_index
    spec = inject_config(spec, config)
    # run SLM Lab trial
    metrics = Trial(spec).run()
    metrics.update(config) # carry config for analysis too
    # ray report to carry data in ray trial.last_result
    reporter(trial_data={trial_index: metrics}) 
Example #2
Source File: search.py    From SLM-Lab with MIT License 6 votes vote down vote up
def run_param_specs(param_specs):
    '''Run the given param_specs in parallel trials using ray. Used for benchmarking.'''
    ray.init()
    ray_trials = tune.run(
        ray_trainable,
        name='param_specs',
        config={
            'spec': tune.grid_search(param_specs),
            'trial_index': 0,
        },
        resources_per_trial=infer_trial_resources(param_specs[0]),
        num_samples=1,
        reuse_actors=False,
        server_port=util.get_port(),
    )
    ray.shutdown() 
Example #3
Source File: search.py    From SLM-Lab with MIT License 6 votes vote down vote up
def ray_trainable(config, reporter):
    '''
    Create an instance of a trainable function for ray: https://ray.readthedocs.io/en/latest/tune-usage.html#training-api
    Lab needs a spec and a trial_index to be carried through config, pass them with config in ray.run() like so:
    config = {
        'spec': spec,
        'trial_index': tune.sample_from(lambda spec: gen_trial_index()),
        ... # normal ray config with sample, grid search etc.
    }
    '''
    import os
    os.environ.pop('CUDA_VISIBLE_DEVICES', None)  # remove CUDA id restriction from ray
    from slm_lab.experiment.control import Trial
    # restore data carried from ray.run() config
    spec = config.pop('spec')
    spec = inject_config(spec, config)
    # tick trial_index with proper offset
    trial_index = config.pop('trial_index')
    spec['meta']['trial'] = trial_index - 1
    spec_util.tick(spec, 'trial')
    # run SLM Lab trial
    metrics = Trial(spec).run()
    metrics.update(config)  # carry config for analysis too
    # ray report to carry data in ray trial.last_result
    reporter(trial_data={trial_index: metrics}) 
Example #4
Source File: run_ray.py    From iroko with Apache License 2.0 6 votes vote down vote up
def get_tune_experiment(config, agent, episodes, root_dir, is_schedule):
    scheduler = None
    agent_class = get_agent(agent)
    ex_conf = {}
    ex_conf["name"] = agent
    ex_conf["run"] = agent_class
    ex_conf["local_dir"] = config["env_config"]["output_dir"]
    ex_conf["stop"] = {"episodes_total": episodes}

    if is_schedule:
        ex_conf["stop"] = {"time_total_s": 300}
        ex_conf["num_samples"] = 2
        config["env_config"]["parallel_envs"] = True
        # custom changes to experiment
        log.info("Performing tune experiment")
        config, scheduler = set_tuning_parameters(agent, config)
    ex_conf["config"] = config
    experiment = Experiment(**ex_conf)
    return experiment, scheduler 
Example #5
Source File: hyperfind.py    From incremental_learning.pytorch with MIT License 6 votes vote down vote up
def get_tune_config(tune_options, options_files):
    with open(tune_options) as f:
        options = yaml.load(f, Loader=yaml.FullLoader)

    if "epochs" in options and options["epochs"] == 1:
        raise ValueError("Using only 1 epoch, must be a mistake.")

    config = {}
    for k, v in options.items():
        if not k.startswith("var:"):
            config[k] = v
        else:
            config[k.replace("var:", "")] = tune.grid_search(v)

    if options_files is not None:
        print("Options files: {}".format(options_files))
        config["options"] = [os.path.realpath(op) for op in options_files]

    return config 
Example #6
Source File: hyperfind.py    From incremental_learning.pytorch with MIT License 6 votes vote down vote up
def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("-rd", "--ray-directory", default="/data/douillard/ray_results")
    parser.add_argument("-o", "--output-options")
    parser.add_argument("-t", "--tune")
    parser.add_argument("-g", "--gpus", nargs="+", default=["0"])
    parser.add_argument("-per", "--gpu-percent", type=float, default=0.5)
    parser.add_argument("-topn", "--topn", default=5, type=int)
    parser.add_argument("-earlystop", default="ucir", type=str)
    parser.add_argument("-options", "--options", default=None, nargs="+")
    parser.add_argument("-threads", default=2, type=int)
    parser.add_argument("-resume", default=False, action="store_true")
    parser.add_argument("-metric", default="avg_inc_acc", choices=["avg_inc_acc", "last_acc"])

    return parser.parse_args() 
Example #7
Source File: test_sync.py    From ray with Apache License 2.0 6 votes vote down vote up
def testNoSync(self):
        """Sync should not run on a single node."""

        def sync_func(source, target):
            pass

        with patch.object(CommandBasedClient, "_execute") as mock_sync:
            [trial] = tune.run(
                "__fake",
                name="foo",
                max_failures=0,
                **{
                    "stop": {
                        "training_iteration": 1
                    },
                    "sync_to_driver": sync_func
                }).trials
            self.assertEqual(mock_sync.call_count, 0) 
Example #8
Source File: test_sync.py    From ray with Apache License 2.0 6 votes vote down vote up
def testCloudFunctions(self):
        tmpdir = tempfile.mkdtemp()
        tmpdir2 = tempfile.mkdtemp()
        os.mkdir(os.path.join(tmpdir2, "foo"))

        def sync_func(local, remote):
            for filename in glob.glob(os.path.join(local, "*.json")):
                shutil.copy(filename, remote)

        [trial] = tune.run(
            "__fake",
            name="foo",
            max_failures=0,
            local_dir=tmpdir,
            stop={
                "training_iteration": 1
            },
            upload_dir=tmpdir2,
            sync_to_cloud=sync_func).trials
        test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json"))
        self.assertTrue(test_file_path)
        shutil.rmtree(tmpdir)
        shutil.rmtree(tmpdir2) 
Example #9
Source File: test_commands.py    From ray with Apache License 2.0 6 votes vote down vote up
def test_ls_with_cfg(start_ray, tmpdir):
    experiment_name = "test_ls_with_cfg"
    experiment_path = os.path.join(str(tmpdir), experiment_name)
    tune.run(
        "__fake",
        name=experiment_name,
        stop={"training_iteration": 1},
        config={"test_variable": tune.grid_search(list(range(5)))},
        local_dir=str(tmpdir))

    columns = [CONFIG_PREFIX + "test_variable", "trial_id"]
    limit = 4
    with Capturing() as output:
        commands.list_trials(experiment_path, info_keys=columns, limit=limit)
    lines = output.captured
    assert all(col in lines[1] for col in columns)
    assert lines[1].count("|") == len(columns) + 1
    assert len(lines) == 3 + limit + 1 
Example #10
Source File: test_commands.py    From ray with Apache License 2.0 6 votes vote down vote up
def test_time(start_ray, tmpdir):
    experiment_name = "test_time"
    experiment_path = os.path.join(str(tmpdir), experiment_name)
    num_samples = 2
    tune.run_experiments({
        experiment_name: {
            "run": "__fake",
            "stop": {
                "training_iteration": 1
            },
            "num_samples": num_samples,
            "local_dir": str(tmpdir)
        }
    })
    times = []
    for i in range(5):
        start = time.time()
        subprocess.check_call(["tune", "ls", experiment_path])
        times += [time.time() - start]

    assert sum(times) / len(times) < 3.0, "CLI is taking too long!" 
Example #11
Source File: test_api.py    From ray with Apache License 2.0 6 votes vote down vote up
def testLotsOfStops(self):
        class TestTrainable(Trainable):
            def step(self):
                result = {"name": self.trial_name, "trial_id": self.trial_id}
                return result

            def cleanup(self):
                time.sleep(2)
                open(os.path.join(self.logdir, "marker"), "a").close()
                return 1

        analysis = tune.run(
            TestTrainable, num_samples=10, stop={TRAINING_ITERATION: 1})
        ray.shutdown()
        for trial in analysis.trials:
            path = os.path.join(trial.logdir, "marker")
            assert os.path.exists(path) 
Example #12
Source File: test_api.py    From ray with Apache License 2.0 6 votes vote down vote up
def testBadStoppingFunction(self):
        def train(config, reporter):
            for i in range(10):
                reporter(test=i)

        class CustomStopper:
            def stop(self, result):
                return result["test"] > 6

        def stop(result):
            return result["test"] > 6

        with self.assertRaises(TuneError):
            tune.run(train, stop=CustomStopper().stop)
        with self.assertRaises(TuneError):
            tune.run(train, stop=stop) 
Example #13
Source File: test_api.py    From ray with Apache License 2.0 6 votes vote down vote up
def testStopper(self):
        def train(config, reporter):
            for i in range(10):
                reporter(test=i)

        class CustomStopper(Stopper):
            def __init__(self):
                self._count = 0

            def __call__(self, trial_id, result):
                print("called")
                self._count += 1
                return result["test"] > 6

            def stop_all(self):
                return self._count > 5

        trials = tune.run(train, num_samples=5, stop=CustomStopper()).trials
        self.assertTrue(all(t.status == Trial.TERMINATED for t in trials))
        self.assertTrue(
            any(
                t.last_result.get("training_iteration") is None
                for t in trials)) 
Example #14
Source File: test_api.py    From ray with Apache License 2.0 6 votes vote down vote up
def testTrainableCallable(self):
        def dummy_fn(config, reporter, steps):
            reporter(timesteps_total=steps, done=True)

        from functools import partial
        steps = 500
        register_trainable("test", partial(dummy_fn, steps=steps))
        [trial] = run_experiments({
            "foo": {
                "run": "test",
            }
        })
        self.assertEqual(trial.status, Trial.TERMINATED)
        self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
        [trial] = tune.run(partial(dummy_fn, steps=steps)).trials
        self.assertEqual(trial.status, Trial.TERMINATED)
        self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps) 
Example #15
Source File: test_api.py    From ray with Apache License 2.0 6 votes vote down vote up
def testLongFilename(self):
        def train(config, reporter):
            assert os.path.join(ray.utils.get_user_temp_dir(), "logdir",
                                "foo") in os.getcwd(), os.getcwd()
            reporter(timesteps_total=1)

        register_trainable("f1", train)
        run_experiments({
            "foo": {
                "run": "f1",
                "local_dir": os.path.join(ray.utils.get_user_temp_dir(),
                                          "logdir"),
                "config": {
                    "a" * 50: tune.sample_from(lambda spec: 5.0 / 7),
                    "b" * 50: tune.sample_from(lambda spec: "long" * 40),
                },
            }
        }) 
Example #16
Source File: test_api.py    From ray with Apache License 2.0 6 votes vote down vote up
def testNestedStoppingReturn(self):
        def train(config, reporter):
            for i in range(10):
                reporter(test={"test1": {"test2": i}})

        with self.assertRaises(TuneError):
            [trial] = tune.run(
                train, stop={
                    "test": {
                        "test1": {
                            "test2": 6
                        }
                    }
                }).trials
        [trial] = tune.run(train, stop={"test/test1/test2": 6}).trials
        self.assertEqual(trial.last_result["training_iteration"], 7) 
Example #17
Source File: search.py    From SLM-Lab with MIT License 5 votes vote down vote up
def run_ray_search(spec):
    '''
    Method to run ray search from experiment. Uses RandomSearch now.
    TODO support for other ray search algorithms: https://ray.readthedocs.io/en/latest/tune-searchalg.html
    '''
    logger.info(f'Running ray search for spec {spec["name"]}')
    # generate trial index to pass into Lab Trial
    global trial_index  # make gen_trial_index passable into ray.run
    trial_index = -1

    def gen_trial_index():
        global trial_index
        trial_index += 1
        return trial_index

    ray.init()

    ray_trials = tune.run(
        ray_trainable,
        name=spec['name'],
        config={
            'spec': spec,
            'trial_index': tune.sample_from(lambda spec: gen_trial_index()),
            **build_config_space(spec)
        },
        resources_per_trial=infer_trial_resources(spec),
        num_samples=spec['meta']['max_trial'],
        reuse_actors=False,
        server_port=util.get_port(),
    )
    trial_data_dict = {}  # data for Lab Experiment to analyze
    for ray_trial in ray_trials:
        ray_trial_data = ray_trial.last_result['trial_data']
        trial_data_dict.update(ray_trial_data)

    ray.shutdown()
    return trial_data_dict 
Example #18
Source File: search.py    From ConvLab with MIT License 5 votes vote down vote up
def run_ray_search(spec):
    '''
    Method to run ray search from experiment. Uses RandomSearch now.
    TODO support for other ray search algorithms: https://ray.readthedocs.io/en/latest/tune-searchalg.html
    '''
    logger.info(f'Running ray search for spec {spec["name"]}')
    # generate trial index to pass into Lab Trial
    global trial_index  # make gen_trial_index passable into ray.run
    trial_index = -1

    def gen_trial_index():
        global trial_index
        trial_index += 1
        return trial_index

    ray.init()

    ray_trials = tune.run(
        ray_trainable,
        name=spec['name'],
        config={
            "spec": spec,
            "trial_index": tune.sample_from(lambda spec: gen_trial_index()),
            **build_config_space(spec)
        },
        resources_per_trial=infer_trial_resources(spec),
        num_samples=spec['meta']['max_trial'],
        queue_trials=True,
    )
    trial_data_dict = {}  # data for Lab Experiment to analyze
    for ray_trial in ray_trials:
        ray_trial_data = ray_trial.last_result['trial_data']
        trial_data_dict.update(ray_trial_data)

    ray.shutdown()
    return trial_data_dict 
Example #19
Source File: search.py    From SLM-Lab with MIT License 5 votes vote down vote up
def build_config_space(spec):
    '''
    Build ray config space from flattened spec.search
    Specify a config space in spec using `"{key}__{space_type}": {v}`.
    Where `{space_type}` is `grid_search` of `ray.tune`, or any function name of `np.random`:
    - `grid_search`: str/int/float. v = list of choices
    - `choice`: str/int/float. v = list of choices
    - `randint`: int. v = [low, high)
    - `uniform`: float. v = [low, high)
    - `normal`: float. v = [mean, stdev)

    For example:
    - `"explore_anneal_epi__randint": [10, 60],` will sample integers uniformly from 10 to 60 for `explore_anneal_epi`,
    - `"lr__uniform": [0.001, 0.1]`, and it will sample `lr` using `np.random.uniform(0.001, 0.1)`

    If any key uses `grid_search`, it will be combined exhaustively in combination with other random sampling.
    '''
    space_types = ('grid_search', 'choice', 'randint', 'uniform', 'normal')
    config_space = {}
    for k, v in util.flatten_dict(spec['search']).items():
        key, space_type = k.split('__')
        assert space_type in space_types, f'Please specify your search variable as {key}__<space_type> in one of {space_types}'
        if space_type == 'grid_search':
            config_space[key] = tune.grid_search(v)
        elif space_type == 'choice':
            config_space[key] = tune.sample_from(lambda spec, v=v: random.choice(v))
        else:
            np_fn = getattr(np.random, space_type)
            config_space[key] = tune.sample_from(lambda spec, v=v: np_fn(*v))
    return config_space 
Example #20
Source File: run_ray.py    From iroko with Apache License 2.0 5 votes vote down vote up
def get_args(args=None):
    p = argparse.ArgumentParser()
    p.add_argument('--topo', '-t', dest='topo', type=str.lower,
                   default='dumbbell', help='The topology to operate on.')
    p.add_argument('--num_hosts', dest='num_hosts', type=int,
                   default='4', help='The number of hosts in the topology.')
    p.add_argument('--agent', '-a', dest='agent', default="PG", type=str.lower,
                   help='must be string of either: PPO, DDPG, PG,'
                   ' DCTCP, TCP_NV, PCC, or TCP')
    p.add_argument('--episodes', '-e', dest='episodes', type=int, default=5,
                   help='Total number of episodes to train the RL agent.')
    p.add_argument('--iterations', '-i', dest='timesteps',
                   type=int, default=10000,
                   help='Total number of episodes to train the RL agent.')
    p.add_argument('--pattern', '-p', dest='pattern_index', type=int,
                   default=0, help='Traffic pattern we are testing.')
    p.add_argument('--rate', '-r', dest='rate', default=10, type=int,
                   help='Maximum bandwidth in mbit that each link supports. ')
    p.add_argument('--output', dest='root_output', default=ROOT_OUTPUT_DIR,
                   help='Folder which contains all the collected metrics.')
    p.add_argument('--env', dest='env', type=str.lower,
                   default='iroko', help='The platform to run.')
    p.add_argument('--transport', dest='transport', default="udp",
                   type=str.lower, help='The transport protocol of the hosts.')
    p.add_argument('--tune', action="store_true", default=False,
                   help='Specify whether to run the tune framework')
    p.add_argument('--schedule', action="store_true", default=False,
                   help='Specify whether to perform hyperparameter tuning')
    return p.parse_args(args) 
Example #21
Source File: run_ray.py    From iroko with Apache License 2.0 5 votes vote down vote up
def tune_run(config, episodes, root_dir, is_schedule):
    agent = config['env_config']['agent']
    experiment, scheduler = get_tune_experiment(
        config, agent, episodes, root_dir, is_schedule)
    tune.run(experiment, config=config, scheduler=scheduler,
             verbose=2)
    log.info("Tune run over. Clearing dc_env...") 
Example #22
Source File: hyperfind.py    From incremental_learning.pytorch with MIT License 5 votes vote down vote up
def main():
    args = parse_args()

    set_seen_gpus(args.gpus)

    if args.tune is not None:
        config = get_tune_config(args.tune, args.options)
        config["threads"] = args.threads

        try:
            os.system("echo '\ek{}_gridsearch\e\\'".format(args.tune))
        except:
            pass

        ray.init()
        tune.run(
            train_func,
            name=args.tune.rstrip("/").split("/")[-1],
            stop={"avg_inc_acc": 100},
            config=config,
            resources_per_trial={
                "cpu": 2,
                "gpu": args.gpu_percent
            },
            local_dir=args.ray_directory,
            resume=args.resume
        )

        args.ray_directory = os.path.join(args.ray_directory, args.tune.rstrip("/").split("/")[-1])

    if args.tune is not None:
        print("\n\n", args.tune, args.options, "\n\n")

    if args.ray_directory is not None:
        best_config = analyse_ray_dump(
            _get_abs_path(args.ray_directory), args.topn, metric=args.metric
        )

        if args.output_options:
            with open(args.output_options, "w+") as f:
                yaml.dump(_convert_config(best_config), f) 
Example #23
Source File: search.py    From fast-autoaugment with MIT License 5 votes vote down vote up
def step_w_log(self):
    original = gorilla.get_original_attribute(ray.tune.trial_runner.TrialRunner, 'step')

    # log
    cnts = OrderedDict()
    for status in [Trial.RUNNING, Trial.TERMINATED, Trial.PENDING, Trial.PAUSED, Trial.ERROR]:
        cnt = len(list(filter(lambda x: x.status == status, self._trials)))
        cnts[status] = cnt
    best_top1_acc = 0.
    for trial in filter(lambda x: x.status == Trial.TERMINATED, self._trials):
        if not trial.last_result:
            continue
        best_top1_acc = max(best_top1_acc, trial.last_result['top1_valid'])
    print('iter', self._iteration, 'top1_acc=%.3f' % best_top1_acc, cnts, end='\r')
    return original(self) 
Example #24
Source File: search.py    From ConvLab with MIT License 5 votes vote down vote up
def build_config_space(spec):
    '''
    Build ray config space from flattened spec.search
    Specify a config space in spec using `"{key}__{space_type}": {v}`.
    Where `{space_type}` is `grid_search` of `ray.tune`, or any function name of `np.random`:
    - `grid_search`: str/int/float. v = list of choices
    - `choice`: str/int/float. v = list of choices
    - `randint`: int. v = [low, high)
    - `uniform`: float. v = [low, high)
    - `normal`: float. v = [mean, stdev)

    For example:
    - `"explore_anneal_epi__randint": [10, 60],` will sample integers uniformly from 10 to 60 for `explore_anneal_epi`,
    - `"lr__uniform": [0.001, 0.1]`, and it will sample `lr` using `np.random.uniform(0.001, 0.1)`

    If any key uses `grid_search`, it will be combined exhaustively in combination with other random sampling.
    '''
    space_types = ('grid_search', 'choice', 'randint', 'uniform', 'normal')
    config_space = {}
    for k, v in util.flatten_dict(spec['search']).items():
        key, space_type = k.split('__')
        assert space_type in space_types, f'Please specify your search variable as {key}__<space_type> in one of {space_types}'
        if space_type == 'grid_search':
            config_space[key] = tune.grid_search(v)
        elif space_type == 'choice':
            config_space[key] = tune.sample_from(lambda spec, v=v: random.choice(v))
        else:
            np_fn = getattr(np.random, space_type)
            config_space[key] = tune.sample_from(lambda spec, v=v: np_fn(*v))
    return config_space 
Example #25
Source File: tune_ray.py    From blueoil with Apache License 2.0 5 votes vote down vote up
def run(config_file, tunable_id, local_dir):
    register_trainable(tunable_id, TrainTunable)
    lm_config = config_util.load(config_file)

    def easydict_to_dict(config):
        if isinstance(config, EasyDict):
            config = dict(config)

        for key, value in config.items():
            if isinstance(value, EasyDict):
                value = dict(value)
                easydict_to_dict(value)
            config[key] = value
        return config

    tune_space = easydict_to_dict(lm_config['TUNE_SPACE'])
    tune_spec = easydict_to_dict(lm_config['TUNE_SPEC'])
    tune_spec['run'] = tunable_id
    tune_spec['config'] = {'lm_config': os.path.join(os.getcwd(), config_file)}
    tune_spec['local_dir'] = local_dir
    tune_spec['trial_name_creator'] = ray.tune.function(trial_str_creator)

    # Expecting use of gpus to do parameter search
    ray.init(num_cpus=multiprocessing.cpu_count() // 2, num_gpus=max(get_num_gpu(), 1))
    algo = HyperOptSearch(tune_space, max_concurrent=4, reward_attr="mean_accuracy")
    scheduler = AsyncHyperBandScheduler(time_attr="training_iteration", reward_attr="mean_accuracy", max_t=200)
    trials = run_experiments(experiments={'exp_tune': tune_spec},
                             search_alg=algo,
                             scheduler=scheduler)
    print("The best result is", get_best_result(trials, metric="mean_accuracy", param='config')) 
Example #26
Source File: train.py    From flow with MIT License 5 votes vote down vote up
def train_rllib(submodule, flags):
    """Train policies using the PPO algorithm in RLlib."""
    import ray
    from ray.tune import run_experiments

    flow_params = submodule.flow_params
    n_cpus = submodule.N_CPUS
    n_rollouts = submodule.N_ROLLOUTS
    policy_graphs = getattr(submodule, "POLICY_GRAPHS", None)
    policy_mapping_fn = getattr(submodule, "policy_mapping_fn", None)
    policies_to_train = getattr(submodule, "policies_to_train", None)

    alg_run, gym_name, config = setup_exps_rllib(
        flow_params, n_cpus, n_rollouts,
        policy_graphs, policy_mapping_fn, policies_to_train)

    ray.init(num_cpus=n_cpus + 1, object_store_memory=200 * 1024 * 1024)
    exp_config = {
        "run": alg_run,
        "env": gym_name,
        "config": {
            **config
        },
        "checkpoint_freq": 20,
        "checkpoint_at_end": True,
        "max_failures": 999,
        "stop": {
            "training_iteration": flags.num_steps,
        },
    }

    if flags.checkpoint_path is not None:
        exp_config['restore'] = flags.checkpoint_path
    run_experiments({flow_params["exp_tag"]: exp_config}) 
Example #27
Source File: registry.py    From ray with Apache License 2.0 5 votes vote down vote up
def register(self, category, key, value):
        if category not in KNOWN_CATEGORIES:
            from ray.tune import TuneError
            raise TuneError("Unknown category {} not among {}".format(
                category, KNOWN_CATEGORIES))
        self._to_flush[(category, key)] = pickle.dumps(value)
        if _internal_kv_initialized():
            self.flush_values() 
Example #28
Source File: registry.py    From ray with Apache License 2.0 5 votes vote down vote up
def register_trainable(name, trainable):
    """Register a trainable function or class.

    This enables a class or function to be accessed on every Ray process
    in the cluster.

    Args:
        name (str): Name to register.
        trainable (obj): Function or tune.Trainable class. Functions must
            take (config, status_reporter) as arguments and will be
            automatically converted into a class during registration.
    """

    from ray.tune.trainable import Trainable
    from ray.tune.function_runner import wrap_function

    if isinstance(trainable, type):
        logger.debug("Detected class for trainable.")
    elif isinstance(trainable, FunctionType):
        logger.debug("Detected function for trainable.")
        trainable = wrap_function(trainable)
    elif callable(trainable):
        logger.warning(
            "Detected unknown callable for trainable. Converting to class.")
        trainable = wrap_function(trainable)

    if not issubclass(trainable, Trainable):
        raise TypeError("Second argument must be convertable to Trainable",
                        trainable)
    _global_registry.register(TRAINABLE_CLASS, name, trainable) 
Example #29
Source File: test_api.py    From ray with Apache License 2.0 5 votes vote down vote up
def testEarlyStopping(self):
        def train(config, reporter):
            reporter(test=0)

        top = 3

        with self.assertRaises(ValueError):
            EarlyStopping("test", top=0)
        with self.assertRaises(ValueError):
            EarlyStopping("test", top="0")
        with self.assertRaises(ValueError):
            EarlyStopping("test", std=0)
        with self.assertRaises(ValueError):
            EarlyStopping("test", patience=-1)
        with self.assertRaises(ValueError):
            EarlyStopping("test", std="0")
        with self.assertRaises(ValueError):
            EarlyStopping("test", mode="0")

        stopper = EarlyStopping("test", top=top, mode="min")

        analysis = tune.run(train, num_samples=10, stop=stopper)
        self.assertTrue(
            all(t.status == Trial.TERMINATED for t in analysis.trials))
        self.assertTrue(len(analysis.dataframe()) <= top)

        patience = 5
        stopper = EarlyStopping("test", top=top, mode="min", patience=patience)

        analysis = tune.run(train, num_samples=20, stop=stopper)
        self.assertTrue(
            all(t.status == Trial.TERMINATED for t in analysis.trials))
        self.assertTrue(len(analysis.dataframe()) <= patience)

        stopper = EarlyStopping("test", top=top, mode="min")

        analysis = tune.run(train, num_samples=10, stop=stopper)
        self.assertTrue(
            all(t.status == Trial.TERMINATED for t in analysis.trials))
        self.assertTrue(len(analysis.dataframe()) <= top) 
Example #30
Source File: test_sync.py    From ray with Apache License 2.0 5 votes vote down vote up
def testClusterSyncFunction(self):
        def sync_func_driver(source, target):
            assert ":" in source, "Source {} not a remote path.".format(source)
            assert ":" not in target, "Target is supposed to be local."
            with open(os.path.join(target, "test.log2"), "w") as f:
                print("writing to", f.name)
                f.write(source)

        [trial] = tune.run(
            "__fake",
            name="foo",
            max_failures=0,
            stop={
                "training_iteration": 1
            },
            sync_to_driver=sync_func_driver).trials
        test_file_path = os.path.join(trial.logdir, "test.log2")
        self.assertFalse(os.path.exists(test_file_path))

        with patch("ray.services.get_node_ip_address") as mock_sync:
            mock_sync.return_value = "0.0.0.0"
            [trial] = tune.run(
                "__fake",
                name="foo",
                max_failures=0,
                stop={
                    "training_iteration": 1
                },
                sync_to_driver=sync_func_driver).trials
        test_file_path = os.path.join(trial.logdir, "test.log2")
        self.assertTrue(os.path.exists(test_file_path))
        os.remove(test_file_path)