Python utils.create_dir() Examples
The following are 25
code examples of utils.create_dir().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
utils
, or try the search function
.
Example #1
Source File: compute_softscore.py From ban-vqa with MIT License | 6 votes |
def create_ans2label(occurence, name, cache_root='data/cache'): """Note that this will also create label2ans.pkl at the same time occurence: dict {answer -> whatever} name: prefix of the output file cache_root: str """ ans2label = {} label2ans = [] label = 0 for answer in occurence: label2ans.append(answer) ans2label[answer] = label label += 1 utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_ans2label.pkl') cPickle.dump(ans2label, open(cache_file, 'wb')) cache_file = os.path.join(cache_root, name+'_label2ans.pkl') cPickle.dump(label2ans, open(cache_file, 'wb')) return ans2label
Example #2
Source File: compute_softscore.py From Attention-on-Attention-for-VQA with MIT License | 6 votes |
def create_ans2label(occurence, name, cache_root='data/cache'): """Note that this will also create label2ans.pkl at the same time occurence: dict {answer -> whatever} name: prefix of the output file cache_root: str """ ans2label = {} label2ans = [] label = 0 for answer in occurence: label2ans.append(answer) ans2label[answer] = label label += 1 utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_ans2label.pkl') cPickle.dump(ans2label, open(cache_file, 'wb')) cache_file = os.path.join(cache_root, name+'_label2ans.pkl') cPickle.dump(label2ans, open(cache_file, 'wb')) return ans2label
Example #3
Source File: compute_softscore.py From bottom-up-attention-tf with MIT License | 6 votes |
def create_ans2label(occurence, name, cache_root='data/cache'): """Note that this will also create label2ans.pkl at the same time occurence: dict {answer -> whatever} name: prefix of the output file cache_root: str """ ans2label = {} label2ans = [] label = 0 for answer in occurence: label2ans.append(answer) ans2label[answer] = label label += 1 utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_ans2label.pkl') cPickle.dump(ans2label, open(cache_file, 'wb')) cache_file = os.path.join(cache_root, name+'_label2ans.pkl') cPickle.dump(label2ans, open(cache_file, 'wb')) return ans2label
Example #4
Source File: compute_softscore.py From bottom-up-attention-vqa with GNU General Public License v3.0 | 6 votes |
def create_ans2label(occurence, name, cache_root='data/cache'): """Note that this will also create label2ans.pkl at the same time occurence: dict {answer -> whatever} name: prefix of the output file cache_root: str """ ans2label = {} label2ans = [] label = 0 for answer in occurence: label2ans.append(answer) ans2label[answer] = label label += 1 utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_ans2label.pkl') cPickle.dump(ans2label, open(cache_file, 'wb')) cache_file = os.path.join(cache_root, name+'_label2ans.pkl') cPickle.dump(label2ans, open(cache_file, 'wb')) return ans2label
Example #5
Source File: compute_softscore.py From VQA_ReGAT with MIT License | 6 votes |
def create_ans2label(occurence, name, cache_root='data/cache'): """Note that this will also create label2ans.pkl at the same time occurence: dict {answer -> whatever} name: prefix of the output file cache_root: str """ ans2label = {} label2ans = [] label = 0 for answer in occurence: label2ans.append(answer) ans2label[answer] = label label += 1 utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_ans2label.pkl') pickle.dump(ans2label, open(cache_file, 'wb')) cache_file = os.path.join(cache_root, name+'_label2ans.pkl') pickle.dump(label2ans, open(cache_file, 'wb')) return ans2label
Example #6
Source File: assemblyGet.py From enaBrowserTools with Apache License 2.0 | 5 votes |
def download_assembly(dest_dir, accession, output_format, fetch_wgs, extract_wgs, expanded, quiet=False): if output_format is None: output_format = utils.EMBL_FORMAT assembly_dir = os.path.join(dest_dir, accession) utils.create_dir(assembly_dir) # download xml utils.download_record(assembly_dir, accession, utils.XML_FORMAT) local_xml = utils.get_destination_file(assembly_dir, accession, utils.XML_FORMAT) # get wgs and sequence report info wgs_set, sequence_report = parse_assembly_xml(local_xml) has_sequence_report = False # download sequence report if sequence_report is not None: has_sequence_report = utils.get_ftp_file(sequence_report, assembly_dir) # parse sequence report and download sequences wgs_scaffolds = [] wgs_scaffold_cnt = 0 if has_sequence_report: wgs_scaffolds = download_sequences(sequence_report.split('/')[-1], assembly_dir, output_format, expanded, quiet) wgs_scaffold_cnt = len(wgs_scaffolds) if wgs_scaffold_cnt > 0: if not quiet: print ('Assembly contains {} WGS scaffolds, will fetch WGS set'.format(wgs_scaffold_cnt)) fetch_wgs = True else: fetch_wgs = True # download wgs set if needed if wgs_set is not None and fetch_wgs: if not quiet: print ('fetching wgs set') sequenceGet.download_wgs(assembly_dir, wgs_set, output_format) # extract wgs scaffolds from WGS file if wgs_scaffold_cnt > 0 and extract_wgs: extract_wgs_scaffolds(assembly_dir, wgs_scaffolds, wgs_set, output_format, quiet)
Example #7
Source File: compute_softscore.py From ban-vqa with MIT License | 5 votes |
def compute_target(answers_dset, ans2label, name, cache_root='data/cache'): """Augment answers_dset with soft score as label ***answers_dset should be preprocessed*** Write result into a cache file """ target = [] for ans_entry in answers_dset: answers = ans_entry['answers'] answer_count = {} for answer in answers: answer_ = answer['answer'] answer_count[answer_] = answer_count.get(answer_, 0) + 1 labels = [] scores = [] for answer in answer_count: if answer not in ans2label: continue labels.append(ans2label[answer]) score = get_score(answer_count[answer]) scores.append(score) target.append({ 'question_id': ans_entry['question_id'], 'image_id': ans_entry['image_id'], 'labels': labels, 'scores': scores }) utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_target.pkl') cPickle.dump(target, open(cache_file, 'wb')) return target
Example #8
Source File: test.py From ban-vqa with MIT License | 5 votes |
def process(args, model, eval_loader): model_path = args.input+'/model%s.pth' % \ ('' if 0 > args.epoch else '_epoch%d' % args.epoch) print('loading %s' % model_path) model_data = torch.load(model_path) model = nn.DataParallel(model).cuda() model.load_state_dict(model_data.get('model_state', model_data)) model.train(False) logits, qIds = get_logits(model, eval_loader) results = make_json(logits, qIds, eval_loader) model_label = '%s%s%d_%s' % (args.model, args.op, args.num_hid, args.label) if args.logits: utils.create_dir('logits/'+model_label) torch.save(logits, 'logits/'+model_label+'/logits%d.pth' % args.index) utils.create_dir(args.output) if 0 <= args.epoch: model_label += '_epoch%d' % args.epoch with open(args.output+'/%s_%s.json' \ % (args.split, model_label), 'w') as f: json.dump(results, f)
Example #9
Source File: compute_softscore.py From bottom-up-attention-tf with MIT License | 5 votes |
def compute_target(answers_dset, ans2label, name, cache_root='data/cache'): """Augment answers_dset with soft score as label ***answers_dset should be preprocessed*** Write result into a cache file """ target = [] for ans_entry in answers_dset: answers = ans_entry['answers'] answer_count = {} for answer in answers: answer_ = answer['answer'] answer_count[answer_] = answer_count.get(answer_, 0) + 1 labels = [] scores = [] for answer in answer_count: if answer not in ans2label: continue labels.append(ans2label[answer]) score = get_score(answer_count[answer]) scores.append(score) target.append({ 'question_id': ans_entry['question_id'], 'image_id': ans_entry['image_id'], 'labels': labels, 'scores': scores }) utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_target.pkl') cPickle.dump(target, open(cache_file, 'wb')) return target
Example #10
Source File: compute_softscore.py From bottom-up-attention-vqa with GNU General Public License v3.0 | 5 votes |
def compute_target(answers_dset, ans2label, name, cache_root='data/cache'): """Augment answers_dset with soft score as label ***answers_dset should be preprocessed*** Write result into a cache file """ target = [] for ans_entry in answers_dset: answers = ans_entry['answers'] answer_count = {} for answer in answers: answer_ = answer['answer'] answer_count[answer_] = answer_count.get(answer_, 0) + 1 labels = [] scores = [] for answer in answer_count: if answer not in ans2label: continue labels.append(ans2label[answer]) score = get_score(answer_count[answer]) scores.append(score) target.append({ 'question_id': ans_entry['question_id'], 'image_id': ans_entry['image_id'], 'labels': labels, 'scores': scores }) utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_target.pkl') cPickle.dump(target, open(cache_file, 'wb')) return target
Example #11
Source File: compute_softscore.py From VQA_ReGAT with MIT License | 5 votes |
def compute_target(answers_dset, ans2label, name, cache_root='data/cache'): """Augment answers_dset with soft score as label ***answers_dset should be preprocessed*** Write result into a cache file """ target = [] for ans_entry in answers_dset: answers = ans_entry['answers'] answer_count = {} for answer in answers: answer_ = answer['answer'] answer_count[answer_] = answer_count.get(answer_, 0) + 1 labels = [] scores = [] for answer in answer_count: if answer not in ans2label: continue labels.append(ans2label[answer]) score = get_score(answer_count[answer]) scores.append(score) target.append({ 'question_id': ans_entry['question_id'], 'image_id': ans_entry['image_id'], 'labels': labels, 'scores': scores }) utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_target.pkl') pickle.dump(target, open(cache_file, 'wb')) return target
Example #12
Source File: pytbcrawler.py From tor-browser-crawler with GNU General Public License v2.0 | 5 votes |
def build_crawl_dirs(): # build crawl directory ut.create_dir(cm.RESULTS_DIR) ut.create_dir(cm.CRAWL_DIR) ut.create_dir(cm.LOGS_DIR) copyfile(cm.CONFIG_FILE, join(cm.LOGS_DIR, 'config.ini')) add_symlink(join(cm.RESULTS_DIR, 'latest_crawl'), basename(cm.CRAWL_DIR))
Example #13
Source File: crawler.py From tor-browser-crawler with GNU General Public License v2.0 | 5 votes |
def __do_instance(self): for self.job.visit in xrange(self.job.visits): ut.create_dir(self.job.path) wl_log.info("*** Visit #%s to %s ***", self.job.visit, self.job.url) with self.driver.launch(): try: self.driver.set_page_load_timeout(cm.SOFT_VISIT_TIMEOUT) except WebDriverException as seto_exc: wl_log.error("Setting soft timeout %s", seto_exc) self.__do_visit() if self.screenshots: try: self.driver.get_screenshot_as_file(self.job.png_file) except WebDriverException: wl_log.error("Cannot get screenshot.") sleep(float(self.job.config['pause_between_visits'])) self.post_visit()
Example #14
Source File: metrics.py From adagan with BSD 3-Clause "New" or "Revised" License | 5 votes |
def _make_plots_2d(self, opts, step, real_points, fake_points, weights=None, prefix=''): max_val = opts['gmm_max_val'] * 2 if real_points is None: real = np.zeros([0, 2]) else: num_real_points = len(real_points) real = np.reshape(real_points, [num_real_points, 2]) if fake_points is None: fake = np.zeros([0, 2]) else: num_fake_points = len(fake_points) fake = np.reshape(fake_points, [num_fake_points, 2]) # Plotting the sample plt.clf() plt.axis([-max_val, max_val, -max_val, max_val]) plt.scatter(real[:, 0], real[:, 1], color='red', s=20, label='real') plt.scatter(fake[:, 0], fake[:, 1], color='blue', s=20, label='fake') plt.legend(loc='upper left') filename = prefix + 'mixture{:02d}.png'.format(step) utils.create_dir(opts['work_dir']) plt.savefig(utils.o_gfile((opts["work_dir"], filename), 'wb'), format='png') # Plotting the weights, if provided if weights is not None: plt.clf() plt.axis([-max_val, max_val, -max_val, max_val]) assert len(weights) == len(real) plt.scatter(real[:, 0], real[:, 1], c=weights, s=40, edgecolors='face') plt.colorbar() filename = prefix + 'weights{:02d}.png'.format(step) utils.create_dir(opts['work_dir']) plt.savefig(utils.o_gfile((opts["work_dir"], filename), 'wb'), format='png')
Example #15
Source File: enaGroupGet.py From enaBrowserTools with Apache License 2.0 | 5 votes |
def download_group(accession, group, output_format, dest_dir, fetch_wgs, extract_wgs, fetch_meta, fetch_index, aspera, subtree, expanded): group_dir = os.path.join(dest_dir, accession) utils.create_dir(group_dir) if group == utils.SEQUENCE: download_sequence_group(accession, output_format, group_dir, subtree, expanded) else: download_data_group(group, accession, output_format, group_dir, fetch_wgs, extract_wgs, fetch_meta, fetch_index, aspera, subtree, expanded)
Example #16
Source File: assemblyGet.py From enaBrowserTools with Apache License 2.0 | 5 votes |
def download_assembly(dest_dir, accession, output_format, fetch_wgs, extract_wgs, expanded, quiet=False): if output_format is None: output_format = utils.EMBL_FORMAT assembly_dir = os.path.join(dest_dir, accession) utils.create_dir(assembly_dir) # download xml utils.download_record(assembly_dir, accession, utils.XML_FORMAT) local_xml = utils.get_destination_file(assembly_dir, accession, utils.XML_FORMAT) # get wgs and sequence report info wgs_set, sequence_report = parse_assembly_xml(local_xml) has_sequence_report = False # download sequence report if sequence_report is not None: has_sequence_report = utils.get_ftp_file(sequence_report, assembly_dir) # parse sequence report and download sequences wgs_scaffolds = [] wgs_scaffold_cnt = 0 if has_sequence_report: wgs_scaffolds = download_sequences(sequence_report.split('/')[-1], assembly_dir, output_format, expanded, quiet) wgs_scaffold_cnt = len(wgs_scaffolds) if wgs_scaffold_cnt > 0: if not quiet: print 'Assembly contains {} WGS scaffolds, will fetch WGS set'.format(wgs_scaffold_cnt) fetch_wgs = True else: fetch_wgs = True # download wgs set if needed if wgs_set is not None and fetch_wgs: if not quiet: print 'fetching wgs set' sequenceGet.download_wgs(assembly_dir, wgs_set, output_format) # extract wgs scaffolds from WGS file if wgs_scaffold_cnt > 0 and extract_wgs: extract_wgs_scaffolds(assembly_dir, wgs_scaffolds, wgs_set, output_format, quiet)
Example #17
Source File: extractor.py From VGGFace2-pytorch with MIT License | 5 votes |
def extract(self): batch_time = utils.AverageMeter() self.model.eval() end = time.time() for batch_idx, (imgs, target, img_files, class_ids) in tqdm.tqdm( enumerate(self.val_loader), total=len(self.val_loader), desc='Extract', ncols=80, leave=False): gc.collect() if self.cuda: imgs = imgs.cuda() imgs = Variable(imgs, volatile=True) output = self.model(imgs) # N C H W torch.Size([1, 1, 401, 600]) if self.flatten_feature: output = output.view(output.size(0), -1) output = output.data.cpu().numpy() assert output.shape[0] == len(img_files) for i, img_file in enumerate(img_files): base_name = os.path.splitext(img_file)[0] feature_file = os.path.join(self.feature_dir, base_name + ".npy") utils.create_dir(os.path.dirname(feature_file)) np.save(feature_file, output[i]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % self.print_freq == 0: log_str = 'Extract: [{0}/{1}]\tTime: {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( batch_idx, len(self.val_loader), batch_time=batch_time) print(log_str) self.print_log(log_str)
Example #18
Source File: compute_softscore.py From Attention-on-Attention-for-VQA with MIT License | 5 votes |
def compute_target(answers_dset, ans2label, name, cache_root='data/cache'): """Augment answers_dset with soft score as label ***answers_dset should be preprocessed*** Write result into a cache file """ target = [] for ans_entry in answers_dset: answers = ans_entry['answers'] answer_count = {} for answer in answers: answer_ = answer['answer'] answer_count[answer_] = answer_count.get(answer_, 0) + 1 labels = [] scores = [] for answer in answer_count: if answer not in ans2label: continue labels.append(ans2label[answer]) score = get_score(answer_count[answer]) scores.append(score) target.append({ 'question_id': ans_entry['question_id'], 'image_id': ans_entry['image_id'], 'labels': labels, 'scores': scores }) utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_target.pkl') cPickle.dump(target, open(cache_file, 'wb')) return target
Example #19
Source File: track.py From SoundCloud with GNU General Public License v3.0 | 5 votes |
def download(client, track, dir, override=False): """Download a track using the given SC client""" title = fix_title(track.title, track.user['username']) print '"%s"' % title if not dir: dir = 'mp3' utils.create_dir(dir) file_name = utils.build_file_name(dir, title) if not override and os.path.exists(file_name): print "File already exists, skipped" return False stream_url = client.get(track.stream_url, allow_redirects=False) urllib.urlretrieve(stream_url.location, file_name) return True
Example #20
Source File: playlist.py From SoundCloud with GNU General Public License v3.0 | 5 votes |
def download_from_url(client_id, url, base_dir, override=False): """Download the given playlist""" downloaded = 0 skipped = 0 errors = 0 # Retrieve playlist data client = soundcloud.Client(client_id=client_id) playlist = client.get('/resolve', url=url) # Create dir playlist_title = playlist.title dir = os.path.join(base_dir, playlist_title) utils.create_dir(dir) # Download tracks for trak in playlist.tracks: try: #done = song.down(client, track, dir, override) done = track.download_from_id(client_id, trak['id'], dir, override) if done: downloaded = downloaded + 1 else: skipped = skipped + 1 except requests.exceptions.HTTPError, err: if err.response.status_code == 404: print 'Error: could not download' errors = errors + 1 else: raise
Example #21
Source File: metrics.py From adagan with BSD 3-Clause "New" or "Revised" License | 5 votes |
def _make_plots_1d(self, opts, step, real_points, fake_points, weights=None, prefix=''): max_val = opts['gmm_max_val'] * 1.2 if real_points is None: real = np.zeros([0, 2]) else: num_real_points = len(real_points) real = np.reshape(real_points, [num_real_points, 1]).flatten() if fake_points is None: fake = np.zeros([0, 2]) else: num_fake_points = len(fake_points) fake = np.reshape(fake_points, [num_fake_points, 1]).flatten() # Plotting the sample AND the weights simultaneously plt.clf() f, _, _ = plt.hist(real, bins=100, range=(-max_val, max_val), normed=True, histtype='step', lw=2, color='red', label='real') plt.hist(fake, bins=100, range=(-max_val, max_val), normed=True, histtype='step', lw=2, color='blue', label='fake') if weights is not None: assert len(real) == len(weights) weights_srt = np.array([y for (x, y) in sorted(zip(real, weights))]) real_points_srt = np.array(sorted(real)) max_pdf = np.max(f) weights_srt = weights_srt / np.max(weights_srt) * max_pdf * 0.8 plt.plot(real_points_srt, weights_srt, lw=3, color='green', label='weights') plt.legend(loc='upper left') filename = prefix + 'mixture{:02d}.png'.format(step) utils.create_dir(opts['work_dir']) plt.savefig(utils.o_gfile((opts["work_dir"], filename), 'wb'), format='png')
Example #22
Source File: readGet.py From enaBrowserTools with Apache License 2.0 | 4 votes |
def download_files(accession, output_format, dest_dir, fetch_index, fetch_meta, aspera): accession_dir = os.path.join(dest_dir, accession) utils.create_dir(accession_dir) # download experiment xml is_experiment = utils.is_experiment(accession) if fetch_meta and is_experiment: download_meta(accession, accession_dir) if fetch_meta and utils.is_run(accession): download_experiment_meta(accession, accession_dir) # download data files search_url = utils.get_file_search_query(accession, aspera) lines = utils.download_report_from_portal(search_url) for line in lines[1:]: data_accession, filelist, md5list, indexlist = utils.parse_file_search_result_line( line, accession, output_format) # create run directory if downloading all data for an experiment if is_experiment: run_dir = os.path.join(accession_dir, data_accession) utils.create_dir(run_dir) target_dir = run_dir else: target_dir = accession_dir # download run/analysis XML if fetch_meta: download_meta(data_accession, target_dir) if len(filelist) == 0: if output_format is None: print ('No files available for {0}'.format(data_accession)) else: print ('No files of format {0} for {1}'.format(output_format, data_accession)) continue for i in range(len(filelist)): file_url = filelist[i] md5 = md5list[i] if file_url != '': download_file(file_url, target_dir, md5, aspera) for index_file in indexlist: if index_file != '': download_file(index_file, target_dir, None, aspera) if utils.is_empty_dir(target_dir): print('Deleting directory ' + os.path.basename(target_dir)) os.rmdir(target_dir)
Example #23
Source File: readGet.py From enaBrowserTools with Apache License 2.0 | 4 votes |
def download_files(accession, output_format, dest_dir, fetch_index, fetch_meta, aspera): accession_dir = os.path.join(dest_dir, accession) utils.create_dir(accession_dir) # download experiment xml is_experiment = utils.is_experiment(accession) if fetch_meta and is_experiment: download_meta(accession, accession_dir) if fetch_meta and utils.is_run(accession): download_experiment_meta(accession, accession_dir) # download data files search_url = utils.get_file_search_query(accession, aspera) temp_file = os.path.join(dest_dir, 'temp.txt') utils.download_report_from_portal(search_url, temp_file) f = open(temp_file) lines = f.readlines() f.close() os.remove(temp_file) for line in lines[1:]: data_accession, filelist, md5list, indexlist = utils.parse_file_search_result_line( line, accession, output_format) # create run directory if downloading all data for an experiment if is_experiment: run_dir = os.path.join(accession_dir, data_accession) utils.create_dir(run_dir) target_dir = run_dir else: target_dir = accession_dir # download run/analysis XML if fetch_meta: download_meta(data_accession, target_dir) if len(filelist) == 0: if output_format is None: print 'No files available for {0}'.format(data_accession) else: print 'No files of format {0} available for {1}'.format( output_format, data_accession) continue for i in range(len(filelist)): file_url = filelist[i] md5 = md5list[i] if file_url != '': download_file(file_url, target_dir, md5, aspera) if fetch_index: for index_file in indexlist: if index_file != '': download_file(index_file, target_dir, None, aspera) if utils.is_empty_dir(target_dir): print 'Deleting directory ' + os.path.basename(target_dir) os.rmdir(target_dir)
Example #24
Source File: train.py From bottom-up-attention-vqa with GNU General Public License v3.0 | 4 votes |
def train(model, train_loader, eval_loader, num_epochs, output): utils.create_dir(output) optim = torch.optim.Adamax(model.parameters()) logger = utils.Logger(os.path.join(output, 'log.txt')) best_eval_score = 0 for epoch in range(num_epochs): total_loss = 0 train_score = 0 t = time.time() for i, (v, b, q, a) in enumerate(train_loader): v = Variable(v).cuda() b = Variable(b).cuda() q = Variable(q).cuda() a = Variable(a).cuda() pred = model(v, b, q, a) loss = instance_bce_with_logits(pred, a) loss.backward() nn.utils.clip_grad_norm(model.parameters(), 0.25) optim.step() optim.zero_grad() batch_score = compute_score_with_logits(pred, a.data).sum() total_loss += loss.data[0] * v.size(0) train_score += batch_score total_loss /= len(train_loader.dataset) train_score = 100 * train_score / len(train_loader.dataset) model.train(False) eval_score, bound = evaluate(model, eval_loader) model.train(True) logger.write('epoch %d, time: %.2f' % (epoch, time.time()-t)) logger.write('\ttrain_loss: %.2f, score: %.2f' % (total_loss, train_score)) logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound)) if eval_score > best_eval_score: model_path = os.path.join(output, 'model.pth') torch.save(model.state_dict(), model_path) best_eval_score = eval_score
Example #25
Source File: train.py From Attention-on-Attention-for-VQA with MIT License | 4 votes |
def train(model, train_loader, eval_loader, num_epochs, output, opt, wd): utils.create_dir(output) # Paper uses AdaDelta if opt == 'Adadelta': optim = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=1e-6, weight_decay=wd) elif opt == 'RMSprop': optim = torch.optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=wd, momentum=0, centered=False) elif opt == 'Adam': optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=wd) else: optim = torch.optim.Adamax(model.parameters(), weight_decay=wd) logger = utils.Logger(os.path.join(output, 'log.txt')) best_eval_score = 0 for epoch in range(num_epochs): total_loss = 0 train_score = 0 t = time.time() correct = 0 for i, (v, b, q, a) in enumerate(train_loader): v = Variable(v).cuda() b = Variable(b).cuda() # boxes not used q = Variable(q).cuda() a = Variable(a).cuda() # true labels pred = model(v, b, q, a) loss = instance_bce_with_logits(pred, a) loss.backward() nn.utils.clip_grad_norm(model.parameters(), 0.25) optim.step() optim.zero_grad() batch_score = compute_score_with_logits(pred, a.data).sum() total_loss += loss.data[0] * v.size(0) train_score += batch_score total_loss /= len(train_loader.dataset) train_score = 100 * train_score / len(train_loader.dataset) model.train(False) eval_score, bound, V_loss = evaluate(model, eval_loader) model.train(True) logger.write('epoch %d, time: %.2f' % (epoch, time.time()-t)) logger.write('\ttrain_loss: %.3f, score: %.3f' % (total_loss, train_score)) logger.write('\teval loss: %.3f, score: %.3f (%.3f)' % (V_loss, 100 * eval_score, 100 * bound)) if eval_score > best_eval_score: model_path = os.path.join(output, 'model.pth') torch.save(model.state_dict(), model_path) best_eval_score = eval_score