Python tqdm.tqdm.write() Examples

The following are 30 code examples of tqdm.tqdm.write(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tqdm.tqdm , or try the search function .
Example #1
Source File: logger.py    From faceswap with GNU General Public License v3.0 6 votes vote down vote up
def stream_handler(loglevel, is_gui):
    """ Add a logging cli handler """
    # Don't set stdout to lower than verbose
    loglevel = max(loglevel, 15)
    log_format = FaceswapFormatter("%(asctime)s %(levelname)-8s %(message)s",
                                   datefmt="%m/%d/%Y %H:%M:%S")

    if is_gui:
        # tqdm.write inserts extra lines in the GUI, so use standard output as
        # it is not needed there.
        log_console = logging.StreamHandler(sys.stdout)
    else:
        log_console = TqdmHandler(sys.stdout)
    log_console.setFormatter(log_format)
    log_console.setLevel(loglevel)
    return log_console 
Example #2
Source File: run.py    From MobileNetV2-pytorch with MIT License 6 votes vote down vote up
def test(model, loader, criterion, device, dtype):
    model.eval()
    test_loss = 0
    correct1, correct5 = 0, 0

    for batch_idx, (data, target) in enumerate(tqdm(loader)):
        data, target = data.to(device=device, dtype=dtype), target.to(device=device)
        with torch.no_grad():
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            corr = correct(output, target, topk=(1, 5))
        correct1 += corr[0]
        correct5 += corr[1]

    test_loss /= len(loader)

    tqdm.write(
        '\nTest set: Average loss: {:.4f}, Top1: {}/{} ({:.2f}%), '
        'Top5: {}/{} ({:.2f}%)'.format(test_loss, int(correct1), len(loader.dataset),
                                       100. * correct1 / len(loader.dataset), int(correct5),
                                       len(loader.dataset), 100. * correct5 / len(loader.dataset)))
    return test_loss, correct1 / len(loader.dataset), correct5 / len(loader.dataset) 
Example #3
Source File: opensubsdata.py    From deepQA with Apache License 2.0 6 votes vote down vote up
def loadConversations(self, dirName):
        """
        Args:
            dirName (str): folder to load
        Return:
            array(question, answer): the extracted QA pairs
        """
        conversations = []
        dirList = self.filesInDir(dirName)
        for filepath in tqdm(dirList, "OpenSubtitles data files"):
            if filepath.endswith('gz'):
                try:
                    doc = self.getXML(filepath)
                    conversations.extend(self.genList(doc))
                except ValueError:
                    tqdm.write("Skipping file %s with errors." % filepath)
                except:
                    print("Unexpected error:", sys.exc_info()[0])
                    raise
        return conversations 
Example #4
Source File: callbacks.py    From ngraph-python with Apache License 2.0 6 votes vote down vote up
def __call__(self, transformer, callback_data, phase, data, idx):
        if phase == CallbackPhase.train_pre_:
            self.total_iterations = callback_data['config'].attrs['total_iterations']
            num_intervals = self.total_iterations // self.frequency
            for loss_name in self.interval_loss_comp.output_keys:
                callback_data.create_dataset("cost/{}".format(loss_name), (num_intervals,))
            callback_data.create_dataset("time/loss", (num_intervals,))
        elif phase == CallbackPhase.train_post:
            losses = loop_eval(self.dataset, self.interval_loss_comp)
            tqdm.write("Training complete.  Avg losses: {}".format(losses))
        elif phase == CallbackPhase.minibatch_post and ((idx + 1) % self.frequency == 0):
            start_loss = default_timer()
            interval_idx = idx // self.frequency

            losses = loop_eval(self.dataset, self.interval_loss_comp)

            for loss_name, loss in losses.items():
                callback_data["cost/{}".format(loss_name)][interval_idx] = loss

            callback_data["time/loss"][interval_idx] = (default_timer() - start_loss)
            tqdm.write("Interval {} Iteration {} complete.  Avg losses: {}".format(
                interval_idx + 1, idx + 1, losses)) 
Example #5
Source File: pascal_voc.py    From CIOD with MIT License 6 votes vote down vote up
def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            os.remove(cache_file)

        gt_roidb = [self._load_pascal_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
        tqdm.write('wrote gt roidb to {}'.format(cache_file))

        return gt_roidb 
Example #6
Source File: preference_learning.py    From ICML2019-TREX with MIT License 6 votes vote down vote up
def train_with_dataset(self,dataset,batch_size,include_action=False,iter=10000,l2_reg=0.01,debug=False):
        sess = tf.get_default_session()

        for it in tqdm(range(iter),dynamic_ncols=True):
            b_x,b_y,x_split,y_split,b_l = dataset.batch(batch_size=batch_size,include_action=include_action)
            loss,l2_loss,acc,_ = sess.run([self.loss,self.l2_loss,self.acc,self.update_op],feed_dict={
                self.x:b_x,
                self.y:b_y,
                self.x_split:x_split,
                self.y_split:y_split,
                self.l:b_l,
                self.l2_reg:l2_reg,
            })

            if debug:
                if it % 100 == 0 or it < 10:
                    tqdm.write(('loss: %f (l2_loss: %f), acc: %f'%(loss,l2_loss,acc))) 
Example #7
Source File: tree.py    From pymerkle with GNU General Public License v3.0 6 votes vote down vote up
def loadFromFile(cls, file_path):
        """
        Loads a Merkle-tree from the provided file, the latter being the result
        of an export (cf. the *MerkleTree.export()* method)

        :param file_path: relative path of the file to load from with
                respect to the current working directory
        :type file_path: str
        :returns: The tree loaded from the provided file
        :rtype: MerkleTree

        :raises WrongJSONFormat: if the JSON object loaded from within the
                    provided file is not a Merkle-tree export
        """
        with open(file_path, 'r') as __file:
            loaded_object = json.load(__file)
        try:
            header = loaded_object['header']
            tree = cls(
                hash_type=header['hash_type'],
                encoding=header['encoding'],
                raw_bytes=header['raw_bytes'],
                security=header['security'])
        except KeyError:
            raise WrongJSONFormat

        tqdm.write('\nFile has been loaded')
        update = tree.update
        for hash in tqdm(loaded_object['hashes'], desc='Retrieving tree...'):
            update(digest=hash)
        tqdm.write('Tree has been retrieved')
        return tree


    # Comparison 
Example #8
Source File: test_capsnet.py    From Pytorch-CapsuleNet with MIT License 6 votes vote down vote up
def test(capsule_net, test_loader, epoch):
    capsule_net.eval()
    test_loss = 0
    correct = 0
    for batch_id, (data, target) in enumerate(test_loader):

        target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)

        if USE_CUDA:
            data, target = data.cuda(), target.cuda()

        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)

        test_loss += loss.data[0]
        correct += sum(np.argmax(masked.data.cpu().numpy(), 1) ==
                       np.argmax(target.data.cpu().numpy(), 1))

    tqdm.write(
        "Epoch: [{}/{}], test accuracy: {:.6f}, loss: {:.6f}".format(epoch, N_EPOCHS, correct / len(test_loader.dataset),
                                                                  test_loss / len(test_loader))) 
