Python wandb.log() Examples

The following are 30 code examples of wandb.log().
Example #1
Source File:    From disentanglement-pytorch with GNU General Public License v3.0 6 votes vote down vote up
def visualize_recon(self, input_image, recon_image, test=False):
        input_image = torchvision.utils.make_grid(input_image)
        recon_image = torchvision.utils.make_grid(recon_image)

        if self.white_line is None:
            self.white_line = torch.ones((3, input_image.size(1), 10)).to(self.device)

        samples =[input_image, self.white_line, recon_image], dim=2)

        if self.file_save:
            if test:
                file_name = os.path.join(self.test_output_dir, '{}_{}.{}'.format(c.RECON, self.iter, c.JPG))
                file_name = os.path.join(self.train_output_dir, '{}.{}'.format(c.RECON, c.JPG))
            torchvision.utils.save_image(samples, file_name)

        if self.use_wandb:
            import wandb
            wandb.log({c.RECON_IMAGE: wandb.Image(samples, caption=str(self.iter))},
Example #2
Source File:    From rl_algorithms with MIT License 6 votes vote down vote up
def run(self):
        """Run main logging loop; continuously receive data and log."""
        if self.args.log:

        while self.update_step < self.args.max_update_step:
            if self.log_info_queue:  # if non-empty
                log_info_id = self.log_info_queue.pop()
                log_info = pa.deserialize(log_info_id)
                state_dict = log_info["state_dict"]
                log_value = log_info["log_value"]
                self.update_step = log_value["update_step"]

                avg_score = self.test(self.update_step)
                log_value["avg_score"] = avg_score
Example #3
Source File:    From NeMo with Apache License 2.0 6 votes vote down vote up
def __init__(
        self, train_tensors=[], wandb_name=None, wandb_project=None, args=None, update_freq=25,
            train_tensors: list of tensors to evaluate and log based on training batches
            wandb_name: wandb experiment name
            wandb_project: wandb project name
            args: argparse flags - will be logged as hyperparameters
            update_freq: frequency with which to log updates

        if not _WANDB_AVAILABLE:
            logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")

        self._update_freq = update_freq
        self._train_tensors = train_tensors
        self._name = wandb_name
        self._project = wandb_project
        self._args = args 
Example #4
Source File:    From rl_algorithms with MIT License 6 votes vote down vote up
def write_log(self, log_value: tuple):
        i, score, policy_loss, value_loss = log_value
        total_loss = policy_loss + value_loss

            "[INFO] episode %d\tepisode step: %d\ttotal score: %d\n"
            "total loss: %.4f\tpolicy loss: %.4f\tvalue loss: %.4f\n"
            % (i, self.episode_step, score, total_loss, policy_loss, value_loss)

        if self.args.log:
                    "total loss": total_loss,
                    "policy loss": policy_loss,
                    "value loss": value_loss,
                    "score": score,
Example #5
Source File:    From rl_algorithms with MIT License 6 votes vote down vote up
def write_log(
        self, log_value: tuple,
        i_episode, n_step, score, actor_loss, critic_loss, total_loss = log_value
            "[INFO] episode %d\tepisode steps: %d\ttotal score: %d\n"
            "total loss: %f\tActor loss: %f\tCritic loss: %f\n"
            % (i_episode, n_step, score, total_loss, actor_loss, critic_loss)

        if self.args.log:
                    "total loss": total_loss,
                    "actor loss": actor_loss,
                    "critic loss": critic_loss,
                    "score": score,
Example #6
Source File:    From atari-representation-learning with MIT License 6 votes vote down vote up
def remove_duplicates(tr_eps, val_eps, test_eps, test_labels):
    Remove any items in test_eps (&test_labels) which are present in tr/val_eps
    flat_tr = list(chain.from_iterable(tr_eps))
    flat_val = list(chain.from_iterable(val_eps))
    tr_val_set = set([x.numpy().tostring() for x in flat_tr] + [x.numpy().tostring() for x in flat_val])
    flat_test = list(chain.from_iterable(test_eps))

    for i, episode in enumerate(test_eps[:]):
        test_labels[i] = [label for obs, label in zip(test_eps[i], test_labels[i]) if obs.numpy().tostring() not in tr_val_set]
        test_eps[i] = [obs for obs in episode if obs.numpy().tostring() not in tr_val_set]
    test_len = len(list(chain.from_iterable(test_eps)))
    dups = len(flat_test) - test_len
    print('Duplicates: {}, Test Len: {}'.format(dups, test_len))
    #wandb.log({'Duplicates': dups, 'Test Len': test_len})
    return test_eps, test_labels 
Example #7
Source File:    From NeMo with Apache License 2.0 5 votes vote down vote up
def _wandb_log(tensors_logged, step):
        if _WANDB_AVAILABLE:
            wandb.log(tensors_logged, step=step) 
Example #8
Source File:    From tape with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
        logger.warn("Cannot log config when using a TBVisualizer. "
                    "Configure wandb for this functionality") 
Example #9
Source File:    From tape with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def log_metrics(self,
                    metrics_dict: typing.Dict[str, float],
                    split: str,
                    step: int):
        wandb.log({f"{split.capitalize()} {name.capitalize()}": value
                   for name, value in metrics_dict.items()}, step=step) 
Example #10
Source File:    From keras-rl with MIT License 5 votes vote down vote up
def on_episode_end(self, episode, logs):
        """ Compute and log training statistics of the episode when done """
        duration = timeit.default_timer() - self.episode_start[episode]
        episode_steps = len(self.observations[episode])

        metrics = np.array(self.metrics[episode])
        metrics_dict = {}
        with warnings.catch_warnings():
            for idx, name in enumerate(self.metrics_names):
                    metrics_dict[name] = np.nanmean(metrics[:, idx])
                except Warning:
                    metrics_dict[name] = float('nan')

            'step': self.step,
            'episode': episode + 1,
            'duration': duration,
            'episode_steps': episode_steps,
            'sps': float(episode_steps) / duration,
            'episode_reward': np.sum(self.rewards[episode]),
            'reward_mean': np.mean(self.rewards[episode]),
            'reward_min': np.min(self.rewards[episode]),
            'reward_max': np.max(self.rewards[episode]),
            'action_mean': np.mean(self.actions[episode]),
            'action_min': np.min(self.actions[episode]),
            'action_max': np.max(self.actions[episode]),
            'obs_mean': np.mean(self.observations[episode]),
            'obs_min': np.min(self.observations[episode]),
            'obs_max': np.max(self.observations[episode]),

        # Free up resources.
        del self.episode_start[episode]
        del self.observations[episode]
        del self.rewards[episode]
        del self.actions[episode]
        del self.metrics[episode] 
Example #11
Source File:    From firedup with MIT License 5 votes vote down vote up
def log(self, msg, color="green"):
        """Print a colorized message to stdout."""
        if proc_id() == 0:
            print(colorize(msg, color, bold=True)) 
Example #12
Source File:    From rtrl with MIT License 5 votes vote down vote up
def run_wandb(entity, project, run_id, run_cls: type = Training, checkpoint_path: str = None):
  """run and save config and stats to"""
  wandb_dir = mkdtemp()  # prevent wandb from polluting the home directory
  atexit.register(shutil.rmtree, wandb_dir, ignore_errors=True)  # clean up after wandb atexit handler finishes
  import wandb
  config = partial_to_dict(run_cls)
  config['seed'] = config['seed'] or randrange(1, 1000000)  # if seed == 0 replace with random
  config['environ'] = log_environment_variables()
  config['git'] = git_info()
  resume = checkpoint_path and exists(checkpoint_path)
  wandb.init(dir=wandb_dir, entity=entity, project=project, id=run_id, resume=resume, config=config)
  for stats in iterate_episodes(run_cls, checkpoint_path):
    [wandb.log(json.loads(s.to_json())) for s in stats] 
Example #13
Source File:    From firedup with MIT License 5 votes vote down vote up
def save_state(self, state_dict, model, itr=None):
        Saves the state of an experiment.

        To be clear: this is about saving *state*, not logging diagnostics.
        All diagnostic logging is separate from this function. This function
        will save whatever is in ``state_dict``---usually just a copy of the
        environment---and the most recent copy of the model via ``model``.

        Call with any frequency you prefer. If you only want to maintain a
        single state and overwrite it at each call with the most recent
        version, leave ``itr=None``. If you want to keep all of the states you
        save, provide unique (increasing) values for 'itr'.

            state_dict (dict): Dictionary containing essential elements to
                describe the current state of training.
            model (nn.Module): A model which contains the policy.
            itr: An int, or None. Current iteration of training.
        if proc_id() == 0:
            fname = "vars.pkl" if itr is None else "vars%d.pkl" % itr
                joblib.dump(state_dict, osp.join(self.output_dir, fname))
                self.log("Warning: could not pickle state_dict.", color="red")
            self._torch_save(model, itr) 
Example #14
Source File:    From firedup with MIT License 5 votes vote down vote up
def dump_tabular(self):
        Write all of the diagnostics from the current iteration.

        Writes both to stdout, and to the output file.
        if proc_id() == 0:
            vals = []
            key_lens = [len(key) for key in self.log_headers]
            max_key_len = max(15, max(key_lens))
            keystr = "%" + "%d" % max_key_len
            fmt = "| " + keystr + "s | %15s |"
            n_slashes = 22 + max_key_len
            print("-" * n_slashes)
            for i, key in enumerate(self.log_headers):
                val = self.log_current_row.get(key, "")
                total_env_interacts = self.log_current_row.get("TotalEnvInteracts", "")
                wandb.log({key: val}, step=total_env_interacts, commit=(i+1)==len(self.log_headers))
                valstr = "%8.3g" % val if hasattr(val, "__float__") else val
                print(fmt % (key, valstr))
            print("-" * n_slashes)
            if self.output_file is not None:
                if self.first_row:
                    self.output_file.write("\t".join(self.log_headers) + "\n")
                self.output_file.write("\t".join(map(str, vals)) + "\n")
        self.first_row = False 
Example #15
Source File:    From firedup with MIT License 5 votes vote down vote up
def log_tabular(self, key, val=None, with_min_and_max=False, average_only=False):
        Log a value or possibly the mean/std/min/max values of a diagnostic.

            key (string): The name of the diagnostic. If you are logging a
                diagnostic whose state has previously been saved with
                ``store``, the key here has to match the key you used there.

            val: A value for the diagnostic. If you have previously saved
                values for this key via ``store``, do *not* provide a ``val``

            with_min_and_max (bool): If true, log min and max values of the
                diagnostic over the epoch.

            average_only (bool): If true, do not log the standard deviation
                of the diagnostic over the epoch.
        if val is not None:
            super().log_tabular(key, val)
            v = self.epoch_dict[key]
            vals = (
                if isinstance(v[0], np.ndarray) and len(v[0].shape) > 0
                else v
            stats = mpi_statistics_scalar(vals, with_min_and_max=with_min_and_max)
            super().log_tabular(key if average_only else "Average" + key, stats[0])
            if not (average_only):
                super().log_tabular("Std" + key, stats[1])
            if with_min_and_max:
                super().log_tabular("Max" + key, stats[3])
                super().log_tabular("Min" + key, stats[2])
        self.epoch_dict[key] = [] 
Example #16
Source File:    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_action_start(self, state):
        if state["global_rank"] is None or state["global_rank"] == 0:
            if _WANDB_AVAILABLE and is None:
                wandb.init(name=self._name, project=self._project)
                if self._args is not None:
            elif _WANDB_AVAILABLE and is not None:
      "Re-using wandb session")
                logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")
      "Will not log data to weights and biases.")
                self._step_freq = -1 
Example #17
Source File:    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_step_end(self, state):
        # log training metrics
        if state["global_rank"] is None or state["global_rank"] == 0:
            if state["step"] % self._step_freq == 0 and self._step_freq > 0:
                tensors_logged = {t: state["tensors"].get_tensor(t).cpu() for t in self._tensors_to_log}
                # Always log learning rate
                if self._log_lr:
                    tensors_logged['LR'] = state["optimizers"][0].param_groups[0]['lr']
                self._wandb_log(tensors_logged, state["step"]) 
Example #18
Source File:    From catalyst with Apache License 2.0 5 votes vote down vote up
def on_stage_start(self, runner: IRunner):
        """Initialize Weights & Biases."""
        wandb.init(**self.logging_params, reinit=True, dir=str(runner.logdir))
            models=runner.model, criterion=runner.criterion, log=self.log
Example #19
Source File:    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_iteration_end(self):
        if self.global_rank is None or self.global_rank == 0:
            step = self.step
            if step % self._step_freq == 0:
                tensor_values = [self.registered_tensors[t.unique_name] for t in self.tensors]
      "Step: {step}")
                if self._print_func:
                if self._swriter is not None:
                    if self._get_tb_values:
                        tb_objects = self._get_tb_values(tensor_values)
                        for name, value in tb_objects:
                            value = value.item()
                            self._swriter.add_scalar(name, value, step)
                    if self._log_to_tb_func:
                        self._log_to_tb_func(self._swriter, tensor_values, step)
                    run_time = time.time() - self._last_iter_start
                    self._swriter.add_scalar('misc/step_time', run_time, step)
                run_time = time.time() - self._last_iter_start
      "Step time: {run_time} seconds")

                # To keep support in line with the removal of learning rate logging from inside actions, log learning
                # rate to tensorboard. However it now logs ever self._step_freq as opposed to every step
                if self._swriter is not None:
                    self._swriter.add_scalar('param/lr', self.learning_rate, step) 
