Python mlflow.start_run() Examples

The following are 30 code examples of mlflow.start_run(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module mlflow , or try the search function .
Example #1
Source File: test_tensorflow2_autolog.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_tf_keras_autolog_persists_manually_created_run(random_train_data, random_one_hot_labels,
                                                        fit_variant):
    mlflow.tensorflow.autolog()
    with mlflow.start_run() as run:
        data = random_train_data
        labels = random_one_hot_labels

        model = create_tf_keras_model()

        if fit_variant == 'fit_generator':
            def generator():
                while True:
                    yield data, labels
            model.fit_generator(generator(), epochs=10, steps_per_epoch=1)
        else:
            model.fit(data, labels, epochs=10)

        assert mlflow.active_run()
        assert mlflow.active_run().info.run_id == run.info.run_id 
Example #2
Source File: test_tracking.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_start_run_context_manager():
    with start_run() as first_run:
        first_uuid = first_run.info.run_id
        # Check that start_run() causes the run information to be persisted in the store
        persisted_run = tracking.MlflowClient().get_run(first_uuid)
        assert persisted_run is not None
        assert persisted_run.info == first_run.info
    finished_run = tracking.MlflowClient().get_run(first_uuid)
    assert finished_run.info.status == RunStatus.to_string(RunStatus.FINISHED)
    # Launch a separate run that fails, verify the run status is FAILED and the run UUID is
    # different
    with pytest.raises(Exception):
        with start_run() as second_run:
            second_run_id = second_run.info.run_id
            raise Exception("Failing run!")
    assert second_run_id != first_uuid
    finished_run2 = tracking.MlflowClient().get_run(second_run_id)
    assert finished_run2.info.status == RunStatus.to_string(RunStatus.FAILED) 
Example #3
Source File: test_tracking.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_log_metrics_uses_millisecond_timestamp_resolution_client():
    with start_run() as active_run, mock.patch("time.time") as time_mock:
        time_mock.side_effect = lambda: 123
        mlflow_client = tracking.MlflowClient()
        run_id = active_run.info.run_id

        mlflow_client.log_metric(run_id=run_id, key="name_1", value=25)
        mlflow_client.log_metric(run_id=run_id, key="name_2", value=-3)
        mlflow_client.log_metric(run_id=run_id, key="name_1", value=30)
        mlflow_client.log_metric(run_id=run_id, key="name_1", value=40)

    metric_history_name1 = mlflow_client.get_metric_history(run_id, "name_1")
    assert set([(m.value, m.timestamp) for m in metric_history_name1]) == set([
        (25, 123 * 1000),
        (30, 123 * 1000),
        (40, 123 * 1000),
    ])
    metric_history_name2 = mlflow_client.get_metric_history(run_id, "name_2")
    assert set([(m.value, m.timestamp) for m in metric_history_name2]) == set([
        (-3, 123 * 1000),
    ]) 
Example #4
Source File: test_keras_model_export.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(model, keras_custom_env):
    artifact_path = "model"
    with mlflow.start_run():
        mlflow.keras.log_model(
            keras_model=model, artifact_path=artifact_path, conda_env=keras_custom_env)
        model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format(
            run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path))

    pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
    saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
    assert os.path.exists(saved_conda_env_path)
    assert saved_conda_env_path != keras_custom_env

    with open(keras_custom_env, "r") as f:
        keras_custom_env_parsed = yaml.safe_load(f)
    with open(saved_conda_env_path, "r") as f:
        saved_conda_env_parsed = yaml.safe_load(f)
    assert saved_conda_env_parsed == keras_custom_env_parsed 