Example #9
Source File: logger.py    From faceswap with GNU General Public License v3.0 6 votes vote down vote up
def crash_log():
    """ Write debug_buffer to a crash log on crash """
    original_traceback = traceback.format_exc()
    path = os.path.dirname(os.path.realpath(sys.argv[0]))
    filename = os.path.join(path, datetime.now().strftime("crash_report.%Y.%m.%d.%H%M%S%f.log"))
    freeze_log = list(debug_buffer)
    try:
        from lib.sysinfo import sysinfo  # pylint:disable=import-outside-toplevel
    except Exception:  # pylint:disable=broad-except
        sysinfo = ("\n\nThere was an error importing System Information from lib.sysinfo. This is "
                   "probably a bug which should be fixed:\n{}".format(traceback.format_exc()))
    with open(filename, "w") as outfile:
        outfile.writelines(freeze_log)
        outfile.write(original_traceback)
        outfile.write(sysinfo)
    return filename 
Example #10
Source File: convert.py    From faceswap with GNU General Public License v3.0 6 votes vote down vote up
def _check_alignments(self, frame_name):
        """ Ensure that we have alignments for the current frame.

        If we have no alignments for this image, skip it and output a message.

        Parameters
        ----------
        frame_name: str
            The name of the frame to check that we have alignments for

        Returns
        -------
        bool
            ``True`` if we have alignments for this face, otherwise ``False``
        """
        have_alignments = self._alignments.frame_exists(frame_name)
        if not have_alignments:
            tqdm.write("No alignment found for {}, "
                       "skipping".format(frame_name))
        return have_alignments 