Example #20
Source File:    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_action_start(self):
        if self.global_rank is None or self.global_rank == 0:
            if self._wandb_name is not None or self._wandb_project is not None:
                if _WANDB_AVAILABLE and is None:
                    wandb.init(name=self._wandb_name, project=self._wandb_project)
                elif _WANDB_AVAILABLE and is not None:
          "Re-using wandb session")
                    logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")
          "Will not log data to weights and biases.")
                    self._wandb_name = None
                    self._wandb_project = None 
Example #21
Source File:    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_action_start(self):
        if self.global_rank is None or self.global_rank == 0:
            if _WANDB_AVAILABLE and is None:
                wandb.init(name=self._name, project=self._project)
                if self._args is not None:
          'init wandb session and append args')
            elif _WANDB_AVAILABLE and is not None:
      "Re-using wandb session")
                logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")
      "Will not log data to weights and biases.")
                self._update_freq = -1 
Example #22
Source File:    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_iteration_end(self):
        # log training metrics
        if self.global_rank is None or self.global_rank == 0:
            if self.step % self._update_freq == 0 and self._update_freq > 0:
                tensors_logged = { self.registered_tensors[t.unique_name].cpu() for t in self._train_tensors}
                # Always log learning rate
                tensors_logged['LR'] = self.learning_rate