Example #5
Source File: test_artifact_utils.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_download_artifact_from_absolute_uri_persists_data_to_specified_output_directory(tmpdir):
    artifact_file_name = "artifact.txt"
    artifact_text = "Sample artifact text"
    local_artifact_path = tmpdir.join(artifact_file_name).strpath
    with open(local_artifact_path, "w") as out:
        out.write(artifact_text)

    logged_artifact_subdir = "logged_artifact"
    with mlflow.start_run():
        mlflow.log_artifact(local_path=local_artifact_path, artifact_path=logged_artifact_subdir)
        artifact_uri = mlflow.get_artifact_uri(artifact_path=logged_artifact_subdir)

    artifact_output_path = tmpdir.join("artifact_output").strpath
    os.makedirs(artifact_output_path)
    _download_artifact_from_uri(artifact_uri=artifact_uri, output_path=artifact_output_path)
    assert logged_artifact_subdir in os.listdir(artifact_output_path)
    assert artifact_file_name in os.listdir(
        os.path.join(artifact_output_path, logged_artifact_subdir))
    with open(os.path.join(
            artifact_output_path, logged_artifact_subdir, artifact_file_name), "r") as f:
        assert f.read() == artifact_text 
Example #6
Source File: load_raw_data.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def load_raw_data(url):
    with mlflow.start_run() as mlrun:
        local_dir = tempfile.mkdtemp()
        local_filename = os.path.join(local_dir, "ml-20m.zip")
        print("Downloading %s to %s" % (url, local_filename))
        r = requests.get(url, stream=True)
        with open(local_filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)

        extracted_dir = os.path.join(local_dir, 'ml-20m')
        print("Extracting %s into %s" % (local_filename, extracted_dir))
        with zipfile.ZipFile(local_filename, 'r') as zip_ref:
            zip_ref.extractall(local_dir)

        ratings_file = os.path.join(extracted_dir, 'ratings.csv')

        print("Uploading ratings: %s" % ratings_file)
        mlflow.log_artifact(ratings_file, "ratings-csv-dir") 
Example #7
Source File: etl_data.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def etl_data(ratings_csv, max_row_limit):
    with mlflow.start_run() as mlrun:
        tmpdir = tempfile.mkdtemp()
        ratings_parquet_dir = os.path.join(tmpdir, 'ratings-parquet')
        spark = pyspark.sql.SparkSession.builder.getOrCreate()
        print("Converting ratings CSV %s to Parquet %s" % (ratings_csv, ratings_parquet_dir))
        ratings_df = spark.read \
            .option("header", "true") \
            .option("inferSchema", "true") \
            .csv(ratings_csv) \
            .drop("timestamp")  # Drop unused column
        ratings_df.show()
        if max_row_limit != -1:
            ratings_df = ratings_df.limit(max_row_limit)
        ratings_df.write.parquet(ratings_parquet_dir)
        print("Uploading Parquet ratings: %s" % ratings_parquet_dir)
        mlflow.log_artifacts(ratings_parquet_dir, "ratings-parquet-dir") 
Example #8
Source File: test_tracking.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_log_batch_validates_entity_names_and_values():
    bad_kwargs = {
        "metrics": [
            [Metric(key="../bad/metric/name", value=0.3, timestamp=3, step=0)],
            [Metric(key="ok-name", value="non-numerical-value", timestamp=3, step=0)],
            [Metric(key="ok-name", value=0.3, timestamp="non-numerical-timestamp", step=0)],
        ],
        "params": [[Param(key="../bad/param/name", value="my-val")]],
        "tags": [[Param(key="../bad/tag/name", value="my-val")]],
    }
    with start_run() as active_run:
        for kwarg, bad_values in bad_kwargs.items():
            for bad_kwarg_value in bad_values:
                final_kwargs = {
                    "run_id":  active_run.info.run_id, "metrics": [], "params": [], "tags": [],
                }
                final_kwargs[kwarg] = bad_kwarg_value
                with pytest.raises(MlflowException) as e:
                    tracking.MlflowClient().log_batch(**final_kwargs)
                assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 