Example #11
Source File: download.py    From open-images-downloader with MIT License 6 votes vote down vote up
def download_objects_of_interest(download_list):
    def fetch_url(url):
        try:
            urllib.request.urlretrieve(url, os.path.join(OUTPUT_DIR, url.split("/")[-1]))
            return url, None
        except Exception as e:
            return None, e

    start = timer()
    results = ThreadPool(20).imap_unordered(fetch_url, download_list)

    df_pbar = tqdm(total=len(download_list), position=1, desc="Download %: ")

    for url, error in results:
        df_pbar.update(1)
        if error is None:
            pass  # TODO: find a way to do tqdm.write() with a refresh
            # print("{} fetched in {}s".format(url, timer() - start), end='\r')
        else:
            pass  # TODO: find a way to do tqdm.write() with a refresh
            # print("error fetching {}: {}".format(url, error), end='\r') 
Example #12
Source File: calc_dataloader_stats.py    From margipose with Apache License 2.0 6 votes vote down vote up
def calculate_stats(stats, opts):
    model_desc = Default_MargiPose_Desc
    model = create_model(model_desc)
    skeleton = CanonicalSkeletonDesc
    loader = create_train_dataloader(
        [opts.dataset], model.data_specs, opts.batch_size, opts.examples_per_epoch, False)
    loader.dataset.without_image = not opts.with_image
    for epoch in range(opts.epochs):
        for batch in tqdm(loader, total=len(loader), leave=False, ascii=True):
            joints_3d = np.asarray(batch['target'])
            stats['root_x'].add_samples(joints_3d[:, skeleton.root_joint_id, 0])
            stats['root_y'].add_samples(joints_3d[:, skeleton.root_joint_id, 1])
            stats['root_z'].add_samples(joints_3d[:, skeleton.root_joint_id, 2])
            stats['lankle_x'].add_samples(joints_3d[:, skeleton.joint_names.index('left_ankle'), 0])
            stats['lankle_y'].add_samples(joints_3d[:, skeleton.joint_names.index('left_ankle'), 1])
            stats['lankle_z'].add_samples(joints_3d[:, skeleton.joint_names.index('left_ankle'), 2])
            if opts.with_image:
                image = np.asarray(batch['input'])
                stats['red'].add_samples(image[:, 0].ravel())
                stats['green'].add_samples(image[:, 1].ravel())
                stats['blue'].add_samples(image[:, 2].ravel())
            stats['index'].add_samples(np.asarray(batch['index'], dtype=np.float32) / (len(loader.dataset) - 1))
        tqdm.write(f'Epoch {epoch + 1:3d}')
        tqdm.write(repr(stats))
    tqdm.write('Done.') 
Example #13
Source File: logging.py    From flambe with MIT License 6 votes vote down vote up
def colorize_exceptions() -> None:
    """Colorizes the system stderr ouput using pygments if installed"""
    try:
        import traceback
        from pygments import highlight
        from pygments.lexers import get_lexer_by_name
        from pygments.formatters import TerminalFormatter

        def colorized_excepthook(type_: Type[BaseException],
                                 value: BaseException,
                                 tb: TracebackType) -> None:
            tbtext = ''.join(traceback.format_exception(type_, value, tb))
            lexer = get_lexer_by_name("pytb", stripall=True)
            formatter = TerminalFormatter()
            sys.stderr.write(highlight(tbtext, lexer, formatter))

        sys.excepthook = colorized_excepthook  # type: ignore

    except ModuleNotFoundError:
        pass 
