Python config.max_len() Examples
The following are 5
code examples of config.max_len().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
config
, or try the search function
.
Example #1
Source File: prepare_data.py From dcase2017_task4_cvssp with MIT License | 6 votes |
def pad_trunc_seq(x, max_len): """Pad or truncate a sequence data to a fixed length. Args: x: ndarray, input sequence data. max_len: integer, length of sequence to be padded or truncated. Returns: ndarray, Padded or truncated input sequence data. """ L = len(x) shape = x.shape if L < max_len: pad_shape = (max_len - L,) + shape[1:] pad = np.zeros(pad_shape) x_new = np.concatenate((x, pad), axis=0) else: x_new = x[0:max_len] return x_new ### Load data & scale data
Example #2
Source File: pre_process.py From Machine-Translation with Apache License 2.0 | 5 votes |
def build_samples(): word_map_zh = json.load(open('data/WORDMAP_zh.json', 'r')) word_map_en = json.load(open('data/WORDMAP_en.json', 'r')) for usage in ['train', 'valid']: if usage == 'train': translation_path_en = os.path.join(train_translation_folder, train_translation_en_filename) translation_path_zh = os.path.join(train_translation_folder, train_translation_zh_filename) filename = 'data/samples_train.json' else: translation_path_en = os.path.join(valid_translation_folder, valid_translation_en_filename) translation_path_zh = os.path.join(valid_translation_folder, valid_translation_zh_filename) filename = 'data/samples_valid.json' print('loading {} texts and vocab'.format(usage)) with open(translation_path_en, 'r') as f: data_en = f.readlines() with open(translation_path_zh, 'r') as f: data_zh = f.readlines() print('building {} samples'.format(usage)) samples = [] for idx in tqdm(range(len(data_en))): sentence_zh = data_zh[idx].strip() seg_list = jieba.cut(sentence_zh) input_zh = encode_text(word_map_zh, list(seg_list)) sentence_en = data_en[idx].strip().lower() tokens = [normalizeString(s) for s in nltk.word_tokenize(sentence_en) if len(normalizeString(s)) > 0] output_en = encode_text(word_map_en, tokens) if len(input_zh) <= max_len and len( output_en) <= max_len and UNK_token not in input_zh and UNK_token not in output_en: samples.append({'input': list(input_zh), 'output': list(output_en)}) with open(filename, 'w') as f: json.dump(samples, f, indent=4) print('{} {} samples created at: {}.'.format(len(samples), usage, filename))
Example #3
Source File: pre_process.py From Machine-Translation-v2 with Apache License 2.0 | 5 votes |
def build_samples(): word_map_zh = json.load(open('data/WORDMAP_zh.json', 'r')) word_map_en = json.load(open('data/WORDMAP_en.json', 'r')) for usage in ['train', 'valid']: if usage == 'train': translation_path_en = os.path.join(train_translation_folder, train_translation_en_filename) translation_path_zh = os.path.join(train_translation_folder, train_translation_zh_filename) filename = 'data/samples_train.json' else: translation_path_en = os.path.join(valid_translation_folder, valid_translation_en_filename) translation_path_zh = os.path.join(valid_translation_folder, valid_translation_zh_filename) filename = 'data/samples_valid.json' print('loading {} texts and vocab'.format(usage)) with open(translation_path_en, 'r') as f: data_en = f.readlines() with open(translation_path_zh, 'r') as f: data_zh = f.readlines() print('building {} samples'.format(usage)) samples = [] for idx in tqdm(range(len(data_en))): sentence_en = data_en[idx].strip().lower() tokens = [normalizeString(s) for s in nltk.word_tokenize(sentence_en)] input_en = encode_text(word_map_en, tokens) sentence_zh = data_zh[idx].strip() seg_list = jieba.cut(sentence_zh) output_zh = encode_text(word_map_zh, list(seg_list)) if len(input_en) <= max_len and len( output_zh) <= max_len and UNK_token not in input_en and UNK_token not in output_zh: samples.append({'input': list(input_en), 'output': list(output_zh)}) with open(filename, 'w') as f: json.dump(samples, f, indent=4) print('{} {} samples created at: {}.'.format(len(samples), usage, filename))
Example #4
Source File: data_utils.py From neural-question-generation with MIT License | 5 votes |
def get_loader(src_file, trg_file, word2idx, batch_size, use_tag=False, debug=False, shuffle=False): dataset = SQuadDatasetWithTag(src_file, trg_file, config.max_len, word2idx, debug) dataloader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn_tag) return dataloader
Example #5
Source File: infer.py From text-classifier with Apache License 2.0 | 4 votes |
def infer_deep_model(model_type='cnn', data_path='', model_save_path='', label_vocab_path='', max_len=300, batch_size=128, col_sep='\t', pred_save_path=None): from keras.models import load_model # load data content data_set, true_labels = data_reader(data_path, col_sep) # init feature # han model need [doc sentence dim] feature(shape 3); others is [sentence dim] feature(shape 2) if model_type == 'han': feature_type = 'doc_vectorize' else: feature_type = 'vectorize' feature = Feature(data_set, feature_type=feature_type, is_infer=True, max_len=max_len) # get data feature data_feature = feature.get_feature() # load model model = load_model(model_save_path) # predict, in keras, predict_proba same with predict pred_label_probs = model.predict(data_feature, batch_size=batch_size) # label id map label_id = load_vocab(label_vocab_path) id_label = {v: k for k, v in label_id.items()} pred_labels = [prob.argmax() for prob in pred_label_probs] pred_labels = [id_label[i] for i in pred_labels] pred_output = [id_label[prob.argmax()] + col_sep + str(prob.max()) for prob in pred_label_probs] logger.info("save infer label and prob result to: %s" % pred_save_path) save_predict_result(pred_output, ture_labels=None, pred_save_path=pred_save_path, data_set=data_set) if true_labels: # evaluate assert len(pred_labels) == len(true_labels) for label, prob in zip(true_labels, pred_label_probs): logger.debug('label_true:%s\tprob_label:%s\tprob:%s' % (label, id_label[prob.argmax()], prob.max())) print('total eval:') try: print(classification_report(true_labels, pred_labels)) print(confusion_matrix(true_labels, pred_labels)) except UnicodeEncodeError: true_labels_id = [label_id[i] for i in true_labels] pred_labels_id = [label_id[i] for i in pred_labels] print(classification_report(true_labels_id, pred_labels_id)) print(confusion_matrix(true_labels_id, pred_labels_id))