Example #9
Source File: test_image_creation.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_cli_build_image_with_runs_uri_calls_expected_azure_routines(sklearn_model):
    artifact_path = "model"
    with mlflow.start_run():
        mlflow.sklearn.log_model(sk_model=sklearn_model, artifact_path=artifact_path)
        run_id = mlflow.active_run().info.run_id
    model_uri = "runs:/{run_id}/{artifact_path}".format(
        run_id=run_id, artifact_path=artifact_path)

    with AzureMLMocks() as aml_mocks:
        result = CliRunner(env={"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}).invoke(
            mlflow.azureml.cli.commands,
            [
                'build-image',
                '-m', model_uri,
                '-w', 'test_workspace',
                '-i', 'image_name',
                '-n', 'model_name',
            ])
        assert result.exit_code == 0

        assert aml_mocks["register_model"].call_count == 1
        assert aml_mocks["create_image"].call_count == 1
        assert aml_mocks["load_workspace"].call_count == 1 
Example #10
Source File: test_spark_model_export.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_mleap_model_log(spark_model_iris):
    artifact_path = "model"
    register_model_patch = mock.patch("mlflow.register_model")
    with mlflow.start_run(), register_model_patch:
        sparkm.log_model(spark_model=spark_model_iris.model,
                         sample_input=spark_model_iris.spark_df,
                         artifact_path=artifact_path,
                         registered_model_name="Model1")
        model_uri = "runs:/{run_id}/{artifact_path}".format(
            run_id=mlflow.active_run().info.run_id,
            artifact_path=artifact_path)
        mlflow.register_model.assert_called_once_with(model_uri, "Model1")

    model_path = _download_artifact_from_uri(artifact_uri=model_uri)
    config_path = os.path.join(model_path, "MLmodel")
    mlflow_model = Model.load(config_path)
    assert sparkm.FLAVOR_NAME in mlflow_model.flavors
    assert mleap.FLAVOR_NAME in mlflow_model.flavors 
Example #11
Source File: test_tracking.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_parent_create_run():
    with mlflow.start_run() as parent_run:
        parent_run_id = parent_run.info.run_id
    os.environ[_RUN_ID_ENV_VAR] = parent_run_id
    with mlflow.start_run() as parent_run:
        assert parent_run.info.run_id == parent_run_id
        with pytest.raises(Exception, match='To start a nested run'):
            mlflow.start_run()
        with mlflow.start_run(nested=True) as child_run:
            assert child_run.info.run_id != parent_run_id
            with mlflow.start_run(nested=True) as grand_child_run:
                pass

    def verify_has_parent_id_tag(child_id, expected_parent_id):
        tags = tracking.MlflowClient().get_run(child_id).data.tags
        assert tags[MLFLOW_PARENT_RUN_ID] == expected_parent_id

    verify_has_parent_id_tag(child_run.info.run_id, parent_run.info.run_id)
    verify_has_parent_id_tag(grand_child_run.info.run_id, child_run.info.run_id)
    assert mlflow.active_run() is None 
Example #12
Source File: test_cli.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_prepare_env_passes(sk_model):
    if no_conda:
        pytest.skip("This test requires conda.")

    with TempDir(chdr=True):
        with mlflow.start_run() as active_run:
            mlflow.sklearn.log_model(sk_model, "model")
            model_uri = "runs:/{run_id}/model".format(run_id=active_run.info.run_id)

        # Test with no conda
        p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri,
                              "--no-conda"], stderr=subprocess.PIPE)
        assert p.wait() == 0

        # With conda
        p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri],
                             stderr=subprocess.PIPE)
        assert p.wait() == 0

        # Should be idempotent
        p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri],
                             stderr=subprocess.PIPE)
        assert p.wait() == 0 
Example #13
Source File: tensorflow.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def _log_event(event):
    """
    Extracts metric information from the event protobuf
    """
    if not mlflow.active_run():
        try_mlflow_log(mlflow.start_run)
        global _AUTOLOG_RUN_ID
        _AUTOLOG_RUN_ID = mlflow.active_run().info.run_id
    if event.WhichOneof('what') == 'summary':
        summary = event.summary
        for v in summary.value:
            if v.HasField('simple_value'):
                if (event.step-1) % _LOG_EVERY_N_STEPS == 0:
                    _thread_pool.submit(_add_to_queue, key=v.tag,
                                        value=v.simple_value, step=event.step,
                                        time=int(time.time() * 1000),
                                        run_id=mlflow.active_run().info.run_id) 