Example #14
Source File: viz.py    From PVN3D with MIT License 6 votes vote down vote up
def flush(self):
        if len(self.flush_vals) == 0:
            return

        longest_win_name = max(map(lambda k: len(k), self.flush_vals.keys()))

        tqdm.write("=== Training Progress ===")

        for win, lines in self.flush_vals.items():
            if len(lines) == 0:
                continue

            _str = "{:<{width}} --- ".format(win, width=longest_win_name)
            for k, v in lines.items():
                _str += "{}: {:.4f}\t".format(k, v)

            tqdm.write(_str)

        tqdm.write(" ")
        tqdm.write(" ")

        self.flush_vals = collections.OrderedDict() 
Example #15
Source File: run.py    From MobileNetV3-pytorch with MIT License 6 votes vote down vote up
def test(model, loader, criterion, device, dtype, child):
    model.eval()
    test_loss = 0
    correct1, correct5 = 0, 0

    enum_load = enumerate(loader) if child else enumerate(tqdm(loader))

    with torch.no_grad():
        for batch_idx, (data, target) in enum_load:
            data, target = data.to(device=device, dtype=dtype), target.to(device=device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            corr = correct(output, target, topk=(1, 5))
            correct1 += corr[0]
            correct5 += corr[1]

    test_loss /= len(loader)
    if not child:
        tqdm.write(
            '\nTest set: Average loss: {:.4f}, Top1: {}/{} ({:.2f}%), '
            'Top5: {}/{} ({:.2f}%)'.format(test_loss, int(correct1), len(loader.sampler),
                                           100. * correct1 / len(loader.sampler), int(correct5),
                                           len(loader.sampler), 100. * correct5 / len(loader.sampler)))
    return test_loss, correct1 / len(loader.sampler), correct5 / len(loader.sampler) 
Example #16
Source File: main.py    From transferlearning with MIT License 6 votes vote down vote up
def test(model, data_tar, e):
    total_loss_test = 0
    correct = 0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for batch_id, (data, target) in enumerate(data_tar):
            data, target = data.view(-1,28 * 28).to(DEVICE),target.to(DEVICE)
            model.eval()
            ypred, _, _ = model(data, data)
            loss = criterion(ypred, target)
            pred = ypred.data.max(1)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
            total_loss_test += loss.data
        accuracy = correct * 100. / len(data_tar.dataset)
        res = 'Test: total loss: {:.6f}, correct: [{}/{}], testing accuracy: {:.4f}%'.format(
            total_loss_test, correct, len(data_tar.dataset), accuracy
        )
    tqdm.write(res)
    RESULT_TEST.append([e, total_loss_test, accuracy])
    log_test.write(res + '\n') 
Example #17
Source File: logging.py    From flambe with MIT License 5 votes vote down vote up
def write(self, x: AnyStr) -> int:
        # Avoid print() second call (useless \n)
        if len(x.rstrip()) > 0:
            return tqdm.write(x, file=self.file)
        return 0 
Example #18
Source File: logger.py    From faceswap with GNU General Public License v3.0 5 votes vote down vote up
def write(self, buffer):
        """ Write line to buffer """
        for line in buffer.rstrip().splitlines():
            self.append(line + "\n") 
Example #19
Source File: dptrp1.py    From dpt-rp1-py with MIT License 5 votes vote down vote up
def download_file(self, remote_path, local_path):
        local_folder = os.path.dirname(local_path)
        # Make sure that local_folder exists so that we can write data there.
        # If local_path is just a filename, local_folder will be '', and
        # we won't need to create any directories.
        if local_folder != "":
            os.makedirs(os.path.dirname(local_path), exist_ok=True)
        data = self.download(remote_path)
        with open(local_path, "wb") as f:
            f.write(data) 
Example #20
Source File: util.py    From MADAN with MIT License 5 votes vote down vote up
def emit(self, record):
        msg = self.format(record)
        tqdm.write(msg) 
Example #21
Source File: logger.py    From faceswap with GNU General Public License v3.0 5 votes vote down vote up
def emit(self, record):
        msg = self.format(record)
        tqdm.write(msg) 
Example #22
Source File: trees.py    From iffse with MIT License 5 votes vote down vote up
def build_annoy_tree(facial_embeddings, tree_path,
                    annoy_metric='euclidean', annoy_trees_no=256):
    """
    Builds an annoy tree

    Args:
        facial_embeddings: List of facial embeddings to be indexed in tree
        tree_path: where the annoy tree will be saved
        annoy_metric: euclidean / angular
        annoy_tree_no: how many trees in the annoy forest? Larger = more accurate
    """

    # Annoy tree
    tree = AnnoyIndex(128, metric=annoy_metric)

    # Don't wanna store entire db into memory
    for idx, f in enumerate(tqdm(facial_embeddings)):
        # Sqlte errors sometimes?
        try:
            cur_np = string_to_np(f.latent_space)

            tree.add_item(idx, cur_np)

        except Exception as e:
            tqdm.write(str(e))

    tree.build(annoy_trees_no)
    tree.save(tree_path) 
Example #23
Source File: train_source.py    From pytorch-domain-adaptation with MIT License 5 votes vote down vote up
def main(args):
    train_loader, val_loader = create_dataloaders(args.batch_size)

    model = Net().to(device)
    optim = torch.optim.Adam(model.parameters())
    lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=1, verbose=True)
    criterion = torch.nn.CrossEntropyLoss()

    best_accuracy = 0
    for epoch in range(1, args.epochs+1):
        model.train()
        train_loss, train_accuracy = do_epoch(model, train_loader, criterion, optim=optim)

        model.eval()
        with torch.no_grad():
            val_loss, val_accuracy = do_epoch(model, val_loader, criterion, optim=None)

        tqdm.write(f'EPOCH {epoch:03d}: train_loss={train_loss:.4f}, train_accuracy={train_accuracy:.4f} '
                   f'val_loss={val_loss:.4f}, val_accuracy={val_accuracy:.4f}')

        if val_accuracy > best_accuracy:
            print('Saving model...')
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), 'trained_models/source.pt')

        lr_schedule.step(val_loss) 
