Python sacred.Experiment() Examples

The following are 24 code examples of sacred.Experiment(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module sacred , or try the search function .
Example #1
Source File: test_commands.py    From sacred with MIT License 6 votes vote down vote up
def test_format_named_configs():
    ingred = Ingredient("ingred")
    ex = Experiment(name="experiment", ingredients=[ingred])

    @ingred.named_config
    def named_config1():
        pass

    @ex.named_config
    def named_config2():
        """named config with doc"""
        pass

    dict_config = dict(v=42)
    ingred.add_named_config("dict_config", dict_config)

    named_configs_text = _format_named_configs(OrderedDict(ex.gather_named_configs()))
    assert named_configs_text.startswith(
        "Named Configurations (" + COLOR_DOC + "doc" + ENDC + "):"
    )
    assert "named_config2" in named_configs_text
    assert "# named config with doc" in named_configs_text
    assert "ingred.named_config1" in named_configs_text
    assert "ingred.dict_config" in named_configs_text 
Example #2
Source File: test_exceptions.py    From sacred with MIT License 6 votes vote down vote up
def test_format_filtered_stacktrace_true():
    ex = Experiment("exp")

    @ex.capture
    def f():
        raise Exception()

    try:
        f()
    except:
        st = format_filtered_stacktrace(filter_traceback="default")
        assert "captured_function" not in st
        assert "WITHOUT Sacred internals" in st

    try:
        f()
    except:
        st = format_filtered_stacktrace(filter_traceback="always")
        assert "captured_function" not in st
        assert "WITHOUT Sacred internals" in st 
Example #3
Source File: test_queue_observer.py    From sacred with MIT License 6 votes vote down vote up
def test_run_waits_for_running_queue_observer():

    queue_observer_with_long_interval = QueueObserver(
        mock.MagicMock(), interval=1, retry_interval=0.01
    )

    ex = Experiment("ator3000")
    ex.observers.append(queue_observer_with_long_interval)

    @ex.main
    def main():
        print("do nothing")

    ex.run()
    assert (
        queue_observer_with_long_interval._covered_observer.method_calls[-1][0]
        == "completed_event"
    ) 
Example #4
Source File: test_queue_observer.py    From sacred with MIT License 6 votes vote down vote up
def test_run_waits_for_running_queue_observer_after_failure():

    queue_observer_with_long_interval = QueueObserver(
        mock.MagicMock(), interval=1, retry_interval=0.01
    )

    ex = Experiment("ator3000")
    ex.observers.append(queue_observer_with_long_interval)

    @ex.main
    def main():
        raise Exception("fatal error")

    try:
        ex.run()
    except:
        pass

    assert (
        queue_observer_with_long_interval._covered_observer.method_calls[-1][0]
        == "failed_event"
    ) 
Example #5
Source File: experiment_demosaicking.py    From learn_prox_ops with GNU General Public License v3.0 6 votes vote down vote up
def grid_search_all_images(dataset, elemental):
    """
    CML command which starts a grid search for a all images of the dataset.

    :param dataset: Dataset name
    :type dataset: String
    :param elemental: General experiment configuration parameters
    :type elemental: Dict
    """
    # pylint:disable=no-value-for-parameter
    # pylint:disable=unused-variable
    grid_params = init_grid_params()
    start_grid_search(ex.path, experiment_all_images_wrapper, [dataset], grid_params)


##
## Experiment
## 
Example #6
Source File: run.py    From pb_chime5 with MIT License 5 votes vote down vote up
def run(_run, chime6, test_run=False):
    if dlp_mpi.IS_MASTER:
        print_config(_run)
        _dir = get_dir()
        print('Experiment dir:', _dir)
    else:
        _dir = None

    _dir = dlp_mpi.bcast(_dir, dlp_mpi.MASTER)

    if chime6:
        enhancer = get_enhancer_chime6()
    else:
        enhancer = get_enhancer()

    if test_run:
        print('Database', enhancer.db)

    session_ids = get_session_ids()
    if dlp_mpi.IS_MASTER:
        print('Enhancer:', enhancer)
        print(session_ids)

    enhancer.enhance_session(
        session_ids,
        _dir / 'audio',
        dataset_slice=test_run,
        audio_dir_exist_ok=True
    )
    if dlp_mpi.IS_MASTER:
        print('Finished experiment dir:', _dir) 
Example #7
Source File: modular.py    From sacred with MIT License 5 votes vote down vote up
def foo(basepath, filename, paths, settings):
    print(paths)
    print(settings)
    return basepath + filename


# ============== Experiment ============================== 
Example #8
Source File: sacred_trainer.py    From sanet_relocal_demo with GNU General Public License v3.0 5 votes vote down vote up
def __call__(self, ex: sacred.Experiment, mode, k, v):
        if mode == 'train':
            self.train_emas[k] = (
                self.ema_beta * v +
                (1.0 - self.ema_beta) * self.train_emas.get(k, v)
            )
            self.train_vals[k] = self.train_vals.get(k, []) + [v]
            ex.log_scalar(f'training.{k}', self.train_emas[k])

        elif mode == 'val':
            ex.log_scalar(f'val.{k}', np.mean(np.array(v)))
            ex.log_scalar(f'train.{k}', np.mean(np.array(self.train_vals[k])))
            self.train_vals[k] = [] 
Example #9
Source File: experiment.py    From chordrec with MIT License 5 votes vote down vote up
def setup(name):
    ex = Experiment(name)
    ex.observers.append(PickleAndSymlinkObserver())
    data.add_sacred_config(ex)
    features.add_sacred_config(ex)
    targets.add_sacred_config(ex)
    augmenters.add_sacred_config(ex)
    return ex 
Example #10
Source File: sacred_trainer.py    From PVN3D with MIT License 5 votes vote down vote up
def __call__(self, ex, mode, k, v):
        # type: (_DefaultExCallback, sacred.Experiment, Any, Any, Any) -> None
        if mode == "train":
            self.train_emas[k] = self.ema_beta * v + (
                1.0 - self.ema_beta
            ) * self.train_emas.get(k, v)
            self.train_vals[k] = self.train_vals.get(k, []) + [v]
            ex.log_scalar("training.{k}".format({"k": k}), self.train_emas[k])

        elif mode == "val":
            ex.log_scalar("val.{k}".format({"k": k}), np.mean(np.array(v)))
            ex.log_scalar(
                "train.{k}".format({"k": k}), np.mean(np.array(self.train_vals[k]))
            )
            self.train_vals[k] = [] 
Example #11
Source File: nstep_run.py    From treeqn with MIT License 5 votes vote down vote up
def fetch_parents(current_path, parents=[]):
    tmp_ex = Experiment('treeqn')
    tmp_ex.add_config(current_path)
    with suppress_stdout():
        tmp_ex.run("print_config")
    if tmp_ex.current_run is not None and "parent_config" in tmp_ex.current_run.config:
        return fetch_parents(tmp_ex.current_run.config["parent_config"], [current_path] + parents)
    else:
        return [current_path] + parents 
Example #12
Source File: train_3d.py    From margipose with Apache License 2.0 5 votes vote down vote up
def setup_showoff_output(self, notebook):
        """Setup Showoff reporting output."""

        from tele.showoff import views

        if self.with_val:
            maybe_val_views = [
                views.Images(['val_examples'], 'Validation example images', images_per_row=2),
                views.PlotlyLineGraph(['train_loss', 'val_loss'], 'Loss'),
                views.PlotlyLineGraph(['train_mpjpe', 'val_mpjpe'], '3D MPJPE'),
                views.PlotlyLineGraph(['train_pck', 'val_pck'], '3D PCK@150mm'),
            ]
        else:
            maybe_val_views = [
                views.PlotlyLineGraph(['train_loss'], 'Loss'),
                views.PlotlyLineGraph(['train_mpjpe'], '3D MPJPE'),
                views.PlotlyLineGraph(['train_pck'], '3D PCK@150mm'),
            ]
        self.telemetry.sink(tele.showoff.Conf(notebook), [
            views.Inspect(['config'], 'Experiment configuration', flatten=True),
            views.Inspect(['host_info'], 'Host information', flatten=True),
            views.Images(['train_examples'], 'Training example images', images_per_row=2),
            *maybe_val_views,
            views.PlotlyLineGraph(
                ['data_load_time', 'data_transfer_time', 'forward_time',
                 'backward_time', 'optim_time', 'eval_time'],
                'Training time breakdown'
            )
        ]) 
Example #13
Source File: experiments.py    From dts with MIT License 5 votes vote down vote up
def run_grid_search(experimentclass, db_name, ex_name, f_main, f_metrics, f_config, observer_type, log_dir=None):
    """
    Run multiple experiments exploring all the possible combinations of the given hyper-parameters.
    Each combination of parameters is an experiment and they will be stored as separate documents.
    Still, they all share the same experiment name.

    :param experimentclass: the wrapper class for the Sacred Experiment.
        Use DTSExperiemnt (see dts.experiment.DTSExperiment)
    :param db_name: str
        Name of the DB where all sort of information regarding the experiment should be stored.
        To be used only when observer_type is 'mongodb'.
    :param ex_name: str
        Experiment name
    :param f_main: the main function. Have a look in dts.examples to understand this better.
    :param f_config: str
        fullpath to the yaml file containing the parameters
    :param observer_type: 'mongodb' or 'file' depending what you want to use.
        If 'file' is used the results/logs are stored in the logs folder
        otherwise everything is stored in the DB.
    """
    parameters = yaml.load(open(f_config))
    keys = list(parameters.keys())
    values = list(parameters.values())
    for vals in product(*values):
        _run_params = dict(sorted(list(zip(keys, vals))))
        run_single_experiment(
            experimentclass=experimentclass,
            db_name=db_name,
            ex_name=ex_name,
            f_main=f_main,
            f_config=_run_params,
            f_metrics=f_metrics,
            observer_type=observer_type,
            log_dir=log_dir
        ) 
Example #14
Source File: experiments.py    From dts with MIT License 5 votes vote down vote up
def main_wrapper(f_main, ex, f_ex_capture, curr_db_name, _run):
    """
    Wrapper for the main function of an experiment.
    Ensures that the DB do not already contain an experiment with the same config as this one.

    :param f_main: function
        updates the main experiment function arguments, calls it and save the
        experiment results and artifacts.
        f_main should have the following signature: f_main(ex, _run, f_log_metrics)
    :param ex: the experiment (an instance of the Experiment class)
    :param f_ex_capture: function
        The function that implements the metrics logging API with sacred
        (should be used with Lambda in keras but has problem right now. Thus it can be ignored)
    :param curr_db_name: str
        Name of the db in use
    :param _run: the run object for the current run
        For more details about the Run object see https://sacred.readthedocs.io/en/latest/experiment.html#run-the-experiment
    """
    client = MongoClient('localhost', 27017)
    print('db = ', curr_db_name)
    db = client[curr_db_name]
    duplicate_ex = check_for_completed_experiment(db, _run.config)
    if duplicate_ex is not None:
        raise ValueError('Aborting due to a duplicate experiment')
        # return f_main(ex, _run, f_ex_capture)
    else:
        return f_main(ex, _run, f_ex_capture) 
Example #15
Source File: run_single_job.py    From gnn-benchmark with MIT License 5 votes vote down vote up
def get_experiment(name, db_host, db_port, db_name, ingredients=None, log_verbose=True):

    if ingredients is None:
        ex = Experiment(name)
    else:
        ex = Experiment(name, ingredients=ingredients)

    ex.observers.append(MongoObserver.create(
        url=f"mongodb://{db_host}:{db_port}",
        db_name=db_name)
    )
    ex.logger = _get_logger(log_verbose)
    return ex 
Example #16
Source File: jack-train.py    From jack with MIT License 5 votes vote down vote up
def fetch_parents(current_path):
    tmp_ex = Experiment('jack')
    if not isinstance(current_path, list):
        current_path = [current_path]
    all_paths = list(current_path)
    for p in current_path:
        tmp_ex.add_config(p)
        if "parent_config" in tmp_ex.configurations[-1]._conf:
            all_paths = fetch_parents(tmp_ex.configurations[-1]._conf["parent_config"]) + all_paths
    return all_paths 
Example #17
Source File: test_tinydb_observer_not_installed.py    From sacred with MIT License 5 votes vote down vote up
def ex():
    return Experiment("ator3000") 
Example #18
Source File: test_method_interception.py    From sacred with MIT License 5 votes vote down vote up
def ex():
    return Experiment("tensorflow_tests") 
Example #19
Source File: test_exceptions.py    From sacred with MIT License 5 votes vote down vote up
def test_format_filtered_stacktrace_false():
    ex = Experiment("exp")

    @ex.capture
    def f():
        raise Exception()

    try:
        f()
    except:
        st = format_filtered_stacktrace(filter_traceback="never")
        assert "captured_function" in st 
Example #20
Source File: test_exceptions.py    From sacred with MIT License 5 votes vote down vote up
def test_named_config_not_found_raises():
    ex = Experiment("exp")
    ex.main(lambda: None)
    with pytest.raises(
        NamedConfigNotFoundError,
        match='Named config not found: "not_there". ' "Available config values are:",
    ):
        ex.run(named_configs=("not_there",)) 
Example #21
Source File: test_exceptions.py    From sacred with MIT License 5 votes vote down vote up
def test_missing_config_raises():
    ex = Experiment("exp")
    ex.main(lambda a: None)
    with pytest.raises(MissingConfigError):
        ex.run() 
Example #22
Source File: test_exceptions.py    From sacred with MIT License 5 votes vote down vote up
def test_circular_dependency_raises():
    # create experiment with circular dependency
    ing = Ingredient("ing")
    ex = Experiment("exp", ingredients=[ing])
    ex.main(lambda: None)
    ing.ingredients.append(ex)

    # run and see if it raises
    with pytest.raises(CircularDependencyError, match="exp->ing->exp"):
        ex.run() 
Example #23
Source File: test_metrics_logger.py    From sacred with MIT License 5 votes vote down vote up
def ex():
    return Experiment("Test experiment") 
Example #24
Source File: ingredient.py    From sacred with MIT License 5 votes vote down vote up
def stats(filename, foo=12):
    print('Statistics for dataset "{}":'.format(filename))
    print("mean = 42.23")
    print("foo=", foo)


# ================== Experiment ===============================================