Example #14
Source File: test_run.py    From nyaggle with MIT License 6 votes vote down vote up
def test_ignore_errors_in_mlflow_params(tmpdir_name):
    mlflow.start_run()
    mlflow.log_param('features', 'ABC')
    mlflow.log_metric('Overall', -99)

    params = {
        'objective': 'binary',
        'max_depth': 8
    }
    X, y = make_classification_df()

    result = run_experiment(params, X, y, with_mlflow=True, logging_directory=tmpdir_name, feature_list=[])

    client = mlflow.tracking.MlflowClient()
    data = client.get_run(mlflow.active_run().info.run_id).data

    assert data.metrics['Overall'] == result.metrics[-1]
    assert data.params['features'] == 'ABC'  # params cannot be overwritten

    mlflow.end_run() 
Example #15
Source File: test_run.py    From nyaggle with MIT License 6 votes vote down vote up
def test_inherit_outer_scope_run(tmpdir_name):
    mlflow.start_run()
    mlflow.log_param('foo', 1)

    params = {
        'objective': 'binary',
        'max_depth': 8
    }
    X, y = make_classification_df()

    run_experiment(params, X, y, with_mlflow=True, logging_directory=tmpdir_name)

    assert mlflow.active_run() is not None  # still valid

    client = mlflow.tracking.MlflowClient()
    data = client.get_run(mlflow.active_run().info.run_id).data

    assert data.metrics['Overall'] > 0  # recorded

    mlflow.end_run() 
Example #16
Source File: experiment.py    From nyaggle with MIT License 6 votes vote down vote up
def start(self):
        """
        Start a new experiment.
        """
        if self.with_mlflow:
            import mlflow

            if mlflow.active_run() is not None:
                active_run = mlflow.active_run()
                self.inherit_existing_run = True
            else:
                active_run = mlflow.start_run(run_name=self.mlflow_run_name, run_id=self.mlflow_run_id)
            mlflow_metadata = {
                'artifact_uri': active_run.info.artifact_uri,
                'experiment_id': active_run.info.experiment_id,
                'run_id': active_run.info.run_id
            }
            self.mlflow_run_id = active_run.info.run_id
            with open(os.path.join(self.logging_directory, 'mlflow.json'), 'w') as f:
                json.dump(mlflow_metadata, f, indent=4) 
Example #17
Source File: test_spark_model_export.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_sparkml_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies(
        spark_model_iris):
    artifact_path = "model"
    with mlflow.start_run():
        sparkm.log_model(
            spark_model=spark_model_iris.model, artifact_path=artifact_path, conda_env=None)
        model_uri = "runs:/{run_id}/{artifact_path}".format(
            run_id=mlflow.active_run().info.run_id,
            artifact_path=artifact_path)

    model_path = _download_artifact_from_uri(artifact_uri=model_uri)
    pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
    conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
    with open(conda_env_path, "r") as f:
        conda_env = yaml.safe_load(f)

    assert conda_env == sparkm.get_default_conda_env() 
Example #18
Source File: test_model.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_model_log():
    with TempDir(chdr=True) as tmp:
        experiment_id = mlflow.create_experiment("test")
        sig = ModelSignature(inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]),
                             outputs=Schema([ColSpec(name=None, type="double")]))
        input_example = {"x": 1, "y": 2}
        with mlflow.start_run(experiment_id=experiment_id) as r:
            Model.log("some/path", TestFlavor,
                      signature=sig,
                      input_example=input_example)

        local_path = _download_artifact_from_uri("runs:/{}/some/path".format(r.info.run_id),
                                                 output_path=tmp.path(""))
        loaded_model = Model.load(os.path.join(local_path, "MLmodel"))
        assert loaded_model.run_id == r.info.run_id
        assert loaded_model.artifact_path == "some/path"
        assert loaded_model.flavors == {
            "flavor1": {"a": 1, "b": 2},
            "flavor2": {"x": 1, "y": 2},
        }
        assert loaded_model.signature == sig
        path = os.path.join(local_path, loaded_model.saved_input_example_info["artifact_path"])
        x = _dataframe_from_json(path)
        assert x.to_dict(orient="records")[0] == input_example 