Example #23
Source File:    From RLcycle with MIT License 5 votes vote down vote up
def write_log(self, log_dict: dict, step: int = None):
        """Write to WandB log"""
        wandb.log(log_dict, step=step) 
Example #24
Source File:    From catz with MIT License 5 votes vote down vote up
def on_epoch_end(self, epoch, logs):
        validation_X, validation_y = next(
            my_generator(15, val_dir))
        output = self.model.predict(validation_X)
            "input": [wandb.Image(np.concatenate(np.split(c, 5, axis=2), axis=1)) for c in validation_X],
            "output": [wandb.Image(np.concatenate([validation_y[i], o], axis=1)) for i, o in enumerate(output)]
        }, commit=False) 
Example #25
Source File:    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_epoch_end(self):
        if self.global_rank is None or self.global_rank == 0:
            # always log epoch num and epoch_time
            epoch_time = time.time() - self._last_epoch_start
            self.wandb_log({"epoch": self.epoch_num, "epoch_time": epoch_time}) 
Example #26
Source File:    From NeMo with Apache License 2.0 5 votes vote down vote up
def wandb_log(self, tensors_logged):
        if _WANDB_AVAILABLE:
            wandb.log(tensors_logged, step=self.step) 
Example #27
Source File:    From catalyst with Apache License 2.0 5 votes vote down vote up
def _log_metrics(
        metrics: Dict[str, float],
        step: int,
        mode: str,
        if self.metrics_to_log is None:
            metrics_to_log = sorted(metrics.keys())
            metrics_to_log = self.metrics_to_log

        def key_locate(key: str):
            Wandb uses first symbol _ for it service purposes
            because of that fact, we can not send original metric names

                key: metric name

                formatted metric name
            if key.startswith("_"):
                return key[1:]
            return key

        metrics = {
            f"{key_locate(key)}/{mode}{suffix}": value
            for key, value in metrics.items()
            if key in metrics_to_log
        wandb.log(metrics, step=step, commit=commit) 
Example #28
Source File:    From cherry with Apache License 2.0 5 votes vote down vote up
def benchmark_log(original_log):
    def new_log(self, key, value):
        wandb.log({key: value}, step=self.num_steps)
        original_log(self, key, value)
    return new_log 
Example #29
Source File:    From atari-representation-learning with MIT License 5 votes vote down vote up
def remove_low_entropy_labels(episode_labels, entropy_threshold=0.3):
    flat_label_list = list(chain.from_iterable(episode_labels))
    counts = {}

    for label_dict in flat_label_list:
        for k in label_dict:
            counts[k] = counts.get(k, {})
            v = label_dict[k]
            counts[k][v] = counts[k].get(v, 0) + 1
    low_entropy_labels = []

    entropy_dict = {}
    for k in counts:
        entropy = torch.distributions.Categorical(
            torch.tensor([x / len(flat_label_list) for x in counts[k].values()])).entropy()
        entropy_dict['entropy_' + k] = entropy
        if entropy < entropy_threshold:
            print("Deleting {} for being too low in entropy! Sorry, dood!".format(k))

    for e in episode_labels:
        for obs in e:
            for key in low_entropy_labels:
                del obs[key]
    # wandb.log(entropy_dict)
    return episode_labels, entropy_dict 
Example #30
Source File:    From atari-representation-learning with MIT License 5 votes vote down vote up
def get_pretrained_rl_representations(args, steps):
    checkpoint = checkpointed_steps_full_sorted[args.checkpoint_index]
    episodes, episode_labels, mean_reward = get_ppo_representations(args, steps, checkpoint)
    wandb.log({"reward": mean_reward, "checkpoint": checkpoint})
    return episodes, episode_labels