Python hparams.hparams.input_type() Examples
The following are 18
code examples of hparams.hparams.input_type().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
hparams.hparams
, or try the search function
.
Example #1
Source File: model.py From WaveRNN-Pytorch with MIT License | 6 votes |
def build_model(): """build model with hparams settings """ if hp.input_type == 'raw': print('building model with Beta distribution output') elif hp.input_type == 'mixture': print("building model with mixture of logistic output") elif hp.input_type == 'bits': print("building model with quantized bit audio") elif hp.input_type == 'mulaw': print("building model with quantized mulaw encoding") else: raise ValueError('input_type provided not supported') model = Model(hp.rnn_dims, hp.fc_dims, hp.bits, hp.pad, hp.upsample_factors, hp.num_mels, hp.compute_dims, hp.res_out_dims, hp.res_blocks) return model
Example #2
Source File: preprocess.py From WaveRNN-Pytorch with MIT License | 6 votes |
def get_wav_mel(path): """Given path to .wav file, get the quantized wav and mel spectrogram as numpy vectors """ wav = load_wav(path) mel = melspectrogram(wav) if hp.input_type == 'raw': return wav.astype(np.float32), mel elif hp.input_type == 'mulaw': quant = mulaw_quantize(wav, hp.mulaw_quantize_channels) return quant.astype(np.int), mel elif hp.input_type == 'bits': quant = quantize(wav) return quant.astype(np.int), mel else: raise ValueError("hp.input_type {} not recognized".format(hp.input_type))
Example #3
Source File: model.py From WaveRNN-Pytorch with MIT License | 6 votes |
def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors, feat_dims, compute_dims, res_out_dims, res_blocks): super().__init__() if hp.input_type == 'raw': self.n_classes = 2 elif hp.input_type == 'mixture': # mixture requires multiple of 3, default at 10 component mixture, i.e 3 x 10 = 30 self.n_classes = 30 elif hp.input_type == 'mulaw': self.n_classes = hp.mulaw_quantize_channels elif hp.input_type == 'bits': self.n_classes = 2**bits else: raise ValueError("input_type: {hp.input_type} not supported") self.rnn_dims = rnn_dims self.aux_dims = res_out_dims // 4 self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad) self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims) self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True) self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) self.fc3 = nn.Linear(fc_dims, self.n_classes) num_params(self)
Example #4
Source File: train.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 5 votes |
def build_model(): if is_mulaw_quantize(hparams.input_type): if hparams.out_channels != hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'") if hparams.upsample_conditional_features and hparams.cin_channels < 0: s = "Upsample conv layers were specified while local conditioning disabled. " s += "Notice that upsample conv layers will never be used." warn(s) model = getattr(builder, hparams.builder)( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, weight_normalization=hparams.weight_normalization, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, upsample_conditional_features=hparams.upsample_conditional_features, upsample_scales=hparams.upsample_scales, freq_axis_kernel_size=hparams.freq_axis_kernel_size, scalar_input=is_scalar_input(hparams.input_type), legacy=hparams.legacy, ) return model
Example #5
Source File: train.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 5 votes |
def build_model(): if is_mulaw_quantize(hparams.input_type): if hparams.out_channels != hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'") if hparams.upsample_conditional_features and hparams.cin_channels < 0: s = "Upsample conv layers were specified while local conditioning disabled. " s += "Notice that upsample conv layers will never be used." warn(s) model = getattr(builder, hparams.builder)( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, weight_normalization=hparams.weight_normalization, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, upsample_conditional_features=hparams.upsample_conditional_features, upsample_scales=hparams.upsample_scales, freq_axis_kernel_size=hparams.freq_axis_kernel_size, scalar_input=is_scalar_input(hparams.input_type), legacy=hparams.legacy, ) return model
Example #6
Source File: train.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 5 votes |
def build_model(): if is_mulaw_quantize(hparams.input_type): if hparams.out_channels != hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'") if hparams.upsample_conditional_features and hparams.cin_channels < 0: s = "Upsample conv layers were specified while local conditioning disabled. " s += "Notice that upsample conv layers will never be used." warn(s) model = getattr(builder, hparams.builder)( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, weight_normalization=hparams.weight_normalization, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, upsample_conditional_features=hparams.upsample_conditional_features, upsample_scales=hparams.upsample_scales, freq_axis_kernel_size=hparams.freq_axis_kernel_size, scalar_input=is_scalar_input(hparams.input_type), legacy=hparams.legacy, ) return model
Example #7
Source File: train.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 5 votes |
def build_model(): if is_mulaw_quantize(hparams.input_type): if hparams.out_channels != hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'") if hparams.upsample_conditional_features and hparams.cin_channels < 0: s = "Upsample conv layers were specified while local conditioning disabled. " s += "Notice that upsample conv layers will never be used." warn(s) model = getattr(builder, hparams.builder)( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, weight_normalization=hparams.weight_normalization, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, upsample_conditional_features=hparams.upsample_conditional_features, upsample_scales=hparams.upsample_scales, freq_axis_kernel_size=hparams.freq_axis_kernel_size, scalar_input=is_scalar_input(hparams.input_type), legacy=hparams.legacy, ) return model
Example #8
Source File: dataset.py From WaveRNN-Pytorch with MIT License | 5 votes |
def discrete_collate(batch) : """collate function used for discrete wav output, such as 9-bit, mulaw-discrete, etc. """ pad = 2 mel_win = hp.seq_len // hp.hop_size + 2 * pad max_offsets = [x[0].shape[-1] - (mel_win + 2 * pad) for x in batch] mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] sig_offsets = [(offset + pad) * hp.hop_size for offset in mel_offsets] mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] \ for i, x in enumerate(batch)] coarse = [x[1][sig_offsets[i]:sig_offsets[i] + hp.seq_len + 1] \ for i, x in enumerate(batch)] mels = np.stack(mels).astype(np.float32) coarse = np.stack(coarse).astype(np.int64) mels = torch.FloatTensor(mels) coarse = torch.LongTensor(coarse) if hp.input_type == 'bits': x_input = 2 * coarse[:, :hp.seq_len].float() / (2**hp.bits - 1.) - 1. elif hp.input_type == 'mulaw': x_input = inv_mulaw_quantize(coarse[:, :hp.seq_len], hp.mulaw_quantize_channels) y_coarse = coarse[:, 1:] return x_input, mels, y_coarse
Example #9
Source File: model.py From WaveRNN-Pytorch with MIT License | 5 votes |
def forward(self, x, mels) : bsize = x.size(0) h1 = torch.zeros(1, bsize, self.rnn_dims).cuda() h2 = torch.zeros(1, bsize, self.rnn_dims).cuda() mels, aux = self.upsample(mels) aux_idx = [self.aux_dims * i for i in range(5)] a1 = aux[:, :, aux_idx[0]:aux_idx[1]] a2 = aux[:, :, aux_idx[1]:aux_idx[2]] a3 = aux[:, :, aux_idx[2]:aux_idx[3]] a4 = aux[:, :, aux_idx[3]:aux_idx[4]] x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2) x = self.I(x) res = x x, _ = self.rnn1(x, h1) x = x + res res = x x = torch.cat([x, a2], dim=2) x, _ = self.rnn2(x, h2) x = x + res x = torch.cat([x, a3], dim=2) x = F.relu(self.fc1(x)) x = torch.cat([x, a4], dim=2) x = F.relu(self.fc2(x)) x = self.fc3(x) if hp.input_type == 'raw': return x elif hp.input_type == 'mixture': return x elif hp.input_type == 'bits' or hp.input_type == 'mulaw': return F.log_softmax(x, dim=-1) else: raise ValueError("input_type: {hp.input_type} not supported")
Example #10
Source File: train.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 4 votes |
def save_states(global_step, writer, y_hat, y, input_lengths, checkpoint_dir=None): print("Save intermediate states at step {}".format(global_step)) idx = np.random.randint(0, len(y_hat)) length = input_lengths[idx].data.cpu().item() # (B, C, T) if y_hat.dim() == 4: y_hat = y_hat.squeeze(-1) if is_mulaw_quantize(hparams.input_type): # (B, T) y_hat = F.softmax(y_hat, dim=1).max(1)[1] # (T,) y_hat = y_hat[idx].data.cpu().long().numpy() y = y[idx].view(-1).data.cpu().long().numpy() y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels) y = P.inv_mulaw_quantize(y, hparams.quantize_channels) else: # (B, T) y_hat = sample_from_discretized_mix_logistic( y_hat, log_scale_min=hparams.log_scale_min) # (T,) y_hat = y_hat[idx].view(-1).data.cpu().numpy() y = y[idx].view(-1).data.cpu().numpy() if is_mulaw(hparams.input_type): y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels) y = P.inv_mulaw(y, hparams.quantize_channels) # Mask by length y_hat[length:] = 0 y[length:] = 0 # Save audio audio_dir = join(checkpoint_dir, "audio") os.makedirs(audio_dir, exist_ok=True) path = join(audio_dir, "step{:09d}_predicted.wav".format(global_step)) librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) path = join(audio_dir, "step{:09d}_target.wav".format(global_step)) librosa.output.write_wav(path, y, sr=hparams.sample_rate)
Example #11
Source File: train.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 4 votes |
def train_loop(device, model, data_loaders, optimizer, writer, checkpoint_dir=None): if is_mulaw_quantize(hparams.input_type): criterion = MaskedCrossEntropyLoss() else: criterion = DiscretizedMixturelogisticLoss() if hparams.exponential_moving_average: ema = ExponentialMovingAverage(hparams.ema_decay) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) else: ema = None global global_step, global_epoch, global_test_step while global_epoch < hparams.nepochs: for phase, data_loader in data_loaders.items(): train = (phase == "train") running_loss = 0. test_evaluated = False for step, (x, y, c, g, input_lengths) in tqdm(enumerate(data_loader)): # Whether to save eval (i.e., online decoding) result do_eval = False eval_dir = join(checkpoint_dir, "{}_eval".format(phase)) # Do eval per eval_interval for train if train and global_step > 0 \ and global_step % hparams.train_eval_interval == 0: do_eval = True # Do eval for test # NOTE: Decoding WaveNet is quite time consuming, so # do only once in a single epoch for testset if not train and not test_evaluated \ and global_epoch % hparams.test_eval_epoch_interval == 0: do_eval = True test_evaluated = True if do_eval: print("[{}] Eval at train step {}".format(phase, global_step)) # Do step running_loss += __train_step(device, phase, global_epoch, global_step, global_test_step, model, optimizer, writer, criterion, x, y, c, g, input_lengths, checkpoint_dir, eval_dir, do_eval, ema) # update global state if train: global_step += 1 else: global_test_step += 1 # log per epoch averaged_loss = running_loss / len(data_loader) writer.add_scalar("{} loss (per epoch)".format(phase), averaged_loss, global_epoch) print("Step {} [{}] Loss: {}".format( global_step, phase, running_loss / len(data_loader))) global_epoch += 1
Example #12
Source File: ljspeech.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 4 votes |
def _process_utterance(out_dir, index, wav_path, text): # Load the audio to a numpy array: wav = audio.load_wav(wav_path) if hparams.rescaling: wav = wav / np.abs(wav).max() * hparams.rescaling_max # Mu-law quantize if is_mulaw_quantize(hparams.input_type): # [0, quantize_channels) out = P.mulaw_quantize(wav, hparams.quantize_channels) # Trim silences start, end = audio.start_and_end_indices(out, hparams.silence_threshold) wav = wav[start:end] out = out[start:end] constant_values = P.mulaw_quantize(0, hparams.quantize_channels) out_dtype = np.int16 elif is_mulaw(hparams.input_type): # [-1, 1] out = P.mulaw(wav, hparams.quantize_channels) constant_values = P.mulaw(0.0, hparams.quantize_channels) out_dtype = np.float32 else: # [-1, 1] out = wav constant_values = 0.0 out_dtype = np.float32 # Compute a mel-scale spectrogram from the trimmed wav: # (N, D) mel_spectrogram = audio.melspectrogram(wav).astype(np.float32).T # lws pads zeros internally before performing stft # this is needed to adjust time resolution between audio and mel-spectrogram l, r = audio.lws_pad_lr(wav, hparams.fft_size, audio.get_hop_size()) # zero pad for quantized signal out = np.pad(out, (l, r), mode="constant", constant_values=constant_values) N = mel_spectrogram.shape[0] assert len(out) >= N * audio.get_hop_size() # time resolution adjustment # ensure length of raw audio is multiple of hop_size so that we can use # transposed convolution to upsample out = out[:N * audio.get_hop_size()] assert len(out) % audio.get_hop_size() == 0 timesteps = len(out) # Write the spectrograms to disk: audio_filename = 'ljspeech-audio-%05d.npy' % index mel_filename = 'ljspeech-mel-%05d.npy' % index np.save(os.path.join(out_dir, audio_filename), out.astype(out_dtype), allow_pickle=False) np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.astype(np.float32), allow_pickle=False) # Return a tuple describing this training example: return (audio_filename, mel_filename, timesteps, text)
Example #13
Source File: train.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 4 votes |
def save_states(global_step, writer, y_hat, y, input_lengths, checkpoint_dir=None): print("Save intermediate states at step {}".format(global_step)) idx = np.random.randint(0, len(y_hat)) length = input_lengths[idx].data.cpu().item() # (B, C, T) if y_hat.dim() == 4: y_hat = y_hat.squeeze(-1) if is_mulaw_quantize(hparams.input_type): # (B, T) y_hat = F.softmax(y_hat, dim=1).max(1)[1] # (T,) y_hat = y_hat[idx].data.cpu().long().numpy() y = y[idx].view(-1).data.cpu().long().numpy() y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels) y = P.inv_mulaw_quantize(y, hparams.quantize_channels) else: # (B, T) y_hat = sample_from_discretized_mix_logistic( y_hat, log_scale_min=hparams.log_scale_min) # (T,) y_hat = y_hat[idx].view(-1).data.cpu().numpy() y = y[idx].view(-1).data.cpu().numpy() if is_mulaw(hparams.input_type): y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels) y = P.inv_mulaw(y, hparams.quantize_channels) # Mask by length y_hat[length:] = 0 y[length:] = 0 # Save audio audio_dir = join(checkpoint_dir, "audio") os.makedirs(audio_dir, exist_ok=True) path = join(audio_dir, "step{:09d}_predicted.wav".format(global_step)) librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) path = join(audio_dir, "step{:09d}_target.wav".format(global_step)) librosa.output.write_wav(path, y, sr=hparams.sample_rate)
Example #14
Source File: ljspeech.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 4 votes |
def _process_utterance(out_dir, index, wav_path, text): # Load the audio to a numpy array: wav = audio.load_wav(wav_path) if hparams.rescaling: wav = wav / np.abs(wav).max() * hparams.rescaling_max # Mu-law quantize if is_mulaw_quantize(hparams.input_type): # [0, quantize_channels) out = P.mulaw_quantize(wav, hparams.quantize_channels) # Trim silences start, end = audio.start_and_end_indices(out, hparams.silence_threshold) wav = wav[start:end] out = out[start:end] constant_values = P.mulaw_quantize(0, hparams.quantize_channels) out_dtype = np.int16 elif is_mulaw(hparams.input_type): # [-1, 1] out = P.mulaw(wav, hparams.quantize_channels) constant_values = P.mulaw(0.0, hparams.quantize_channels) out_dtype = np.float32 else: # [-1, 1] out = wav constant_values = 0.0 out_dtype = np.float32 # Compute a mel-scale spectrogram from the trimmed wav: # (N, D) mel_spectrogram = audio.melspectrogram(wav).astype(np.float32).T # lws pads zeros internally before performing stft # this is needed to adjust time resolution between audio and mel-spectrogram l, r = audio.lws_pad_lr(wav, hparams.fft_size, audio.get_hop_size()) # zero pad for quantized signal out = np.pad(out, (l, r), mode="constant", constant_values=constant_values) N = mel_spectrogram.shape[0] assert len(out) >= N * audio.get_hop_size() # time resolution adjustment # ensure length of raw audio is multiple of hop_size so that we can use # transposed convolution to upsample out = out[:N * audio.get_hop_size()] assert len(out) % audio.get_hop_size() == 0 timesteps = len(out) # Write the spectrograms to disk: audio_filename = 'ljspeech-audio-%05d.npy' % index mel_filename = 'ljspeech-mel-%05d.npy' % index np.save(os.path.join(out_dir, audio_filename), out.astype(out_dtype), allow_pickle=False) np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.astype(np.float32), allow_pickle=False) # Return a tuple describing this training example: return (audio_filename, mel_filename, timesteps, text)
Example #15
Source File: train.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 4 votes |
def save_states(global_step, writer, y_hat, y, input_lengths, checkpoint_dir=None): print("Save intermediate states at step {}".format(global_step)) idx = np.random.randint(0, len(y_hat)) length = input_lengths[idx].data.cpu().item() # (B, C, T) if y_hat.dim() == 4: y_hat = y_hat.squeeze(-1) if is_mulaw_quantize(hparams.input_type): # (B, T) y_hat = F.softmax(y_hat, dim=1).max(1)[1] # (T,) y_hat = y_hat[idx].data.cpu().long().numpy() y = y[idx].view(-1).data.cpu().long().numpy() y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels) y = P.inv_mulaw_quantize(y, hparams.quantize_channels) else: # (B, T) y_hat = sample_from_discretized_mix_logistic( y_hat, log_scale_min=hparams.log_scale_min) # (T,) y_hat = y_hat[idx].view(-1).data.cpu().numpy() y = y[idx].view(-1).data.cpu().numpy() if is_mulaw(hparams.input_type): y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels) y = P.inv_mulaw(y, hparams.quantize_channels) # Mask by length y_hat[length:] = 0 y[length:] = 0 # Save audio audio_dir = join(checkpoint_dir, "audio") os.makedirs(audio_dir, exist_ok=True) path = join(audio_dir, "step{:09d}_predicted.wav".format(global_step)) librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) path = join(audio_dir, "step{:09d}_target.wav".format(global_step)) librosa.output.write_wav(path, y, sr=hparams.sample_rate)
Example #16
Source File: train.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 4 votes |
def train_loop(device, model, data_loaders, optimizer, writer, checkpoint_dir=None): if is_mulaw_quantize(hparams.input_type): criterion = MaskedCrossEntropyLoss() else: criterion = DiscretizedMixturelogisticLoss() if hparams.exponential_moving_average: ema = ExponentialMovingAverage(hparams.ema_decay) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) else: ema = None global global_step, global_epoch, global_test_step while global_epoch < hparams.nepochs: for phase, data_loader in data_loaders.items(): train = (phase == "train") running_loss = 0. test_evaluated = False for step, (x, y, c, g, input_lengths) in tqdm(enumerate(data_loader)): # Whether to save eval (i.e., online decoding) result do_eval = False eval_dir = join(checkpoint_dir, "{}_eval".format(phase)) # Do eval per eval_interval for train if train and global_step > 0 \ and global_step % hparams.train_eval_interval == 0: do_eval = True # Do eval for test # NOTE: Decoding WaveNet is quite time consuming, so # do only once in a single epoch for testset if not train and not test_evaluated \ and global_epoch % hparams.test_eval_epoch_interval == 0: do_eval = True test_evaluated = True if do_eval: print("[{}] Eval at train step {}".format(phase, global_step)) # Do step running_loss += __train_step(device, phase, global_epoch, global_step, global_test_step, model, optimizer, writer, criterion, x, y, c, g, input_lengths, checkpoint_dir, eval_dir, do_eval, ema) # update global state if train: global_step += 1 else: global_test_step += 1 # log per epoch averaged_loss = running_loss / len(data_loader) writer.add_scalar("{} loss (per epoch)".format(phase), averaged_loss, global_epoch) print("Step {} [{}] Loss: {}".format( global_step, phase, running_loss / len(data_loader))) global_epoch += 1
Example #17
Source File: train.py From WaveRNN-Pytorch with MIT License | 4 votes |
def train_loop(device, model, data_loader, optimizer, checkpoint_dir): """Main training loop. """ # create loss and put on device if hp.input_type == 'raw': if hp.distribution == 'beta': criterion = beta_mle_loss elif hp.distribution == 'gaussian': criterion = gaussian_loss elif hp.input_type == 'mixture': criterion = discretized_mix_logistic_loss elif hp.input_type in ["bits", "mulaw"]: criterion = nll_loss else: raise ValueError("input_type:{} not supported".format(hp.input_type)) global global_step, global_epoch, global_test_step while global_epoch < hp.nepochs: running_loss = 0 for i, (x, m, y) in enumerate(tqdm(data_loader)): x, m, y = x.to(device), m.to(device), y.to(device) y_hat = model(x, m) y = y.unsqueeze(-1) loss = criterion(y_hat, y) # calculate learning rate and update learning rate if hp.fix_learning_rate: current_lr = hp.fix_learning_rate elif hp.lr_schedule_type == 'step': current_lr = step_learning_rate_decay(hp.initial_learning_rate, global_step, hp.step_gamma, hp.lr_step_interval) else: current_lr = noam_learning_rate_decay(hp.initial_learning_rate, global_step, hp.noam_warm_up_steps) for param_group in optimizer.param_groups: param_group['lr'] = current_lr optimizer.zero_grad() loss.backward() # clip gradient norm nn.utils.clip_grad_norm_(model.parameters(), hp.grad_norm) optimizer.step() running_loss += loss.item() avg_loss = running_loss / (i+1) # saving checkpoint if needed if global_step != 0 and global_step % hp.save_every_step == 0: save_checkpoint(device, model, optimizer, global_step, checkpoint_dir, global_epoch) # evaluate model if needed if global_step != 0 and global_test_step !=True and global_step % hp.evaluate_every_step == 0: print("step {}, evaluating model: generating wav from mel...".format(global_step)) evaluate_model(model, data_loader, checkpoint_dir) print("evaluation finished, resuming training...") # reset global_test_step status after evaluation if global_test_step is True: global_test_step = False global_step += 1 print("epoch:{}, running loss:{}, average loss:{}, current lr:{}".format(global_epoch, running_loss, avg_loss, current_lr)) global_epoch += 1
Example #18
Source File: train.py From representation_mixing with BSD 3-Clause "New" or "Revised" License | 4 votes |
def train_loop(device, model, data_loaders, optimizer, writer, checkpoint_dir=None): if is_mulaw_quantize(hparams.input_type): criterion = MaskedCrossEntropyLoss() else: criterion = DiscretizedMixturelogisticLoss() if hparams.exponential_moving_average: ema = ExponentialMovingAverage(hparams.ema_decay) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) else: ema = None global global_step, global_epoch, global_test_step while global_epoch < hparams.nepochs: for phase, data_loader in data_loaders.items(): train = (phase == "train") running_loss = 0. test_evaluated = False for step, (x, y, c, g, input_lengths) in tqdm(enumerate(data_loader)): # Whether to save eval (i.e., online decoding) result do_eval = False eval_dir = join(checkpoint_dir, "{}_eval".format(phase)) # Do eval per eval_interval for train if train and global_step > 0 \ and global_step % hparams.train_eval_interval == 0: do_eval = True # Do eval for test # NOTE: Decoding WaveNet is quite time consuming, so # do only once in a single epoch for testset if not train and not test_evaluated \ and global_epoch % hparams.test_eval_epoch_interval == 0: do_eval = True test_evaluated = True if do_eval: print("[{}] Eval at train step {}".format(phase, global_step)) # Do step running_loss += __train_step(device, phase, global_epoch, global_step, global_test_step, model, optimizer, writer, criterion, x, y, c, g, input_lengths, checkpoint_dir, eval_dir, do_eval, ema) # update global state if train: global_step += 1 else: global_test_step += 1 # log per epoch averaged_loss = running_loss / len(data_loader) writer.add_scalar("{} loss (per epoch)".format(phase), averaged_loss, global_epoch) print("Step {} [{}] Loss: {}".format( global_step, phase, running_loss / len(data_loader))) global_epoch += 1