Example #19
Source File: test_gluon_autolog.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_autolog_persists_manually_created_run():
    mlflow.gluon.autolog()

    data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

    with mlflow.start_run() as run:

        model = HybridSequential()
        model.add(Dense(64, activation="relu"))
        model.add(Dense(64, activation="relu"))
        model.add(Dense(10))
        model.initialize()
        model.hybridize()
        trainer = Trainer(model.collect_params(), "adam",
                          optimizer_params={"learning_rate": .001, "epsilon": 1e-07})
        est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(),
                                  metrics=Accuracy(), trainer=trainer)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            est.fit(data, epochs=3)

        assert mlflow.active_run().info.run_id == run.info.run_id 
Example #20
Source File: test_cli.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def test_prepare_env_fails(sk_model):
    if no_conda:
        pytest.skip("This test requires conda.")

    with TempDir(chdr=True):
        with mlflow.start_run() as active_run:
            mlflow.sklearn.log_model(sk_model, "model",
                                     conda_env={"dependencies": ["mlflow-does-not-exist-dep==abc"]})
            model_uri = "runs:/{run_id}/model".format(run_id=active_run.info.run_id)

        # Test with no conda
        p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri,
                              "--no-conda"])
        assert p.wait() == 0

        # With conda - should fail due to bad conda environment.
        p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri])
        assert p.wait() != 0 
Example #21
Source File: test_tensorflow_autolog.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def tf_core_random_tensors():
    mlflow.tensorflow.autolog(every_n_iter=4)
    with mlflow.start_run() as run:
        sess = tf.Session()
        a = tf.constant(3.0, dtype=tf.float32)
        b = tf.constant(4.0)
        total = a + b
        tf.summary.scalar('a', a)
        tf.summary.scalar('b', b)
        merged = tf.summary.merge_all()
        dir = tempfile.mkdtemp()
        writer = tf.summary.FileWriter(dir, sess.graph)
        with sess.as_default():
            for i in range(40):
                summary, _ = sess.run([merged, total])
                writer.add_summary(summary, global_step=i)
        shutil.rmtree(dir)
        writer.close()
        sess.close()

    client = mlflow.tracking.MlflowClient()
    return client.get_run(run.info.run_id) 
Example #22
Source File: test_lightgbm_autolog.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_lgb_autolog_persists_manually_created_run(bst_params, train_set):
    mlflow.lightgbm.autolog()
    with mlflow.start_run() as run:
        lgb.train(bst_params, train_set, num_boost_round=1)
        assert mlflow.active_run()
        assert mlflow.active_run().info.run_id == run.info.run_id 
Example #23
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_start_run_exp_id_0():
    mlflow.set_experiment("some-experiment")
    # Create a run and verify that the current active experiment is the one we just set
    with mlflow.start_run() as active_run:
        exp_id = active_run.info.experiment_id
        assert exp_id != FileStore.DEFAULT_EXPERIMENT_ID
        assert MlflowClient().get_experiment(exp_id).name == "some-experiment"
    # Set experiment ID to 0 when creating a run, verify that the specified experiment ID is honored
    with mlflow.start_run(experiment_id=0) as active_run:
        assert active_run.info.experiment_id == FileStore.DEFAULT_EXPERIMENT_ID 
Example #24
Source File: test_spark_model_export.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_log_model_calls_register_model(tmpdir, spark_model_iris):
    artifact_path = "model"
    dfs_tmp_dir = os.path.join(str(tmpdir), "test")
    try:
        register_model_patch = mock.patch("mlflow.register_model")
        with mlflow.start_run(), register_model_patch:
            sparkm.log_model(artifact_path=artifact_path, spark_model=spark_model_iris.model,
                             dfs_tmpdir=dfs_tmp_dir, registered_model_name="AdsModel1")
            model_uri = "runs:/{run_id}/{artifact_path}".format(
                run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path)
            mlflow.register_model.assert_called_once_with(model_uri, "AdsModel1")
    finally:
        x = dfs_tmp_dir or sparkm.DFS_TMP
        shutil.rmtree(x) 
Example #25
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_set_experiment_with_deleted_experiment_name():
    name = "dead_exp"
    mlflow.set_experiment(name)
    with start_run() as run:
        exp_id = run.info.experiment_id

    tracking.MlflowClient().delete_experiment(exp_id)

    with pytest.raises(MlflowException):
        mlflow.set_experiment(name) 
