Python wandb.log() Examples

The following are 30 code examples of wandb.log(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module wandb , or try the search function .
Example #1
Source File: base_disentangler.py    From disentanglement-pytorch with GNU General Public License v3.0 6 votes vote down vote up
def visualize_recon(self, input_image, recon_image, test=False):
        input_image = torchvision.utils.make_grid(input_image)
        recon_image = torchvision.utils.make_grid(recon_image)

        if self.white_line is None:
            self.white_line = torch.ones((3, input_image.size(1), 10)).to(self.device)

        samples = torch.cat([input_image, self.white_line, recon_image], dim=2)

        if self.file_save:
            if test:
                file_name = os.path.join(self.test_output_dir, '{}_{}.{}'.format(c.RECON, self.iter, c.JPG))
            else:
                file_name = os.path.join(self.train_output_dir, '{}.{}'.format(c.RECON, c.JPG))
            torchvision.utils.save_image(samples, file_name)

        if self.use_wandb:
            import wandb
            wandb.log({c.RECON_IMAGE: wandb.Image(samples, caption=str(self.iter))},
                      step=self.iter) 
Example #2
Source File: distributed_logger.py    From rl_algorithms with MIT License 6 votes vote down vote up
def run(self):
        """Run main logging loop; continuously receive data and log."""
        if self.args.log:
            self.set_wandb()

        while self.update_step < self.args.max_update_step:
            self.recv_log_info()
            if self.log_info_queue:  # if non-empty
                log_info_id = self.log_info_queue.pop()
                log_info = pa.deserialize(log_info_id)
                state_dict = log_info["state_dict"]
                log_value = log_info["log_value"]
                self.update_step = log_value["update_step"]

                self.synchronize(state_dict)
                avg_score = self.test(self.update_step)
                log_value["avg_score"] = avg_score
                self.write_log(log_value) 
Example #3
Source File: deprecated_callbacks.py    From NeMo with Apache License 2.0 6 votes vote down vote up
def __init__(
        self, train_tensors=[], wandb_name=None, wandb_project=None, args=None, update_freq=25,
    ):
        """
        Args:
            train_tensors: list of tensors to evaluate and log based on training batches
            wandb_name: wandb experiment name
            wandb_project: wandb project name
            args: argparse flags - will be logged as hyperparameters
            update_freq: frequency with which to log updates
        """
        super().__init__()

        if not _WANDB_AVAILABLE:
            logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")

        self._update_freq = update_freq
        self._train_tensors = train_tensors
        self._name = wandb_name
        self._project = wandb_project
        self._args = args 
Example #4
Source File: agent.py    From rl_algorithms with MIT License 6 votes vote down vote up
def write_log(self, log_value: tuple):
        i, score, policy_loss, value_loss = log_value
        total_loss = policy_loss + value_loss

        print(
            "[INFO] episode %d\tepisode step: %d\ttotal score: %d\n"
            "total loss: %.4f\tpolicy loss: %.4f\tvalue loss: %.4f\n"
            % (i, self.episode_step, score, total_loss, policy_loss, value_loss)
        )

        if self.args.log:
            wandb.log(
                {
                    "total loss": total_loss,
                    "policy loss": policy_loss,
                    "value loss": value_loss,
                    "score": score,
                }
            ) 
Example #5
Source File: agent.py    From rl_algorithms with MIT License 6 votes vote down vote up
def write_log(
        self, log_value: tuple,
    ):
        i_episode, n_step, score, actor_loss, critic_loss, total_loss = log_value
        print(
            "[INFO] episode %d\tepisode steps: %d\ttotal score: %d\n"
            "total loss: %f\tActor loss: %f\tCritic loss: %f\n"
            % (i_episode, n_step, score, total_loss, actor_loss, critic_loss)
        )

        if self.args.log:
            wandb.log(
                {
                    "total loss": total_loss,
                    "actor loss": actor_loss,
                    "critic loss": critic_loss,
                    "score": score,
                }
            ) 
Example #6
Source File: label_preprocess.py    From atari-representation-learning with MIT License 6 votes vote down vote up
def remove_duplicates(tr_eps, val_eps, test_eps, test_labels):
    """
    Remove any items in test_eps (&test_labels) which are present in tr/val_eps
    """
    flat_tr = list(chain.from_iterable(tr_eps))
    flat_val = list(chain.from_iterable(val_eps))
    tr_val_set = set([x.numpy().tostring() for x in flat_tr] + [x.numpy().tostring() for x in flat_val])
    flat_test = list(chain.from_iterable(test_eps))

    for i, episode in enumerate(test_eps[:]):
        test_labels[i] = [label for obs, label in zip(test_eps[i], test_labels[i]) if obs.numpy().tostring() not in tr_val_set]
        test_eps[i] = [obs for obs in episode if obs.numpy().tostring() not in tr_val_set]
    test_len = len(list(chain.from_iterable(test_eps)))
    dups = len(flat_test) - test_len
    print('Duplicates: {}, Test Len: {}'.format(dups, test_len))
    #wandb.log({'Duplicates': dups, 'Test Len': test_len})
    return test_eps, test_labels 
Example #7
Source File: callbacks.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def _wandb_log(tensors_logged, step):
        if _WANDB_AVAILABLE:
            wandb.log(tensors_logged, step=step) 
Example #8
Source File: visualization.py    From tape with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
        logger.warn("Cannot log config when using a TBVisualizer. "
                    "Configure wandb for this functionality") 
Example #9
Source File: visualization.py    From tape with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def log_metrics(self,
                    metrics_dict: typing.Dict[str, float],
                    split: str,
                    step: int):
        wandb.log({f"{split.capitalize()} {name.capitalize()}": value
                   for name, value in metrics_dict.items()}, step=step) 
Example #10
Source File: callbacks.py    From keras-rl with MIT License 5 votes vote down vote up
def on_episode_end(self, episode, logs):
        """ Compute and log training statistics of the episode when done """
        duration = timeit.default_timer() - self.episode_start[episode]
        episode_steps = len(self.observations[episode])

        metrics = np.array(self.metrics[episode])
        metrics_dict = {}
        with warnings.catch_warnings():
            warnings.filterwarnings('error')
            for idx, name in enumerate(self.metrics_names):
                try:
                    metrics_dict[name] = np.nanmean(metrics[:, idx])
                except Warning:
                    metrics_dict[name] = float('nan')

        wandb.log({
            'step': self.step,
            'episode': episode + 1,
            'duration': duration,
            'episode_steps': episode_steps,
            'sps': float(episode_steps) / duration,
            'episode_reward': np.sum(self.rewards[episode]),
            'reward_mean': np.mean(self.rewards[episode]),
            'reward_min': np.min(self.rewards[episode]),
            'reward_max': np.max(self.rewards[episode]),
            'action_mean': np.mean(self.actions[episode]),
            'action_min': np.min(self.actions[episode]),
            'action_max': np.max(self.actions[episode]),
            'obs_mean': np.mean(self.observations[episode]),
            'obs_min': np.min(self.observations[episode]),
            'obs_max': np.max(self.observations[episode]),
            **metrics_dict
        })

        # Free up resources.
        del self.episode_start[episode]
        del self.observations[episode]
        del self.rewards[episode]
        del self.actions[episode]
        del self.metrics[episode] 
Example #11
Source File: logx.py    From firedup with MIT License 5 votes vote down vote up
def log(self, msg, color="green"):
        """Print a colorized message to stdout."""
        if proc_id() == 0:
            print(colorize(msg, color, bold=True)) 
Example #12
Source File: __init__.py    From rtrl with MIT License 5 votes vote down vote up
def run_wandb(entity, project, run_id, run_cls: type = Training, checkpoint_path: str = None):
  """run and save config and stats to https://wandb.com"""
  wandb_dir = mkdtemp()  # prevent wandb from polluting the home directory
  atexit.register(shutil.rmtree, wandb_dir, ignore_errors=True)  # clean up after wandb atexit handler finishes
  import wandb
  config = partial_to_dict(run_cls)
  config['seed'] = config['seed'] or randrange(1, 1000000)  # if seed == 0 replace with random
  config['environ'] = log_environment_variables()
  config['git'] = git_info()
  resume = checkpoint_path and exists(checkpoint_path)
  wandb.init(dir=wandb_dir, entity=entity, project=project, id=run_id, resume=resume, config=config)
  for stats in iterate_episodes(run_cls, checkpoint_path):
    [wandb.log(json.loads(s.to_json())) for s in stats] 
Example #13
Source File: logx.py    From firedup with MIT License 5 votes vote down vote up
def save_state(self, state_dict, model, itr=None):
        """
        Saves the state of an experiment.

        To be clear: this is about saving *state*, not logging diagnostics.
        All diagnostic logging is separate from this function. This function
        will save whatever is in ``state_dict``---usually just a copy of the
        environment---and the most recent copy of the model via ``model``.

        Call with any frequency you prefer. If you only want to maintain a
        single state and overwrite it at each call with the most recent
        version, leave ``itr=None``. If you want to keep all of the states you
        save, provide unique (increasing) values for 'itr'.

        Args:
            state_dict (dict): Dictionary containing essential elements to
                describe the current state of training.
            model (nn.Module): A model which contains the policy.
            itr: An int, or None. Current iteration of training.
        """
        if proc_id() == 0:
            fname = "vars.pkl" if itr is None else "vars%d.pkl" % itr
            try:
                joblib.dump(state_dict, osp.join(self.output_dir, fname))
            except:
                self.log("Warning: could not pickle state_dict.", color="red")
            self._torch_save(model, itr) 
Example #14
Source File: logx.py    From firedup with MIT License 5 votes vote down vote up
def dump_tabular(self):
        """
        Write all of the diagnostics from the current iteration.

        Writes both to stdout, and to the output file.
        """
        if proc_id() == 0:
            vals = []
            key_lens = [len(key) for key in self.log_headers]
            max_key_len = max(15, max(key_lens))
            keystr = "%" + "%d" % max_key_len
            fmt = "| " + keystr + "s | %15s |"
            n_slashes = 22 + max_key_len
            print("-" * n_slashes)
            for i, key in enumerate(self.log_headers):
                val = self.log_current_row.get(key, "")
                total_env_interacts = self.log_current_row.get("TotalEnvInteracts", "")
                wandb.log({key: val}, step=total_env_interacts, commit=(i+1)==len(self.log_headers))
                valstr = "%8.3g" % val if hasattr(val, "__float__") else val
                print(fmt % (key, valstr))
                vals.append(val)
            print("-" * n_slashes)
            if self.output_file is not None:
                if self.first_row:
                    self.output_file.write("\t".join(self.log_headers) + "\n")
                self.output_file.write("\t".join(map(str, vals)) + "\n")
                self.output_file.flush()
        self.log_current_row.clear()
        self.first_row = False 
Example #15
Source File: logx.py    From firedup with MIT License 5 votes vote down vote up
def log_tabular(self, key, val=None, with_min_and_max=False, average_only=False):
        """
        Log a value or possibly the mean/std/min/max values of a diagnostic.

        Args:
            key (string): The name of the diagnostic. If you are logging a
                diagnostic whose state has previously been saved with
                ``store``, the key here has to match the key you used there.

            val: A value for the diagnostic. If you have previously saved
                values for this key via ``store``, do *not* provide a ``val``
                here.

            with_min_and_max (bool): If true, log min and max values of the
                diagnostic over the epoch.

            average_only (bool): If true, do not log the standard deviation
                of the diagnostic over the epoch.
        """
        if val is not None:
            super().log_tabular(key, val)
        else:
            v = self.epoch_dict[key]
            vals = (
                np.concatenate(v)
                if isinstance(v[0], np.ndarray) and len(v[0].shape) > 0
                else v
            )
            stats = mpi_statistics_scalar(vals, with_min_and_max=with_min_and_max)
            super().log_tabular(key if average_only else "Average" + key, stats[0])
            if not (average_only):
                super().log_tabular("Std" + key, stats[1])
            if with_min_and_max:
                super().log_tabular("Max" + key, stats[3])
                super().log_tabular("Min" + key, stats[2])
        self.epoch_dict[key] = [] 
Example #16
Source File: callbacks.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_action_start(self, state):
        if state["global_rank"] is None or state["global_rank"] == 0:
            if _WANDB_AVAILABLE and wandb.run is None:
                wandb.init(name=self._name, project=self._project)
                if self._args is not None:
                    wandb.config.update(self._args)
            elif _WANDB_AVAILABLE and wandb.run is not None:
                logging.info("Re-using wandb session")
            else:
                logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")
                logging.info("Will not log data to weights and biases.")
                self._step_freq = -1 
Example #17
Source File: callbacks.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_step_end(self, state):
        # log training metrics
        if state["global_rank"] is None or state["global_rank"] == 0:
            if state["step"] % self._step_freq == 0 and self._step_freq > 0:
                tensors_logged = {t: state["tensors"].get_tensor(t).cpu() for t in self._tensors_to_log}
                # Always log learning rate
                if self._log_lr:
                    tensors_logged['LR'] = state["optimizers"][0].param_groups[0]['lr']
                self._wandb_log(tensors_logged, state["step"]) 
Example #18
Source File: wandb_logger.py    From catalyst with Apache License 2.0 5 votes vote down vote up
def on_stage_start(self, runner: IRunner):
        """Initialize Weights & Biases."""
        wandb.init(**self.logging_params, reinit=True, dir=str(runner.logdir))
        wandb.watch(
            models=runner.model, criterion=runner.criterion, log=self.log
        ) 
Example #19
Source File: deprecated_callbacks.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_iteration_end(self):
        if self.global_rank is None or self.global_rank == 0:
            step = self.step
            if step % self._step_freq == 0:
                tensor_values = [self.registered_tensors[t.unique_name] for t in self.tensors]
                logging.info(f"Step: {step}")
                if self._print_func:
                    self._print_func(tensor_values)
                sys.stdout.flush()
                if self._swriter is not None:
                    if self._get_tb_values:
                        tb_objects = self._get_tb_values(tensor_values)
                        for name, value in tb_objects:
                            value = value.item()
                            self._swriter.add_scalar(name, value, step)
                    if self._log_to_tb_func:
                        self._log_to_tb_func(self._swriter, tensor_values, step)
                    run_time = time.time() - self._last_iter_start
                    self._swriter.add_scalar('misc/step_time', run_time, step)
                run_time = time.time() - self._last_iter_start
                logging.info(f"Step time: {run_time} seconds")

                # To keep support in line with the removal of learning rate logging from inside actions, log learning
                # rate to tensorboard. However it now logs ever self._step_freq as opposed to every step
                if self._swriter is not None:
                    self._swriter.add_scalar('param/lr', self.learning_rate, step) 
Example #20
Source File: deprecated_callbacks.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_action_start(self):
        if self.global_rank is None or self.global_rank == 0:
            if self._wandb_name is not None or self._wandb_project is not None:
                if _WANDB_AVAILABLE and wandb.run is None:
                    wandb.init(name=self._wandb_name, project=self._wandb_project)
                elif _WANDB_AVAILABLE and wandb.run is not None:
                    logging.info("Re-using wandb session")
                else:
                    logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")
                    logging.info("Will not log data to weights and biases.")
                    self._wandb_name = None
                    self._wandb_project = None 
Example #21
Source File: deprecated_callbacks.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_action_start(self):
        if self.global_rank is None or self.global_rank == 0:
            if _WANDB_AVAILABLE and wandb.run is None:
                wandb.init(name=self._name, project=self._project)
                if self._args is not None:
                    logging.info('init wandb session and append args')
                    wandb.config.update(self._args)
            elif _WANDB_AVAILABLE and wandb.run is not None:
                logging.info("Re-using wandb session")
            else:
                logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")
                logging.info("Will not log data to weights and biases.")
                self._update_freq = -1 
Example #22
Source File: deprecated_callbacks.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_iteration_end(self):
        # log training metrics
        if self.global_rank is None or self.global_rank == 0:
            if self.step % self._update_freq == 0 and self._update_freq > 0:
                tensors_logged = {t.name: self.registered_tensors[t.unique_name].cpu() for t in self._train_tensors}
                # Always log learning rate
                tensors_logged['LR'] = self.learning_rate
                self.wandb_log(tensors_logged) 
Example #23
Source File: logger.py    From RLcycle with MIT License 5 votes vote down vote up
def write_log(self, log_dict: dict, step: int = None):
        """Write to WandB log"""
        wandb.log(log_dict, step=step) 
Example #24
Source File: train.py    From catz with MIT License 5 votes vote down vote up
def on_epoch_end(self, epoch, logs):
        validation_X, validation_y = next(
            my_generator(15, val_dir))
        output = self.model.predict(validation_X)
        wandb.log({
            "input": [wandb.Image(np.concatenate(np.split(c, 5, axis=2), axis=1)) for c in validation_X],
            "output": [wandb.Image(np.concatenate([validation_y[i], o], axis=1)) for i, o in enumerate(output)]
        }, commit=False) 
Example #25
Source File: deprecated_callbacks.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def on_epoch_end(self):
        if self.global_rank is None or self.global_rank == 0:
            # always log epoch num and epoch_time
            epoch_time = time.time() - self._last_epoch_start
            self.wandb_log({"epoch": self.epoch_num, "epoch_time": epoch_time}) 
Example #26
Source File: deprecated_callbacks.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def wandb_log(self, tensors_logged):
        if _WANDB_AVAILABLE:
            wandb.log(tensors_logged, step=self.step) 
Example #27
Source File: wandb_logger.py    From catalyst with Apache License 2.0 5 votes vote down vote up
def _log_metrics(
        self,
        metrics: Dict[str, float],
        step: int,
        mode: str,
        suffix="",
        commit=True,
    ):
        if self.metrics_to_log is None:
            metrics_to_log = sorted(metrics.keys())
        else:
            metrics_to_log = self.metrics_to_log

        def key_locate(key: str):
            """
            Wandb uses first symbol _ for it service purposes
            because of that fact, we can not send original metric names

            Args:
                key: metric name

            Returns:
                formatted metric name
            """
            if key.startswith("_"):
                return key[1:]
            return key

        metrics = {
            f"{key_locate(key)}/{mode}{suffix}": value
            for key, value in metrics.items()
            if key in metrics_to_log
        }
        wandb.log(metrics, step=step, commit=commit) 
Example #28
Source File: benchmark.py    From cherry with Apache License 2.0 5 votes vote down vote up
def benchmark_log(original_log):
    def new_log(self, key, value):
        wandb.log({key: value}, step=self.num_steps)
        original_log(self, key, value)
    return new_log 
Example #29
Source File: label_preprocess.py    From atari-representation-learning with MIT License 5 votes vote down vote up
def remove_low_entropy_labels(episode_labels, entropy_threshold=0.3):
    flat_label_list = list(chain.from_iterable(episode_labels))
    counts = {}

    for label_dict in flat_label_list:
        for k in label_dict:
            counts[k] = counts.get(k, {})
            v = label_dict[k]
            counts[k][v] = counts[k].get(v, 0) + 1
    low_entropy_labels = []

    entropy_dict = {}
    for k in counts:
        entropy = torch.distributions.Categorical(
            torch.tensor([x / len(flat_label_list) for x in counts[k].values()])).entropy()
        entropy_dict['entropy_' + k] = entropy
        if entropy < entropy_threshold:
            print("Deleting {} for being too low in entropy! Sorry, dood!".format(k))
            low_entropy_labels.append(k)

    for e in episode_labels:
        for obs in e:
            for key in low_entropy_labels:
                del obs[key]
    # wandb.log(entropy_dict)
    return episode_labels, entropy_dict 
Example #30
Source File: pretrained_agents.py    From atari-representation-learning with MIT License 5 votes vote down vote up
def get_pretrained_rl_representations(args, steps):
    checkpoint = checkpointed_steps_full_sorted[args.checkpoint_index]
    episodes, episode_labels, mean_reward = get_ppo_representations(args, steps, checkpoint)
    wandb.log({"reward": mean_reward, "checkpoint": checkpoint})
    return episodes, episode_labels