Python mlflow.start_run() Examples
The following are 30
code examples of mlflow.start_run().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
mlflow
, or try the search function
.
Example #1
Source File: test_tensorflow2_autolog.py From mlflow with Apache License 2.0 | 6 votes |
def test_tf_keras_autolog_persists_manually_created_run(random_train_data, random_one_hot_labels, fit_variant): mlflow.tensorflow.autolog() with mlflow.start_run() as run: data = random_train_data labels = random_one_hot_labels model = create_tf_keras_model() if fit_variant == 'fit_generator': def generator(): while True: yield data, labels model.fit_generator(generator(), epochs=10, steps_per_epoch=1) else: model.fit(data, labels, epochs=10) assert mlflow.active_run() assert mlflow.active_run().info.run_id == run.info.run_id
Example #2
Source File: test_tracking.py From mlflow with Apache License 2.0 | 6 votes |
def test_start_run_context_manager(): with start_run() as first_run: first_uuid = first_run.info.run_id # Check that start_run() causes the run information to be persisted in the store persisted_run = tracking.MlflowClient().get_run(first_uuid) assert persisted_run is not None assert persisted_run.info == first_run.info finished_run = tracking.MlflowClient().get_run(first_uuid) assert finished_run.info.status == RunStatus.to_string(RunStatus.FINISHED) # Launch a separate run that fails, verify the run status is FAILED and the run UUID is # different with pytest.raises(Exception): with start_run() as second_run: second_run_id = second_run.info.run_id raise Exception("Failing run!") assert second_run_id != first_uuid finished_run2 = tracking.MlflowClient().get_run(second_run_id) assert finished_run2.info.status == RunStatus.to_string(RunStatus.FAILED)
Example #3
Source File: test_tracking.py From mlflow with Apache License 2.0 | 6 votes |
def test_log_metrics_uses_millisecond_timestamp_resolution_client(): with start_run() as active_run, mock.patch("time.time") as time_mock: time_mock.side_effect = lambda: 123 mlflow_client = tracking.MlflowClient() run_id = active_run.info.run_id mlflow_client.log_metric(run_id=run_id, key="name_1", value=25) mlflow_client.log_metric(run_id=run_id, key="name_2", value=-3) mlflow_client.log_metric(run_id=run_id, key="name_1", value=30) mlflow_client.log_metric(run_id=run_id, key="name_1", value=40) metric_history_name1 = mlflow_client.get_metric_history(run_id, "name_1") assert set([(m.value, m.timestamp) for m in metric_history_name1]) == set([ (25, 123 * 1000), (30, 123 * 1000), (40, 123 * 1000), ]) metric_history_name2 = mlflow_client.get_metric_history(run_id, "name_2") assert set([(m.value, m.timestamp) for m in metric_history_name2]) == set([ (-3, 123 * 1000), ])
Example #4
Source File: test_keras_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(model, keras_custom_env): artifact_path = "model" with mlflow.start_run(): mlflow.keras.log_model( keras_model=model, artifact_path=artifact_path, conda_env=keras_custom_env) model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path)) pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) assert os.path.exists(saved_conda_env_path) assert saved_conda_env_path != keras_custom_env with open(keras_custom_env, "r") as f: keras_custom_env_parsed = yaml.safe_load(f) with open(saved_conda_env_path, "r") as f: saved_conda_env_parsed = yaml.safe_load(f) assert saved_conda_env_parsed == keras_custom_env_parsed
Example #5
Source File: test_artifact_utils.py From mlflow with Apache License 2.0 | 6 votes |
def test_download_artifact_from_absolute_uri_persists_data_to_specified_output_directory(tmpdir): artifact_file_name = "artifact.txt" artifact_text = "Sample artifact text" local_artifact_path = tmpdir.join(artifact_file_name).strpath with open(local_artifact_path, "w") as out: out.write(artifact_text) logged_artifact_subdir = "logged_artifact" with mlflow.start_run(): mlflow.log_artifact(local_path=local_artifact_path, artifact_path=logged_artifact_subdir) artifact_uri = mlflow.get_artifact_uri(artifact_path=logged_artifact_subdir) artifact_output_path = tmpdir.join("artifact_output").strpath os.makedirs(artifact_output_path) _download_artifact_from_uri(artifact_uri=artifact_uri, output_path=artifact_output_path) assert logged_artifact_subdir in os.listdir(artifact_output_path) assert artifact_file_name in os.listdir( os.path.join(artifact_output_path, logged_artifact_subdir)) with open(os.path.join( artifact_output_path, logged_artifact_subdir, artifact_file_name), "r") as f: assert f.read() == artifact_text
Example #6
Source File: load_raw_data.py From mlflow with Apache License 2.0 | 6 votes |
def load_raw_data(url): with mlflow.start_run() as mlrun: local_dir = tempfile.mkdtemp() local_filename = os.path.join(local_dir, "ml-20m.zip") print("Downloading %s to %s" % (url, local_filename)) r = requests.get(url, stream=True) with open(local_filename, 'wb') as f: for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks f.write(chunk) extracted_dir = os.path.join(local_dir, 'ml-20m') print("Extracting %s into %s" % (local_filename, extracted_dir)) with zipfile.ZipFile(local_filename, 'r') as zip_ref: zip_ref.extractall(local_dir) ratings_file = os.path.join(extracted_dir, 'ratings.csv') print("Uploading ratings: %s" % ratings_file) mlflow.log_artifact(ratings_file, "ratings-csv-dir")
Example #7
Source File: etl_data.py From mlflow with Apache License 2.0 | 6 votes |
def etl_data(ratings_csv, max_row_limit): with mlflow.start_run() as mlrun: tmpdir = tempfile.mkdtemp() ratings_parquet_dir = os.path.join(tmpdir, 'ratings-parquet') spark = pyspark.sql.SparkSession.builder.getOrCreate() print("Converting ratings CSV %s to Parquet %s" % (ratings_csv, ratings_parquet_dir)) ratings_df = spark.read \ .option("header", "true") \ .option("inferSchema", "true") \ .csv(ratings_csv) \ .drop("timestamp") # Drop unused column ratings_df.show() if max_row_limit != -1: ratings_df = ratings_df.limit(max_row_limit) ratings_df.write.parquet(ratings_parquet_dir) print("Uploading Parquet ratings: %s" % ratings_parquet_dir) mlflow.log_artifacts(ratings_parquet_dir, "ratings-parquet-dir")
Example #8
Source File: test_tracking.py From mlflow with Apache License 2.0 | 6 votes |
def test_log_batch_validates_entity_names_and_values(): bad_kwargs = { "metrics": [ [Metric(key="../bad/metric/name", value=0.3, timestamp=3, step=0)], [Metric(key="ok-name", value="non-numerical-value", timestamp=3, step=0)], [Metric(key="ok-name", value=0.3, timestamp="non-numerical-timestamp", step=0)], ], "params": [[Param(key="../bad/param/name", value="my-val")]], "tags": [[Param(key="../bad/tag/name", value="my-val")]], } with start_run() as active_run: for kwarg, bad_values in bad_kwargs.items(): for bad_kwarg_value in bad_values: final_kwargs = { "run_id": active_run.info.run_id, "metrics": [], "params": [], "tags": [], } final_kwargs[kwarg] = bad_kwarg_value with pytest.raises(MlflowException) as e: tracking.MlflowClient().log_batch(**final_kwargs) assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Example #9
Source File: test_image_creation.py From mlflow with Apache License 2.0 | 6 votes |
def test_cli_build_image_with_runs_uri_calls_expected_azure_routines(sklearn_model): artifact_path = "model" with mlflow.start_run(): mlflow.sklearn.log_model(sk_model=sklearn_model, artifact_path=artifact_path) run_id = mlflow.active_run().info.run_id model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=run_id, artifact_path=artifact_path) with AzureMLMocks() as aml_mocks: result = CliRunner(env={"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}).invoke( mlflow.azureml.cli.commands, [ 'build-image', '-m', model_uri, '-w', 'test_workspace', '-i', 'image_name', '-n', 'model_name', ]) assert result.exit_code == 0 assert aml_mocks["register_model"].call_count == 1 assert aml_mocks["create_image"].call_count == 1 assert aml_mocks["load_workspace"].call_count == 1
Example #10
Source File: test_spark_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_mleap_model_log(spark_model_iris): artifact_path = "model" register_model_patch = mock.patch("mlflow.register_model") with mlflow.start_run(), register_model_patch: sparkm.log_model(spark_model=spark_model_iris.model, sample_input=spark_model_iris.spark_df, artifact_path=artifact_path, registered_model_name="Model1") model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) mlflow.register_model.assert_called_once_with(model_uri, "Model1") model_path = _download_artifact_from_uri(artifact_uri=model_uri) config_path = os.path.join(model_path, "MLmodel") mlflow_model = Model.load(config_path) assert sparkm.FLAVOR_NAME in mlflow_model.flavors assert mleap.FLAVOR_NAME in mlflow_model.flavors
Example #11
Source File: test_tracking.py From mlflow with Apache License 2.0 | 6 votes |
def test_parent_create_run(): with mlflow.start_run() as parent_run: parent_run_id = parent_run.info.run_id os.environ[_RUN_ID_ENV_VAR] = parent_run_id with mlflow.start_run() as parent_run: assert parent_run.info.run_id == parent_run_id with pytest.raises(Exception, match='To start a nested run'): mlflow.start_run() with mlflow.start_run(nested=True) as child_run: assert child_run.info.run_id != parent_run_id with mlflow.start_run(nested=True) as grand_child_run: pass def verify_has_parent_id_tag(child_id, expected_parent_id): tags = tracking.MlflowClient().get_run(child_id).data.tags assert tags[MLFLOW_PARENT_RUN_ID] == expected_parent_id verify_has_parent_id_tag(child_run.info.run_id, parent_run.info.run_id) verify_has_parent_id_tag(grand_child_run.info.run_id, child_run.info.run_id) assert mlflow.active_run() is None
Example #12
Source File: test_cli.py From mlflow with Apache License 2.0 | 6 votes |
def test_prepare_env_passes(sk_model): if no_conda: pytest.skip("This test requires conda.") with TempDir(chdr=True): with mlflow.start_run() as active_run: mlflow.sklearn.log_model(sk_model, "model") model_uri = "runs:/{run_id}/model".format(run_id=active_run.info.run_id) # Test with no conda p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri, "--no-conda"], stderr=subprocess.PIPE) assert p.wait() == 0 # With conda p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri], stderr=subprocess.PIPE) assert p.wait() == 0 # Should be idempotent p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri], stderr=subprocess.PIPE) assert p.wait() == 0
Example #13
Source File: tensorflow.py From mlflow with Apache License 2.0 | 6 votes |
def _log_event(event): """ Extracts metric information from the event protobuf """ if not mlflow.active_run(): try_mlflow_log(mlflow.start_run) global _AUTOLOG_RUN_ID _AUTOLOG_RUN_ID = mlflow.active_run().info.run_id if event.WhichOneof('what') == 'summary': summary = event.summary for v in summary.value: if v.HasField('simple_value'): if (event.step-1) % _LOG_EVERY_N_STEPS == 0: _thread_pool.submit(_add_to_queue, key=v.tag, value=v.simple_value, step=event.step, time=int(time.time() * 1000), run_id=mlflow.active_run().info.run_id)
Example #14
Source File: test_run.py From nyaggle with MIT License | 6 votes |
def test_ignore_errors_in_mlflow_params(tmpdir_name): mlflow.start_run() mlflow.log_param('features', 'ABC') mlflow.log_metric('Overall', -99) params = { 'objective': 'binary', 'max_depth': 8 } X, y = make_classification_df() result = run_experiment(params, X, y, with_mlflow=True, logging_directory=tmpdir_name, feature_list=[]) client = mlflow.tracking.MlflowClient() data = client.get_run(mlflow.active_run().info.run_id).data assert data.metrics['Overall'] == result.metrics[-1] assert data.params['features'] == 'ABC' # params cannot be overwritten mlflow.end_run()
Example #15
Source File: test_run.py From nyaggle with MIT License | 6 votes |
def test_inherit_outer_scope_run(tmpdir_name): mlflow.start_run() mlflow.log_param('foo', 1) params = { 'objective': 'binary', 'max_depth': 8 } X, y = make_classification_df() run_experiment(params, X, y, with_mlflow=True, logging_directory=tmpdir_name) assert mlflow.active_run() is not None # still valid client = mlflow.tracking.MlflowClient() data = client.get_run(mlflow.active_run().info.run_id).data assert data.metrics['Overall'] > 0 # recorded mlflow.end_run()
Example #16
Source File: experiment.py From nyaggle with MIT License | 6 votes |
def start(self): """ Start a new experiment. """ if self.with_mlflow: import mlflow if mlflow.active_run() is not None: active_run = mlflow.active_run() self.inherit_existing_run = True else: active_run = mlflow.start_run(run_name=self.mlflow_run_name, run_id=self.mlflow_run_id) mlflow_metadata = { 'artifact_uri': active_run.info.artifact_uri, 'experiment_id': active_run.info.experiment_id, 'run_id': active_run.info.run_id } self.mlflow_run_id = active_run.info.run_id with open(os.path.join(self.logging_directory, 'mlflow.json'), 'w') as f: json.dump(mlflow_metadata, f, indent=4)
Example #17
Source File: test_spark_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_sparkml_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies( spark_model_iris): artifact_path = "model" with mlflow.start_run(): sparkm.log_model( spark_model=spark_model_iris.model, artifact_path=artifact_path, conda_env=None) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) model_path = _download_artifact_from_uri(artifact_uri=model_uri) pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) with open(conda_env_path, "r") as f: conda_env = yaml.safe_load(f) assert conda_env == sparkm.get_default_conda_env()
Example #18
Source File: test_model.py From mlflow with Apache License 2.0 | 6 votes |
def test_model_log(): with TempDir(chdr=True) as tmp: experiment_id = mlflow.create_experiment("test") sig = ModelSignature(inputs=Schema([ColSpec("integer", "x"), ColSpec("integer", "y")]), outputs=Schema([ColSpec(name=None, type="double")])) input_example = {"x": 1, "y": 2} with mlflow.start_run(experiment_id=experiment_id) as r: Model.log("some/path", TestFlavor, signature=sig, input_example=input_example) local_path = _download_artifact_from_uri("runs:/{}/some/path".format(r.info.run_id), output_path=tmp.path("")) loaded_model = Model.load(os.path.join(local_path, "MLmodel")) assert loaded_model.run_id == r.info.run_id assert loaded_model.artifact_path == "some/path" assert loaded_model.flavors == { "flavor1": {"a": 1, "b": 2}, "flavor2": {"x": 1, "y": 2}, } assert loaded_model.signature == sig path = os.path.join(local_path, loaded_model.saved_input_example_info["artifact_path"]) x = _dataframe_from_json(path) assert x.to_dict(orient="records")[0] == input_example
Example #19
Source File: test_gluon_autolog.py From mlflow with Apache License 2.0 | 6 votes |
def test_autolog_persists_manually_created_run(): mlflow.gluon.autolog() data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard") with mlflow.start_run() as run: model = HybridSequential() model.add(Dense(64, activation="relu")) model.add(Dense(64, activation="relu")) model.add(Dense(10)) model.initialize() model.hybridize() trainer = Trainer(model.collect_params(), "adam", optimizer_params={"learning_rate": .001, "epsilon": 1e-07}) est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(), metrics=Accuracy(), trainer=trainer) with warnings.catch_warnings(): warnings.simplefilter("ignore") est.fit(data, epochs=3) assert mlflow.active_run().info.run_id == run.info.run_id
Example #20
Source File: test_cli.py From mlflow with Apache License 2.0 | 6 votes |
def test_prepare_env_fails(sk_model): if no_conda: pytest.skip("This test requires conda.") with TempDir(chdr=True): with mlflow.start_run() as active_run: mlflow.sklearn.log_model(sk_model, "model", conda_env={"dependencies": ["mlflow-does-not-exist-dep==abc"]}) model_uri = "runs:/{run_id}/model".format(run_id=active_run.info.run_id) # Test with no conda p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri, "--no-conda"]) assert p.wait() == 0 # With conda - should fail due to bad conda environment. p = subprocess.Popen(["mlflow", "models", "prepare-env", "-m", model_uri]) assert p.wait() != 0
Example #21
Source File: test_tensorflow_autolog.py From mlflow with Apache License 2.0 | 6 votes |
def tf_core_random_tensors(): mlflow.tensorflow.autolog(every_n_iter=4) with mlflow.start_run() as run: sess = tf.Session() a = tf.constant(3.0, dtype=tf.float32) b = tf.constant(4.0) total = a + b tf.summary.scalar('a', a) tf.summary.scalar('b', b) merged = tf.summary.merge_all() dir = tempfile.mkdtemp() writer = tf.summary.FileWriter(dir, sess.graph) with sess.as_default(): for i in range(40): summary, _ = sess.run([merged, total]) writer.add_summary(summary, global_step=i) shutil.rmtree(dir) writer.close() sess.close() client = mlflow.tracking.MlflowClient() return client.get_run(run.info.run_id)
Example #22
Source File: test_lightgbm_autolog.py From mlflow with Apache License 2.0 | 5 votes |
def test_lgb_autolog_persists_manually_created_run(bst_params, train_set): mlflow.lightgbm.autolog() with mlflow.start_run() as run: lgb.train(bst_params, train_set, num_boost_round=1) assert mlflow.active_run() assert mlflow.active_run().info.run_id == run.info.run_id
Example #23
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_start_run_exp_id_0(): mlflow.set_experiment("some-experiment") # Create a run and verify that the current active experiment is the one we just set with mlflow.start_run() as active_run: exp_id = active_run.info.experiment_id assert exp_id != FileStore.DEFAULT_EXPERIMENT_ID assert MlflowClient().get_experiment(exp_id).name == "some-experiment" # Set experiment ID to 0 when creating a run, verify that the specified experiment ID is honored with mlflow.start_run(experiment_id=0) as active_run: assert active_run.info.experiment_id == FileStore.DEFAULT_EXPERIMENT_ID
Example #24
Source File: test_spark_model_export.py From mlflow with Apache License 2.0 | 5 votes |
def test_log_model_calls_register_model(tmpdir, spark_model_iris): artifact_path = "model" dfs_tmp_dir = os.path.join(str(tmpdir), "test") try: register_model_patch = mock.patch("mlflow.register_model") with mlflow.start_run(), register_model_patch: sparkm.log_model(artifact_path=artifact_path, spark_model=spark_model_iris.model, dfs_tmpdir=dfs_tmp_dir, registered_model_name="AdsModel1") model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) mlflow.register_model.assert_called_once_with(model_uri, "AdsModel1") finally: x = dfs_tmp_dir or sparkm.DFS_TMP shutil.rmtree(x)
Example #25
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_set_experiment_with_deleted_experiment_name(): name = "dead_exp" mlflow.set_experiment(name) with start_run() as run: exp_id = run.info.experiment_id tracking.MlflowClient().delete_experiment(exp_id) with pytest.raises(MlflowException): mlflow.set_experiment(name)
Example #26
Source File: test_cli.py From mlflow with Apache License 2.0 | 5 votes |
def test_build_docker(iris_data, sk_model): with mlflow.start_run() as active_run: mlflow.sklearn.log_model(sk_model, "model") model_uri = "runs:/{run_id}/model".format(run_id=active_run.info.run_id) x, _ = iris_data df = pd.DataFrame(x) image_name = pyfunc_build_image(model_uri, extra_args=["--install-mlflow"]) host_port = get_safe_port() scoring_proc = pyfunc_serve_from_docker_image(image_name, host_port) _validate_with_rest_endpoint(scoring_proc, host_port, df, x, sk_model)
Example #27
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_set_experiment(): with pytest.raises(TypeError): mlflow.set_experiment() # pylint: disable=no-value-for-parameter with pytest.raises(Exception): mlflow.set_experiment(None) with pytest.raises(Exception): mlflow.set_experiment("") name = "random_exp" exp_id = mlflow.create_experiment(name) mlflow.set_experiment(name) with start_run() as run: assert run.info.experiment_id == exp_id another_name = "another_experiment" mlflow.set_experiment(another_name) exp_id2 = mlflow.tracking.MlflowClient().get_experiment_by_name(another_name) with start_run() as another_run: assert another_run.info.experiment_id == exp_id2.experiment_id
Example #28
Source File: test_spark_model_export.py From mlflow with Apache License 2.0 | 5 votes |
def test_log_model_no_registered_model_name(tmpdir, spark_model_iris): artifact_path = "model" dfs_tmp_dir = os.path.join(str(tmpdir), "test") try: register_model_patch = mock.patch("mlflow.register_model") with mlflow.start_run(), register_model_patch: sparkm.log_model(artifact_path=artifact_path, spark_model=spark_model_iris.model, dfs_tmpdir=dfs_tmp_dir) mlflow.register_model.assert_not_called() finally: x = dfs_tmp_dir or sparkm.DFS_TMP shutil.rmtree(x)
Example #29
Source File: test_xgboost_autolog.py From mlflow with Apache License 2.0 | 5 votes |
def test_xgb_autolog_persists_manually_created_run(bst_params, dtrain): mlflow.xgboost.autolog() with mlflow.start_run() as run: xgb.train(bst_params, dtrain) assert mlflow.active_run() assert mlflow.active_run().info.run_id == run.info.run_id
Example #30
Source File: test_spark_model_export.py From mlflow with Apache License 2.0 | 5 votes |
def test_default_conda_env_strips_dev_suffix_from_pyspark_version(spark_model_iris, model_path): mock_version_standard = mock.PropertyMock(return_value="2.4.0") with mock.patch("pyspark.__version__", new_callable=mock_version_standard): default_conda_env_standard = sparkm.get_default_conda_env() for dev_version in ["2.4.0.dev0", "2.4.0.dev", "2.4.0.dev1", "2.4.0dev.a", "2.4.0.devb"]: mock_version_dev = mock.PropertyMock(return_value=dev_version) with mock.patch("pyspark.__version__", new_callable=mock_version_dev): default_conda_env_dev = sparkm.get_default_conda_env() assert (default_conda_env_dev == default_conda_env_standard) with mlflow.start_run(): sparkm.log_model( spark_model=spark_model_iris.model, artifact_path="model", conda_env=None) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path="model") model_path = _download_artifact_from_uri(artifact_uri=model_uri) pyfunc_conf = _get_flavor_configuration( model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) with open(conda_env_path, "r") as f: persisted_conda_env_dev = yaml.safe_load(f) assert (persisted_conda_env_dev == default_conda_env_standard) for unaffected_version in ["2.0", "2.3.4", "2"]: mock_version = mock.PropertyMock(return_value=unaffected_version) with mock.patch("pyspark.__version__", new_callable=mock_version): assert unaffected_version in yaml.safe_dump(sparkm.get_default_conda_env())