Python mlflow.active_run() Examples
The following are 30
code examples of mlflow.active_run().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
mlflow
, or try the search function
.
Example #1
Source File: test_keras_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_model_log(model, data, predicted): x, _ = data # should_start_run tests whether or not calling log_model() automatically starts a run. for should_start_run in [False, True]: try: if should_start_run: mlflow.start_run() artifact_path = "keras_model" mlflow.keras.log_model(model, artifact_path=artifact_path) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) # Load model model_loaded = mlflow.keras.load_model(model_uri=model_uri) assert all(model_loaded.predict(x) == predicted) # Loading pyfunc model pyfunc_loaded = mlflow.pyfunc.load_model(model_uri=model_uri) assert all(pyfunc_loaded.predict(x).values == predicted) finally: mlflow.end_run()
Example #2
Source File: experiment.py From nyaggle with MIT License | 6 votes |
def start(self): """ Start a new experiment. """ if self.with_mlflow: import mlflow if mlflow.active_run() is not None: active_run = mlflow.active_run() self.inherit_existing_run = True else: active_run = mlflow.start_run(run_name=self.mlflow_run_name, run_id=self.mlflow_run_id) mlflow_metadata = { 'artifact_uri': active_run.info.artifact_uri, 'experiment_id': active_run.info.experiment_id, 'run_id': active_run.info.run_id } self.mlflow_run_id = active_run.info.run_id with open(os.path.join(self.logging_directory, 'mlflow.json'), 'w') as f: json.dump(mlflow_metadata, f, indent=4)
Example #3
Source File: tensorflow.py From mlflow with Apache License 2.0 | 6 votes |
def _log_event(event): """ Extracts metric information from the event protobuf """ if not mlflow.active_run(): try_mlflow_log(mlflow.start_run) global _AUTOLOG_RUN_ID _AUTOLOG_RUN_ID = mlflow.active_run().info.run_id if event.WhichOneof('what') == 'summary': summary = event.summary for v in summary.value: if v.HasField('simple_value'): if (event.step-1) % _LOG_EVERY_N_STEPS == 0: _thread_pool.submit(_add_to_queue, key=v.tag, value=v.simple_value, step=event.step, time=int(time.time() * 1000), run_id=mlflow.active_run().info.run_id)
Example #4
Source File: test_tensorflow2_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_log_model_without_specified_conda_env_uses_default_env_with_expected_dependencies( saved_tf_iris_model): artifact_path = "model" with mlflow.start_run(): mlflow.tensorflow.log_model(tf_saved_model_dir=saved_tf_iris_model.path, tf_meta_graph_tags=saved_tf_iris_model.meta_graph_tags, tf_signature_def_key=saved_tf_iris_model.signature_def_key, artifact_path=artifact_path, conda_env=None) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) model_path = _download_artifact_from_uri(artifact_uri=model_uri) pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) with open(conda_env_path, "r") as f: conda_env = yaml.safe_load(f) assert conda_env == mlflow.tensorflow.get_default_conda_env()
Example #5
Source File: test_h2o_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_model_log(h2o_iris_model): h2o_model = h2o_iris_model.model old_uri = mlflow.get_tracking_uri() # should_start_run tests whether or not calling log_model() automatically starts a run. for should_start_run in [False, True]: with TempDir(chdr=True, remove_on_exit=True): try: artifact_path = "gbm_model" mlflow.set_tracking_uri("test") if should_start_run: mlflow.start_run() mlflow.h2o.log_model(h2o_model=h2o_model, artifact_path=artifact_path) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) # Load model h2o_model_loaded = mlflow.h2o.load_model(model_uri=model_uri) assert all( h2o_model_loaded.predict(h2o_iris_model.inference_data).as_data_frame() == h2o_model.predict(h2o_iris_model.inference_data).as_data_frame()) finally: mlflow.end_run() mlflow.set_tracking_uri(old_uri)
Example #6
Source File: test_tensorflow2_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_log_model_persists_specified_conda_env_in_mlflow_model_directory( saved_tf_iris_model, tf_custom_env): artifact_path = "model" with mlflow.start_run(): mlflow.tensorflow.log_model(tf_saved_model_dir=saved_tf_iris_model.path, tf_meta_graph_tags=saved_tf_iris_model.meta_graph_tags, tf_signature_def_key=saved_tf_iris_model.signature_def_key, artifact_path=artifact_path, conda_env=tf_custom_env) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) model_path = _download_artifact_from_uri(artifact_uri=model_uri) pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) assert os.path.exists(saved_conda_env_path) assert saved_conda_env_path != tf_custom_env with open(tf_custom_env, "r") as f: tf_custom_env_text = f.read() with open(saved_conda_env_path, "r") as f: saved_conda_env_text = f.read() assert saved_conda_env_text == tf_custom_env_text
Example #7
Source File: test_h2o_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_model_log_persists_specified_conda_env_in_mlflow_model_directory( h2o_iris_model, h2o_custom_env): artifact_path = "model" with mlflow.start_run(): mlflow.h2o.log_model(h2o_model=h2o_iris_model.model, artifact_path=artifact_path, conda_env=h2o_custom_env) model_path = _download_artifact_from_uri("runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path)) pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) assert os.path.exists(saved_conda_env_path) assert saved_conda_env_path != h2o_custom_env with open(h2o_custom_env, "r") as f: h2o_custom_env_text = f.read() with open(saved_conda_env_path, "r") as f: saved_conda_env_text = f.read() assert saved_conda_env_text == h2o_custom_env_text
Example #8
Source File: test_tensorflow_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_log_model_persists_specified_conda_env_in_mlflow_model_directory( saved_tf_iris_model, tf_custom_env): artifact_path = "model" with mlflow.start_run(): mlflow.tensorflow.log_model(tf_saved_model_dir=saved_tf_iris_model.path, tf_meta_graph_tags=saved_tf_iris_model.meta_graph_tags, tf_signature_def_key=saved_tf_iris_model.signature_def_key, artifact_path=artifact_path, conda_env=tf_custom_env) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) model_path = _download_artifact_from_uri(artifact_uri=model_uri) pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) assert os.path.exists(saved_conda_env_path) assert saved_conda_env_path != tf_custom_env with open(tf_custom_env, "r") as f: tf_custom_env_text = f.read() with open(saved_conda_env_path, "r") as f: saved_conda_env_text = f.read() assert saved_conda_env_text == tf_custom_env_text
Example #9
Source File: test_tensorflow_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_log_model_without_specified_conda_env_uses_default_env_with_expected_dependencies( saved_tf_iris_model, model_path): artifact_path = "model" with mlflow.start_run(): mlflow.tensorflow.log_model(tf_saved_model_dir=saved_tf_iris_model.path, tf_meta_graph_tags=saved_tf_iris_model.meta_graph_tags, tf_signature_def_key=saved_tf_iris_model.signature_def_key, artifact_path=artifact_path, conda_env=None) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) model_path = _download_artifact_from_uri(artifact_uri=model_uri) pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) with open(conda_env_path, "r") as f: conda_env = yaml.safe_load(f) assert conda_env == mlflow.tensorflow.get_default_conda_env()
Example #10
Source File: test_run.py From nyaggle with MIT License | 6 votes |
def test_inherit_outer_scope_run(tmpdir_name): mlflow.start_run() mlflow.log_param('foo', 1) params = { 'objective': 'binary', 'max_depth': 8 } X, y = make_classification_df() run_experiment(params, X, y, with_mlflow=True, logging_directory=tmpdir_name) assert mlflow.active_run() is not None # still valid client = mlflow.tracking.MlflowClient() data = client.get_run(mlflow.active_run().info.run_id).data assert data.metrics['Overall'] > 0 # recorded mlflow.end_run()
Example #11
Source File: test_run.py From nyaggle with MIT License | 6 votes |
def test_ignore_errors_in_mlflow_params(tmpdir_name): mlflow.start_run() mlflow.log_param('features', 'ABC') mlflow.log_metric('Overall', -99) params = { 'objective': 'binary', 'max_depth': 8 } X, y = make_classification_df() result = run_experiment(params, X, y, with_mlflow=True, logging_directory=tmpdir_name, feature_list=[]) client = mlflow.tracking.MlflowClient() data = client.get_run(mlflow.active_run().info.run_id).data assert data.metrics['Overall'] == result.metrics[-1] assert data.params['features'] == 'ABC' # params cannot be overwritten mlflow.end_run()
Example #12
Source File: mlflow_utils.py From nucleus7 with Mozilla Public License 2.0 | 6 votes |
def log_project_artifacts_to_mlflow(function: Callable): """ Log the artifact to mlflow Parameters ---------- function function to wrap """ @wraps(function) def wrapped(*args, **kwargs): if mlflow.active_run() is None: _warn_about_no_run() return function(*args, **kwargs) artifacts_path = project.get_active_artifacts_directory() artifacts_path_realpath = os.path.realpath(artifacts_path) mlflow.log_artifacts(artifacts_path_realpath) return function(*args, **kwargs) return wrapped # pylint: disable=invalid-name # this is method, not a constant, and is used inside of the patch
Example #13
Source File: test_spark_datasource_autologging.py From mlflow with Apache License 2.0 | 6 votes |
def test_autologging_of_datasources_with_different_formats( spark_session, format_to_file_path): mlflow.spark.autolog() for data_format, file_path in format_to_file_path.items(): base_df = spark_session.read.format(data_format).option("header", "true").\ option("inferSchema", "true").load(file_path) base_df.createOrReplaceTempView("temptable") table_df0 = spark_session.table("temptable") table_df1 = spark_session.sql("SELECT number1, number2 from temptable LIMIT 5") dfs = [ base_df, table_df0, table_df1, base_df.filter("number1 > 0"), base_df.select("number1"), base_df.limit(2), base_df.filter("number1 > 0").select("number1").limit(2)] for df in dfs: with mlflow.start_run(): run_id = mlflow.active_run().info.run_id df.collect() time.sleep(1) run = mlflow.get_run(run_id) _assert_spark_data_logged(run=run, path=file_path, data_format=data_format)
Example #14
Source File: test_gluon_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_gluon_model_serving_and_scoring_as_pyfunc(gluon_model, model_data): _, _, test_data = model_data expected = nd.argmax(gluon_model(test_data), axis=1) artifact_path = "model" with mlflow.start_run(): mlflow.gluon.log_model(gluon_model, artifact_path=artifact_path) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) scoring_response = pyfunc_serve_and_score_model( model_uri=model_uri, data=pd.DataFrame(test_data.asnumpy()), content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON_SPLIT_ORIENTED) response_values = \ pd.read_json(scoring_response.content, orient="records").values.astype(np.float32) assert all( np.argmax(response_values, axis=1) == expected.asnumpy())
Example #15
Source File: test_gluon_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_model_log_load(gluon_model, model_data, model_path): _, _, test_data = model_data expected = nd.argmax(gluon_model(test_data), axis=1) artifact_path = "model" with mlflow.start_run(): mlflow.gluon.log_model(gluon_model, artifact_path=artifact_path) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) # Loading Gluon model model_loaded = mlflow.gluon.load_model(model_uri, ctx.cpu()) actual = nd.argmax(model_loaded(test_data), axis=1) assert all(expected == actual) # Loading pyfunc model pyfunc_loaded = mlflow.pyfunc.load_model(model_uri) test_pyfunc_data = pd.DataFrame(test_data.asnumpy()) pyfunc_preds = pyfunc_loaded.predict(test_pyfunc_data) assert all( np.argmax(pyfunc_preds.values, axis=1) == expected.asnumpy())
Example #16
Source File: test_tensorflow2_autolog.py From mlflow with Apache License 2.0 | 6 votes |
def test_tf_keras_autolog_persists_manually_created_run(random_train_data, random_one_hot_labels, fit_variant): mlflow.tensorflow.autolog() with mlflow.start_run() as run: data = random_train_data labels = random_one_hot_labels model = create_tf_keras_model() if fit_variant == 'fit_generator': def generator(): while True: yield data, labels model.fit_generator(generator(), epochs=10, steps_per_epoch=1) else: model.fit(data, labels, epochs=10) assert mlflow.active_run() assert mlflow.active_run().info.run_id == run.info.run_id
Example #17
Source File: test_tensorflow2_autolog.py From mlflow with Apache License 2.0 | 6 votes |
def test_tf_keras_autolog_ends_auto_created_run(random_train_data, random_one_hot_labels, fit_variant): mlflow.tensorflow.autolog() data = random_train_data labels = random_one_hot_labels model = create_tf_keras_model() if fit_variant == 'fit_generator': def generator(): while True: yield data, labels model.fit_generator(generator(), epochs=10, steps_per_epoch=1) else: model.fit(data, labels, epochs=10) assert mlflow.active_run() is None
Example #18
Source File: test_image_creation.py From mlflow with Apache License 2.0 | 6 votes |
def test_cli_build_image_with_runs_uri_calls_expected_azure_routines(sklearn_model): artifact_path = "model" with mlflow.start_run(): mlflow.sklearn.log_model(sk_model=sklearn_model, artifact_path=artifact_path) run_id = mlflow.active_run().info.run_id model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=run_id, artifact_path=artifact_path) with AzureMLMocks() as aml_mocks: result = CliRunner(env={"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}).invoke( mlflow.azureml.cli.commands, [ 'build-image', '-m', model_uri, '-w', 'test_workspace', '-i', 'image_name', '-n', 'model_name', ]) assert result.exit_code == 0 assert aml_mocks["register_model"].call_count == 1 assert aml_mocks["create_image"].call_count == 1 assert aml_mocks["load_workspace"].call_count == 1
Example #19
Source File: test_gluon_autolog.py From mlflow with Apache License 2.0 | 6 votes |
def test_autolog_ends_auto_created_run(): mlflow.gluon.autolog() data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard") model = HybridSequential() model.add(Dense(64, activation="relu")) model.add(Dense(64, activation="relu")) model.add(Dense(10)) model.initialize() model.hybridize() trainer = Trainer(model.collect_params(), "adam", optimizer_params={"learning_rate": .001, "epsilon": 1e-07}) est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(), metrics=Accuracy(), trainer=trainer) with warnings.catch_warnings(): warnings.simplefilter("ignore") est.fit(data, epochs=3) assert mlflow.active_run() is None
Example #20
Source File: test_fluent.py From mlflow with Apache License 2.0 | 6 votes |
def test_delete_tag(): """ Confirm that fluent API delete tags actually works :return: """ mlflow.set_tag('a', 'b') run = MlflowClient().get_run(mlflow.active_run().info.run_id) print(run.info.run_id) assert 'a' in run.data.tags mlflow.delete_tag('a') run = MlflowClient().get_run(mlflow.active_run().info.run_id) assert 'a' not in run.data.tags with pytest.raises(MlflowException): mlflow.delete_tag('a') with pytest.raises(MlflowException): mlflow.delete_tag('b') mlflow.end_run()
Example #21
Source File: test_spark_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_sparkml_model_log_persists_specified_conda_env_in_mlflow_model_directory( spark_model_iris, model_path, spark_custom_env): artifact_path = "model" with mlflow.start_run(): sparkm.log_model( spark_model=spark_model_iris.model, artifact_path=artifact_path, conda_env=spark_custom_env) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) model_path = _download_artifact_from_uri(artifact_uri=model_uri) pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) assert os.path.exists(saved_conda_env_path) assert saved_conda_env_path != spark_custom_env with open(spark_custom_env, "r") as f: spark_custom_env_parsed = yaml.safe_load(f) with open(saved_conda_env_path, "r") as f: saved_conda_env_parsed = yaml.safe_load(f) assert saved_conda_env_parsed == spark_custom_env_parsed
Example #22
Source File: test_spark_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_sparkml_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies( spark_model_iris): artifact_path = "model" with mlflow.start_run(): sparkm.log_model( spark_model=spark_model_iris.model, artifact_path=artifact_path, conda_env=None) model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) model_path = _download_artifact_from_uri(artifact_uri=model_uri) pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]) with open(conda_env_path, "r") as f: conda_env = yaml.safe_load(f) assert conda_env == sparkm.get_default_conda_env()
Example #23
Source File: test_spark_model_export.py From mlflow with Apache License 2.0 | 6 votes |
def test_mleap_model_log(spark_model_iris): artifact_path = "model" register_model_patch = mock.patch("mlflow.register_model") with mlflow.start_run(), register_model_patch: sparkm.log_model(spark_model=spark_model_iris.model, sample_input=spark_model_iris.spark_df, artifact_path=artifact_path, registered_model_name="Model1") model_uri = "runs:/{run_id}/{artifact_path}".format( run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path) mlflow.register_model.assert_called_once_with(model_uri, "Model1") model_path = _download_artifact_from_uri(artifact_uri=model_uri) config_path = os.path.join(model_path, "MLmodel") mlflow_model = Model.load(config_path) assert sparkm.FLAVOR_NAME in mlflow_model.flavors assert mleap.FLAVOR_NAME in mlflow_model.flavors
Example #24
Source File: test_tracking.py From mlflow with Apache License 2.0 | 6 votes |
def test_parent_create_run(): with mlflow.start_run() as parent_run: parent_run_id = parent_run.info.run_id os.environ[_RUN_ID_ENV_VAR] = parent_run_id with mlflow.start_run() as parent_run: assert parent_run.info.run_id == parent_run_id with pytest.raises(Exception, match='To start a nested run'): mlflow.start_run() with mlflow.start_run(nested=True) as child_run: assert child_run.info.run_id != parent_run_id with mlflow.start_run(nested=True) as grand_child_run: pass def verify_has_parent_id_tag(child_id, expected_parent_id): tags = tracking.MlflowClient().get_run(child_id).data.tags assert tags[MLFLOW_PARENT_RUN_ID] == expected_parent_id verify_has_parent_id_tag(child_run.info.run_id, parent_run.info.run_id) verify_has_parent_id_tag(grand_child_run.info.run_id, child_run.info.run_id) assert mlflow.active_run() is None
Example #25
Source File: test_tracking.py From mlflow with Apache License 2.0 | 6 votes |
def test_log_batch_validates_entity_names_and_values(): bad_kwargs = { "metrics": [ [Metric(key="../bad/metric/name", value=0.3, timestamp=3, step=0)], [Metric(key="ok-name", value="non-numerical-value", timestamp=3, step=0)], [Metric(key="ok-name", value=0.3, timestamp="non-numerical-timestamp", step=0)], ], "params": [[Param(key="../bad/param/name", value="my-val")]], "tags": [[Param(key="../bad/tag/name", value="my-val")]], } with start_run() as active_run: for kwarg, bad_values in bad_kwargs.items(): for bad_kwarg_value in bad_values: final_kwargs = { "run_id": active_run.info.run_id, "metrics": [], "params": [], "tags": [], } final_kwargs[kwarg] = bad_kwarg_value with pytest.raises(MlflowException) as e: tracking.MlflowClient().log_batch(**final_kwargs) assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
Example #26
Source File: test_tracking.py From mlflow with Apache License 2.0 | 6 votes |
def test_log_metrics_uses_millisecond_timestamp_resolution_client(): with start_run() as active_run, mock.patch("time.time") as time_mock: time_mock.side_effect = lambda: 123 mlflow_client = tracking.MlflowClient() run_id = active_run.info.run_id mlflow_client.log_metric(run_id=run_id, key="name_1", value=25) mlflow_client.log_metric(run_id=run_id, key="name_2", value=-3) mlflow_client.log_metric(run_id=run_id, key="name_1", value=30) mlflow_client.log_metric(run_id=run_id, key="name_1", value=40) metric_history_name1 = mlflow_client.get_metric_history(run_id, "name_1") assert set([(m.value, m.timestamp) for m in metric_history_name1]) == set([ (25, 123 * 1000), (30, 123 * 1000), (40, 123 * 1000), ]) metric_history_name2 = mlflow_client.get_metric_history(run_id, "name_2") assert set([(m.value, m.timestamp) for m in metric_history_name2]) == set([ (-3, 123 * 1000), ])
Example #27
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_log_params(): expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": 5} with start_run() as active_run: run_id = active_run.info.run_id mlflow.log_params(expected_params) finished_run = tracking.MlflowClient().get_run(run_id) # Validate params assert finished_run.data.params == {"name_1": "c", "name_2": "b", "nested/nested/name": "5"}
Example #28
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_get_artifact_uri_uses_currently_active_run_id(): artifact_path = "artifact" with mlflow.start_run() as active_run: assert mlflow.get_artifact_uri(artifact_path=artifact_path) == \ tracking.artifact_utils.get_artifact_uri( run_id=active_run.info.run_id, artifact_path=artifact_path)
Example #29
Source File: test_fluent.py From mlflow with Apache License 2.0 | 5 votes |
def is_from_run(active_run, run): return active_run.info == run.info and active_run.data == run.data
Example #30
Source File: test_tracking.py From mlflow with Apache License 2.0 | 5 votes |
def test_start_run_exp_id_0(): mlflow.set_experiment("some-experiment") # Create a run and verify that the current active experiment is the one we just set with mlflow.start_run() as active_run: exp_id = active_run.info.experiment_id assert exp_id != FileStore.DEFAULT_EXPERIMENT_ID assert MlflowClient().get_experiment(exp_id).name == "some-experiment" # Set experiment ID to 0 when creating a run, verify that the specified experiment ID is honored with mlflow.start_run(experiment_id=0) as active_run: assert active_run.info.experiment_id == FileStore.DEFAULT_EXPERIMENT_ID