Example #24
Source File: bow_trainer.py    From hedwig with Apache License 2.0 5 votes vote down vote up
def train(self):
        train_data = StreamingSparseDataset(self.train_features, self.train_labels)
        train_dataloader = DataLoader(train_data, shuffle=True, batch_size=self.args.batch_size)

        print("Number of examples: ", len(self.train_labels))
        print("Batch size:", self.args.batch_size)

        for epoch in trange(int(self.args.epochs), desc="Epoch"):
            self.train_epoch(train_dataloader)
            dev_evaluator = BagOfWordsEvaluator(self.model, self.vectorizer, self.processor, self.args, split='dev')
            dev_acc, dev_precision, dev_recall, dev_f1, dev_loss = dev_evaluator.get_scores()[0]

            # Print validation results
            tqdm.write(self.log_header)
            tqdm.write(self.log_template.format(epoch + 1, self.nb_train_steps, epoch + 1, self.args.epochs,
                                                dev_acc, dev_precision, dev_recall, dev_f1, dev_loss))

            # Update validation results
            if dev_f1 > self.best_dev_f1:
                self.unimproved_iters = 0
                self.best_dev_f1 = dev_f1
                torch.save(self.model, self.snapshot_path)
            else:
                self.unimproved_iters += 1
                if self.unimproved_iters >= self.args.patience:
                    self.early_stop = True
                    tqdm.write("Early Stopping. Epoch: {}, Best Dev F1: {}".format(epoch, self.best_dev_f1))
                    break 
Example #25
Source File: tqdm.py    From CornerNet-Lite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def write(self, x):
        if len(x.rstrip()) > 0:
            tqdm.write(x, file=self.dummy_file) 
Example #26
Source File: progress.py    From gandissect with MIT License 5 votes vote down vote up
def print_progress(*args):
    '''
    When within a progress loop, post_progress(k=str) will display
    the given k=str status on the right-hand-side of the progress
    status bar.  If not within a visible progress bar, does nothing.
    '''
    if default_verbosity:
        printfn = print if tqdm is None else tqdm.write
        printfn(' '.join(str(s) for s in args)) 