Example #26
Source File: test_cli.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_build_docker(iris_data, sk_model):
    with mlflow.start_run() as active_run:
        mlflow.sklearn.log_model(sk_model, "model")
        model_uri = "runs:/{run_id}/model".format(run_id=active_run.info.run_id)
    x, _ = iris_data
    df = pd.DataFrame(x)
    image_name = pyfunc_build_image(model_uri, extra_args=["--install-mlflow"])
    host_port = get_safe_port()
    scoring_proc = pyfunc_serve_from_docker_image(image_name, host_port)
    _validate_with_rest_endpoint(scoring_proc, host_port, df, x, sk_model) 
Example #27
Source File: test_tracking.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_set_experiment():
    with pytest.raises(TypeError):
        mlflow.set_experiment()  # pylint: disable=no-value-for-parameter

    with pytest.raises(Exception):
        mlflow.set_experiment(None)

    with pytest.raises(Exception):
        mlflow.set_experiment("")

    name = "random_exp"
    exp_id = mlflow.create_experiment(name)
    mlflow.set_experiment(name)
    with start_run() as run:
        assert run.info.experiment_id == exp_id

    another_name = "another_experiment"
    mlflow.set_experiment(another_name)
    exp_id2 = mlflow.tracking.MlflowClient().get_experiment_by_name(another_name)
    with start_run() as another_run:
        assert another_run.info.experiment_id == exp_id2.experiment_id 
Example #28
Source File: test_spark_model_export.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_log_model_no_registered_model_name(tmpdir, spark_model_iris):
    artifact_path = "model"
    dfs_tmp_dir = os.path.join(str(tmpdir), "test")
    try:
        register_model_patch = mock.patch("mlflow.register_model")
        with mlflow.start_run(), register_model_patch:
            sparkm.log_model(artifact_path=artifact_path, spark_model=spark_model_iris.model,
                             dfs_tmpdir=dfs_tmp_dir)
            mlflow.register_model.assert_not_called()
    finally:
        x = dfs_tmp_dir or sparkm.DFS_TMP
        shutil.rmtree(x) 
Example #29
Source File: test_xgboost_autolog.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_xgb_autolog_persists_manually_created_run(bst_params, dtrain):
    mlflow.xgboost.autolog()
    with mlflow.start_run() as run:
        xgb.train(bst_params, dtrain)
        assert mlflow.active_run()
        assert mlflow.active_run().info.run_id == run.info.run_id 
Example #30
Source File: test_spark_model_export.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_default_conda_env_strips_dev_suffix_from_pyspark_version(spark_model_iris, model_path):
    mock_version_standard = mock.PropertyMock(return_value="2.4.0")
    with mock.patch("pyspark.__version__", new_callable=mock_version_standard):
        default_conda_env_standard = sparkm.get_default_conda_env()

    for dev_version in ["2.4.0.dev0", "2.4.0.dev", "2.4.0.dev1", "2.4.0dev.a", "2.4.0.devb"]:
        mock_version_dev = mock.PropertyMock(return_value=dev_version)
        with mock.patch("pyspark.__version__", new_callable=mock_version_dev):
            default_conda_env_dev = sparkm.get_default_conda_env()
            assert (default_conda_env_dev == default_conda_env_standard)

            with mlflow.start_run():
                sparkm.log_model(
                    spark_model=spark_model_iris.model, artifact_path="model", conda_env=None)
                model_uri = "runs:/{run_id}/{artifact_path}".format(
                    run_id=mlflow.active_run().info.run_id,
                    artifact_path="model")

            model_path = _download_artifact_from_uri(artifact_uri=model_uri)
            pyfunc_conf = _get_flavor_configuration(
                model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
            conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV])
            with open(conda_env_path, "r") as f:
                persisted_conda_env_dev = yaml.safe_load(f)
            assert (persisted_conda_env_dev == default_conda_env_standard)

    for unaffected_version in ["2.0", "2.3.4", "2"]:
        mock_version = mock.PropertyMock(return_value=unaffected_version)
        with mock.patch("pyspark.__version__", new_callable=mock_version):
            assert unaffected_version in yaml.safe_dump(sparkm.get_default_conda_env())