Example #27
Source File: solver.py    From neural_chat with MIT License 5 votes vote down vote up
def generate_sentence(self, sentences, sentence_length,
                          input_conversation_length, input_sentences,
                          target_sentences, extra_context_inputs=None):
        """Generate output of decoder (single batch)"""
        self.model.eval()

        # [batch_size, max_seq_len, vocab_size]
        preds = self.model(
            sentences,
            sentence_length,
            input_conversation_length,
            target_sentences,
            decode=True,
            extra_context_inputs=extra_context_inputs)
        generated_sentences = preds[0]

        # write output to file
        with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f:
            f.write(f'<Epoch {self.epoch_i}>\n\n')

            tqdm.write('\n<Samples>')
            for input_sent, target_sent, output_sent in zip(
                    input_sentences, target_sentences, generated_sentences):
                input_sent = self.vocab.decode(input_sent)
                target_sent = self.vocab.decode(target_sent)
                output_sent = '\n'.join([self.vocab.decode(sent) for sent in output_sent])
                s = '\n'.join(['Input sentence: ' + input_sent,
                               'Ground truth: ' + target_sent,
                               'Generated response: ' + output_sent + '\n'])
                f.write(s + '\n')
                print(s)
            print('') 
Example #28
Source File: solver.py    From neural_chat with MIT License 5 votes vote down vote up
def generate_sentence(self, input_sentences, input_sentence_length,
                          input_conversation_length, target_sentences,
                          extra_context_inputs=None):
        self.model.eval()

        # [batch_size, max_seq_len, vocab_size]
        preds = self.model(
            input_sentences,
            input_sentence_length,
            input_conversation_length,
            target_sentences,
            decode=True,
            extra_context_inputs=extra_context_inputs)
        generated_sentences = preds[0]

        # write output to file
        with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f:
            f.write(f'<Epoch {self.epoch_i}>\n\n')

            tqdm.write('\n<Samples>')
            for input_sent, target_sent, output_sent in zip(
                    input_sentences, target_sentences, generated_sentences):
                input_sent = self.vocab.decode(input_sent)
                target_sent = self.vocab.decode(target_sent)
                output_sent = '\n'.join([self.vocab.decode(sent) for sent in output_sent])
                s = '\n'.join(['Input sentence: ' + input_sent,
                               'Ground truth: ' + target_sent,
                               'Generated response: ' + output_sent + '\n'])
                f.write(s + '\n')
                print(s)
            print('') 
Example #29
Source File: model.py    From MelNet with MIT License 5 votes vote down vote up
def sample(self, condition):
        x = None
        seq = torch.from_numpy(text_to_sequence(condition)).long().unsqueeze(0)
        input_lengths = torch.LongTensor([seq[0].shape[0]]).cuda()
        audio_lengths = torch.LongTensor([0]).cuda()

        ## Tier 1 ##
        tqdm.write('Tier 1')
        for t in tqdm(range(self.args.timestep // self.t_div)):
            audio_lengths += 1
            if x is None:
                x = torch.zeros((1, self.n_mels // self.f_div, 1)).cuda()
            else:
                x = torch.cat([x, torch.zeros((1, self.n_mels // self.f_div, 1)).cuda()], dim=-1)
            for m in tqdm(range(self.n_mels // self.f_div)):
                torch.cuda.synchronize()
                if self.infer_hp.conditional:
                    mu, std, pi, _ = self.tiers[1](x, seq, input_lengths, audio_lengths)
                else:
                    mu, std, pi = self.tiers[1](x, audio_lengths)
                temp = sample_gmm(mu, std, pi)
                x[:, m, t] = temp[:, m, t]

        ## Tier 2~N ##
        for tier in tqdm(range(2, self.hp.model.tier + 1)):
            tqdm.write('Tier %d' % tier)
            mu, std, pi = self.tiers[tier](x)
            temp = sample_gmm(mu, std, pi)
            x = self.tierutil.interleave(x, temp, tier + 1)

        return x 
Example #30
Source File: faceforensics_download.py    From DeepFake-Detection with MIT License 5 votes vote down vote up
def download_file(url, out_file, report_progress=False):
    out_dir = os.path.dirname(out_file)
    if not os.path.isfile(out_file):
        fh, out_file_tmp = tempfile.mkstemp(dir=out_dir)
        f = os.fdopen(fh, 'w')
        f.close()
        if report_progress:
            urllib.request.urlretrieve(url, out_file_tmp,
                                       reporthook=reporthook)
        else:
            urllib.request.urlretrieve(url, out_file_tmp)
        os.rename(out_file_tmp, out_file)
    else:
        tqdm.write('WARNING: skipping download of existing file ' + out_file)