Python util.flatten() Examples
The following are 27
code examples of util.flatten().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
util
, or try the search function
.
Example #1
Source File: count.py From coref with Apache License 2.0 | 6 votes |
def count(data_file): f = open(data_file) max_num_sp = 0 overlap, total = 0, 0 for i, line in enumerate(f): # print('---', line) data = json.loads(line) clusters = util.flatten(data['clusters']) clusters = [tuple(c) for c in clusters] for c1 in clusters: for c2 in clusters: if c1 == c2: continue total += 1 if (is_overlap(c1, c2)) or (is_overlap(c2, c1)): overlap += 1 # print('overlap', c1, c2) # else: # print('non-overlap', c1, c2) print(overlap, total, overlap * 100.0 / total) print('max_num_sp', max_num_sp)
Example #2
Source File: base.py From mHTM with MIT License | 6 votes |
def encode(self, data=None): """ Use a generator to yield each encoded bit. This supports being able to encode a list of values, where each value will be sequentially encoded. This function can also encode single values. @param data: The data to encode. If it isn't provided the encoder's data is used. """ if data is None: data = self.raw_data # Make a list to account for single inputs for x in flatten([data]): for bit in flatten([self._encode(x)]): yield bit
Example #3
Source File: lsgn_data.py From lsgn with Apache License 2.0 | 6 votes |
def load_eval_data(self): eval_data = [] eval_tensors = [] coref_eval_data = [] with open(self.config["eval_path"]) as f: eval_examples = [json.loads(jsonline) for jsonline in f.readlines()] populate_sentence_offset(eval_examples) for doc_id, example in enumerate(eval_examples): doc_tensors = [] num_mentions_in_doc = 0 for e in self.split_document_example(example): # Because each batch=1 document at test time, we do not need to offset cluster ids. e["cluster_id_offset"] = 0 e["doc_id"] = doc_id + 1 doc_tensors.append(self.tensorize_example(e, is_training=False)) #num_mentions_in_doc += len(e["coref"]) #assert num_mentions_in_doc == len(util.flatten(example["clusters"])) eval_tensors.append(doc_tensors) eval_data.extend(srl_eval_utils.split_example_for_eval(example)) coref_eval_data.append(example) print("Loaded {} eval examples.".format(len(eval_data))) return eval_data, eval_tensors, coref_eval_data
Example #4
Source File: jira_util.py From project-dev-kpis with MIT License | 6 votes |
def issue_to_changelog(issue): return dict( [ ('key', issue.key), ( 'changelog', [ (u'Created', parse_date(issue.fields.created)) ] + flatten([ [ (i.toString, parse_date(h.created)) for i in h.items if i.field == 'status' ] for h in issue.changelog.histories ]) ) ] )
Example #5
Source File: minecraft.py From minecraft-starwars with MIT License | 5 votes |
def intFloor(*args): return [int(math.floor(x)) for x in flatten(args)]
Example #6
Source File: to_gap_tsv.py From coref with Apache License 2.0 | 5 votes |
def convert(json_file, tsv_file): data = read_json(json_file) tsv = read_tsv_file(tsv_file) if tsv_file is not None else None predictions = ['\t'.join(['ID', 'A-coref', 'B-coref'])] for key, datum in data.items(): prediction = data[key] sents = util.flatten(prediction['sentences']) if tsv is not None: print(list(enumerate(tsv[key]))) a_offset, b_offset, pronoun_offset = tuple(map(int, tsv[key][5].split(':'))), tuple(map(int, tsv[key][8].split(':'))), tuple(map(int, tsv[key][3].split(':'))) assert ' '.join(sents[a_offset[0]:a_offset[1]]) == tsv[key][4], (sents[a_offset[0]:a_offset[1]], tsv[key][4]) assert ' '.join(sents[b_offset[0]:b_offset[1]]) == tsv[key][7], (sents[b_offset[0]:b_offset[1]], tsv[key][7]) assert ' '.join(sents[pronoun_offset[0]:pronoun_offset[1]]) == tsv[key][2], (sents[pronoun_offset[0]:pronoun_offset[1]], tsv[key][2]) # continue pronoun_cluster = find_pronoun_cluster(prediction, prediction['pronoun_subtoken_span']) a_coref, b_coref = 'FALSE', 'FALSE' a_text, b_text = (tsv[key][4], tsv[key][7]) if tsv is not None else (None, None) for span in pronoun_cluster: a_aligned = is_aligned(span, prediction['a_subtoken_span']) if tsv is None else is_substring_aligned(span, sents, a_text) b_aligned = is_aligned(span, prediction['b_subtoken_span']) if tsv is None else is_substring_aligned(span, sents, b_text) if a_aligned: a_coref = 'TRUE' if b_aligned: b_coref = 'TRUE' predictions += ['\t'.join([key, a_coref, b_coref])] # write file with open(json_file.replace('jsonlines', 'tsv'), 'w') as f: f.write('\n'.join(predictions))
Example #7
Source File: count.py From coref with Apache License 2.0 | 5 votes |
def avg_len(data_file): f = open(data_file) total = 0 max_num_sp = 0 segments = [] for i, line in enumerate(f): # print('---', line) data = json.loads(line) text = util.flatten(data['sentences']) segments.append(len(data['sentences'])) total += len(text) max_num_sp = max(max_num_sp, len(text)) print(total / i) print(max_num_sp) print(len(segments), sum(segments) / len(segments), max(segments), sum([1 for s in segments if s == 1]))
Example #8
Source File: compare.py From coref with Apache License 2.0 | 5 votes |
def compare_json(json1, json2): json1 = read_file(json1) json2 = read_file(json2) for i, (l1, l2) in enumerate(zip(json1, json2)): assert l1['doc_key'] == l2['doc_key'] if tuple(util.flatten(l1['sentences'])) != tuple(util.flatten(l2['sentences'])): print(i, l1['doc_key'], list(enumerate(util.flatten(l1['sentences']))), list(enumerate(util.flatten(l2['sentences'])))) for j, (w1, w2) in enumerate(zip(util.flatten(l1['sentences']), util.flatten(l2['sentences']))): if w1 != w2: print(j, w1, w2) break
Example #9
Source File: http_demo.py From coref with Apache License 2.0 | 5 votes |
def print_predictions(example): words = util.flatten(example["sentences"]) for cluster in example["predicted_clusters"]: print(u"Predicted cluster: {}".format([" ".join(words[m[0]:m[1]+1]) for m in cluster]))
Example #10
Source File: print_clusters.py From coref with Apache License 2.0 | 5 votes |
def print_clusters(data_file): f = open(data_file) for i, line in enumerate(f): data = json.loads(line) text = util.flatten(data['sentences']) # clusters = [[text[s:e+1] for s,e in cluster] for cluster in data['clusters']] #print(text) for ci, cluster in enumerate(data['clusters']): spans = [text[s:e+1] for s,e in cluster] print(i, ci, spans) if i > 5: break
Example #11
Source File: pronoun_evaluation.py From coref with Apache License 2.0 | 5 votes |
def evaluate(fname): p, r, f1 = [], [], [] pronoun_text = defaultdict(int) num_gold_pairs, num_pred_pairs = 0, 0 total_gold_singletons, total_pred_singletons, total_singleton_intersection = 0, 0, 0 with open(fname) as f: for line in f: datum = json.loads(line) tokens = flatten(datum['sentences']) #pronouns = flatten(datum['clusters']) pair_fn = get_mention_pairs # for pidx in pronouns: # pronoun_text[(tokens[pidx].lower())] += 1 gold_pronoun_mention_pairs, gold_singletons = pair_fn(datum['clusters'], flatten(datum['clusters'])) pred_pronoun_mention_pairs, pred_singletons = pair_fn(datum['predicted_clusters'], flatten(datum['predicted_clusters'])) total_gold_singletons += len(gold_singletons) total_pred_singletons += len(pred_singletons) total_singleton_intersection += len(gold_singletons.intersection(pred_singletons)) intersection = gold_pronoun_mention_pairs.intersection(pred_pronoun_mention_pairs) num_gold_pairs += len(gold_pronoun_mention_pairs) num_pred_pairs += len(pred_pronoun_mention_pairs) this_recall = len(intersection) / len(gold_pronoun_mention_pairs) if len(gold_pronoun_mention_pairs) > 0 else 1.0 this_prec = len(intersection) / len(pred_pronoun_mention_pairs) if len(pred_pronoun_mention_pairs) > 0 else 1.0 this_f1 = 2 * this_recall * this_prec / (this_recall + this_prec) if this_recall + this_prec > 0 else 0 p += [this_prec] r += [this_recall] f1 += [this_f1] print('gold_singletons: {}, pred_singletons: {} intersection: {}'.format(total_gold_singletons, total_pred_singletons, total_singleton_intersection)) print('num_gold: {}, num_pred: {}, P: {}, R: {} F1: {}'.format(num_gold_pairs, num_pred_pairs, sum(p) / len(p), sum(r) / len(r), sum(f1) / len(f1))) #print(sum(pronoun_text.values()), sorted(list(pronoun_text.items()), key=lambda k : k[1]))
Example #12
Source File: demo.py From coref with Apache License 2.0 | 5 votes |
def print_predictions(example): words = util.flatten(example["sentences"]) for cluster in example["predicted_clusters"]: print(u"Predicted cluster: {}".format([" ".join(words[m[0]:m[1]+1]) for m in cluster]))
Example #13
Source File: BlockStmt.py From Turing with MIT License | 5 votes |
def get_children(self) -> List[AstNode]: return util.flatten(x.get_children() for x in self.children)
Example #14
Source File: filtering.py From dblp with MIT License | 5 votes |
def papers_file(self): for file_obj in util.flatten(self.input()): if 'paper' in file_obj.path: return file_obj
Example #15
Source File: minimize.py From coref-ee with Apache License 2.0 | 5 votes |
def finalize(self): merged_clusters = [] for c1 in self.clusters.values(): existing = None for m in c1: for c2 in merged_clusters: if m in c2: existing = c2 break if existing is not None: break if existing is not None: print("Merging clusters (shouldn't happen very often.)") existing.update(c1) else: merged_clusters.append(set(c1)) merged_clusters = [list(c) for c in merged_clusters] all_mentions = util.flatten(merged_clusters) assert len(all_mentions) == len(set(all_mentions)) return { "doc_key": self.doc_key, "sentences": self.sentences, "speakers": self.speakers, "constituents": self.span_dict_to_list(self.constituents), "ner": self.span_dict_to_list(self.ner), "clusters": merged_clusters }
Example #16
Source File: demo.py From coref-ee with Apache License 2.0 | 5 votes |
def print_predictions(example): words = util.flatten(example["sentences"]) for cluster in example["predicted_clusters"]: print(u"Predicted cluster: {}".format([" ".join(words[m[0]:m[1]+1]) for m in cluster]))
Example #17
Source File: minimize.py From e2e-coref with Apache License 2.0 | 5 votes |
def finalize(self): merged_clusters = [] for c1 in self.clusters.values(): existing = None for m in c1: for c2 in merged_clusters: if m in c2: existing = c2 break if existing is not None: break if existing is not None: print("Merging clusters (shouldn't happen very often.)") existing.update(c1) else: merged_clusters.append(set(c1)) merged_clusters = [list(c) for c in merged_clusters] all_mentions = util.flatten(merged_clusters) assert len(all_mentions) == len(set(all_mentions)) return { "doc_key": self.doc_key, "sentences": self.sentences, "speakers": self.speakers, "constituents": self.span_dict_to_list(self.constituents), "ner": self.span_dict_to_list(self.ner), "clusters": merged_clusters }
Example #18
Source File: demo.py From e2e-coref with Apache License 2.0 | 5 votes |
def print_predictions(example): words = util.flatten(example["sentences"]) for cluster in example["predicted_clusters"]: print(u"Predicted cluster: {}".format([" ".join(words[m[0]:m[1]+1]) for m in cluster]))
Example #19
Source File: minimize.py From gap with MIT License | 5 votes |
def finalize(self): merged_clusters = [] for c1 in self.clusters.values(): existing = None for m in c1: for c2 in merged_clusters: if m in c2: existing = c2 break if existing is not None: break if existing is not None: print("Merging clusters (shouldn't happen very often.)") existing.update(c1) else: merged_clusters.append(set(c1)) merged_clusters = [list(c) for c in merged_clusters] all_mentions = util.flatten(merged_clusters) assert len(all_mentions) == len(set(all_mentions)) return { "doc_key": self.doc_key, "sentences": self.sentences, "speakers": self.speakers, "constituents": self.span_dict_to_list(self.constituents), "ner": self.span_dict_to_list(self.ner), "clusters": merged_clusters }
Example #20
Source File: demo.py From gap with MIT License | 5 votes |
def print_predictions(example): words = util.flatten(example["sentences"]) for cluster in example["predicted_clusters"]: print(u"Predicted cluster: {}".format([" ".join(words[m[0]:m[1]+1]) for m in cluster]))
Example #21
Source File: minecraft.py From TeachCraft-Challenges with Apache License 2.0 | 5 votes |
def intFloor(*args): return [int(math.floor(x)) for x in flatten(args)]
Example #22
Source File: coref_bert_model_2.py From coref-ee with Apache License 2.0 | 4 votes |
def evaluate(self, session, official_stdout=False, pprint=False, test=False): self.load_eval_data() coref_predictions = {} coref_evaluator = metrics.CorefEvaluator() for example_num, (tensorized_example, example) in enumerate(self.eval_data): feed_dict = {self.input_tensors[k]: tensorized_example[k] for k in self.input_tensors} candidate_starts, candidate_ends, candidate_mention_scores, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = session.run( self.predictions, feed_dict=feed_dict) predicted_antecedents = self.get_predicted_antecedents(top_antecedents, top_antecedent_scores) coref_predictions[example["doc_key"]] = self.evaluate_coref(top_span_starts, top_span_ends, predicted_antecedents, example["clusters"], coref_evaluator) if pprint: tokens = util.flatten(example["sentences"]) print("GOLD CLUSTERS:") util.coref_pprint(tokens, example["clusters"]) print("PREDICTED CLUSTERS:") util.coref_pprint(tokens, coref_predictions[example["doc_key"]]) print("==================================================================") if example_num % 10 == 0: print("Evaluated {}/{} examples.".format(example_num + 1, len(self.eval_data))) summary_dict = {} p, r, f = coref_evaluator.get_prf() average_f1 = f * 100 summary_dict["Average F1 (py)"] = average_f1 print("Average F1 (py): {:.2f}%".format(average_f1)) summary_dict["Average precision (py)"] = p print("Average precision (py): {:.2f}%".format(p * 100)) summary_dict["Average recall (py)"] = r print("Average recall (py): {:.2f}%".format(r * 100)) # if test: # conll_results = conll.evaluate_conll(self.config["conll_eval_path"], coref_predictions, official_stdout) # average_f1 = sum(results["f"] for results in conll_results.values()) / len(conll_results) # summary_dict["Average F1 (conll)"] = average_f1 # print("Average F1 (conll): {:.2f}%".format(average_f1)) return util.make_summary(summary_dict), average_f1
Example #23
Source File: coref_model.py From coref-ee with Apache License 2.0 | 4 votes |
def evaluate(self, session, official_stdout=False, pprint=False, test=False): self.load_eval_data() coref_predictions = {} coref_evaluator = metrics.CorefEvaluator() if not test: session.run(self.switch_to_test_mode_op) for example_num, (tensorized_example, example) in enumerate(self.eval_data): _, _, _, _, _, _, _, _, _, gold_starts, gold_ends, _ = tensorized_example feed_dict = {i: t for i, t in zip(self.input_tensors, tensorized_example)} candidate_starts, candidate_ends, candidate_mention_scores, top_span_starts, top_span_ends, top_antecedents, top_antecedent_scores = session.run( self.predictions, feed_dict=feed_dict) predicted_antecedents = self.get_predicted_antecedents(top_antecedents, top_antecedent_scores) coref_predictions[example["doc_key"]] = self.evaluate_coref(top_span_starts, top_span_ends, predicted_antecedents, example["clusters"], coref_evaluator) if pprint: tokens = util.flatten(example["sentences"]) print("GOLD CLUSTERS:") util.coref_pprint(tokens, example["clusters"]) print("PREDICTED CLUSTERS:") util.coref_pprint(tokens, coref_predictions[example["doc_key"]]) print('==================================================================') if example_num % 10 == 0: print("Evaluated {}/{} examples.".format(example_num + 1, len(self.eval_data))) if not test: session.run(self.switch_to_train_mode_op) summary_dict = {} p, r, f = coref_evaluator.get_prf() average_f1 = f * 100 summary_dict["Average F1 (py)"] = average_f1 print("Average F1 (py): {:.2f}%".format(average_f1)) summary_dict["Average precision (py)"] = p print("Average precision (py): {:.2f}%".format(p * 100)) summary_dict["Average recall (py)"] = r print("Average recall (py): {:.2f}%".format(r * 100)) # if test: # conll_results = conll.evaluate_conll(self.config["conll_eval_path"], coref_predictions, official_stdout) # average_f1 = sum(results["f"] for results in conll_results.values()) / len(conll_results) # summary_dict["Average F1 (conll)"] = average_f1 # print("Average F1 (conll): {:.2f}%".format(average_f1)) return util.make_summary(summary_dict), average_f1
Example #24
Source File: coref_model.py From coref-ee with Apache License 2.0 | 4 votes |
def tensorize_example(self, example, is_training): clusters = example["clusters"] gold_mentions = sorted(tuple(m) for m in util.flatten(clusters)) gold_mention_map = {m: i for i, m in enumerate(gold_mentions)} cluster_ids = np.zeros(len(gold_mentions)) for cluster_id, cluster in enumerate(clusters): for mention in cluster: cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id + 1 sentences = example["sentences"] num_words = sum(len(s) for s in sentences) speakers = util.flatten(example["speakers"]) assert num_words == len(speakers) max_sentence_length = max(len(s) for s in sentences) max_word_length = max(max(max(len(w) for w in s) for s in sentences), max(self.config["filter_widths"])) text_len = np.array([len(s) for s in sentences]) tokens = [[""] * max_sentence_length for _ in sentences] context_word_emb = np.zeros([len(sentences), max_sentence_length, self.context_embeddings.size]) head_word_emb = np.zeros([len(sentences), max_sentence_length, self.head_embeddings.size]) char_index = np.zeros([len(sentences), max_sentence_length, max_word_length]) for i, sentence in enumerate(sentences): for j, word in enumerate(sentence): tokens[i][j] = word context_word_emb[i, j] = self.context_embeddings[word] head_word_emb[i, j] = self.head_embeddings[word] char_index[i, j, :len(word)] = [self.char_dict[c] for c in word] tokens = np.array(tokens) speaker_dict = {s: i for i, s in enumerate(set(speakers))} speaker_ids = np.array([speaker_dict[s] for s in speakers]) doc_key = example["doc_key"] genre = self.genres[doc_key[:2]] gold_starts, gold_ends = self.tensorize_mentions(gold_mentions) lm_emb = self.load_lm_embeddings(doc_key) example_tensors = ( tokens, context_word_emb, head_word_emb, lm_emb, char_index, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids) if is_training and len(sentences) > self.config["max_training_sentences"]: return self.truncate_example(*example_tensors) else: return example_tensors
Example #25
Source File: coref_model.py From e2e-coref with Apache License 2.0 | 4 votes |
def tensorize_example(self, example, is_training): clusters = example["clusters"] gold_mentions = sorted(tuple(m) for m in util.flatten(clusters)) gold_mention_map = {m:i for i,m in enumerate(gold_mentions)} cluster_ids = np.zeros(len(gold_mentions)) for cluster_id, cluster in enumerate(clusters): for mention in cluster: cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id + 1 sentences = example["sentences"] num_words = sum(len(s) for s in sentences) speakers = util.flatten(example["speakers"]) assert num_words == len(speakers) max_sentence_length = max(len(s) for s in sentences) max_word_length = max(max(max(len(w) for w in s) for s in sentences), max(self.config["filter_widths"])) text_len = np.array([len(s) for s in sentences]) tokens = [[""] * max_sentence_length for _ in sentences] context_word_emb = np.zeros([len(sentences), max_sentence_length, self.context_embeddings.size]) head_word_emb = np.zeros([len(sentences), max_sentence_length, self.head_embeddings.size]) char_index = np.zeros([len(sentences), max_sentence_length, max_word_length]) for i, sentence in enumerate(sentences): for j, word in enumerate(sentence): tokens[i][j] = word context_word_emb[i, j] = self.context_embeddings[word] head_word_emb[i, j] = self.head_embeddings[word] char_index[i, j, :len(word)] = [self.char_dict[c] for c in word] tokens = np.array(tokens) speaker_dict = { s:i for i,s in enumerate(set(speakers)) } speaker_ids = np.array([speaker_dict[s] for s in speakers]) doc_key = example["doc_key"] genre = self.genres[doc_key[:2]] gold_starts, gold_ends = self.tensorize_mentions(gold_mentions) lm_emb = self.load_lm_embeddings(doc_key) example_tensors = (tokens, context_word_emb, head_word_emb, lm_emb, char_index, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids) if is_training and len(sentences) > self.config["max_training_sentences"]: return self.truncate_example(*example_tensors) else: return example_tensors
Example #26
Source File: gold_mentions.py From coref with Apache License 2.0 | 4 votes |
def tensorize_example(self, example, is_training): clusters = example["clusters"] gold_mentions = sorted(tuple(m) for m in util.flatten(clusters)) gold_mention_map = {m:i for i,m in enumerate(gold_mentions)} cluster_ids = np.zeros(len(gold_mentions)) for cluster_id, cluster in enumerate(clusters): for mention in cluster: cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id + 1 sentences = example["sentences"] num_words = sum(len(s) for s in sentences) speakers = example["speakers"] # assert num_words == len(speakers), (num_words, len(speakers)) speaker_dict = self.get_speaker_dict(util.flatten(speakers)) sentence_map = example['sentence_map'] max_sentence_length = self.max_segment_len text_len = np.array([len(s) for s in sentences]) input_ids, input_mask, speaker_ids = [], [], [] for i, (sentence, speaker) in enumerate(zip(sentences, speakers)): sent_input_ids = self.tokenizer.convert_tokens_to_ids(sentence) sent_input_mask = [1] * len(sent_input_ids) sent_speaker_ids = [speaker_dict.get(s, 3) for s in speaker] while len(sent_input_ids) < max_sentence_length: sent_input_ids.append(0) sent_input_mask.append(0) sent_speaker_ids.append(0) input_ids.append(sent_input_ids) speaker_ids.append(sent_speaker_ids) input_mask.append(sent_input_mask) input_ids = np.array(input_ids) input_mask = np.array(input_mask) speaker_ids = np.array(speaker_ids) assert num_words == np.sum(input_mask), (num_words, np.sum(input_mask)) # speaker_dict = { s:i for i,s in enumerate(set(speakers)) } # speaker_ids = np.array([speaker_dict[s] for s in speakers]) doc_key = example["doc_key"] self.subtoken_maps[doc_key] = example["subtoken_map"] self.gold[doc_key] = example["clusters"] genre = self.genres.get(doc_key[:2], 0) gold_starts, gold_ends = self.tensorize_mentions(gold_mentions) example_tensors = (input_ids, input_mask, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids, sentence_map) if is_training and len(sentences) > self.config["max_training_sentences"]: if self.config['single_example']: return self.truncate_example(*example_tensors) else: offsets = range(self.config['max_training_sentences'], len(sentences), self.config['max_training_sentences']) tensor_list = [self.truncate_example(*(example_tensors + (offset,))) for offset in offsets] return tensor_list else: return example_tensors
Example #27
Source File: independent.py From coref with Apache License 2.0 | 4 votes |
def tensorize_example(self, example, is_training): clusters = example["clusters"] gold_mentions = sorted(tuple(m) for m in util.flatten(clusters)) gold_mention_map = {m:i for i,m in enumerate(gold_mentions)} cluster_ids = np.zeros(len(gold_mentions)) for cluster_id, cluster in enumerate(clusters): for mention in cluster: cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id + 1 sentences = example["sentences"] num_words = sum(len(s) for s in sentences) speakers = example["speakers"] # assert num_words == len(speakers), (num_words, len(speakers)) speaker_dict = self.get_speaker_dict(util.flatten(speakers)) sentence_map = example['sentence_map'] max_sentence_length = self.max_segment_len text_len = np.array([len(s) for s in sentences]) input_ids, input_mask, speaker_ids = [], [], [] for i, (sentence, speaker) in enumerate(zip(sentences, speakers)): sent_input_ids = self.tokenizer.convert_tokens_to_ids(sentence) sent_input_mask = [1] * len(sent_input_ids) sent_speaker_ids = [speaker_dict.get(s, 3) for s in speaker] while len(sent_input_ids) < max_sentence_length: sent_input_ids.append(0) sent_input_mask.append(0) sent_speaker_ids.append(0) input_ids.append(sent_input_ids) speaker_ids.append(sent_speaker_ids) input_mask.append(sent_input_mask) input_ids = np.array(input_ids) input_mask = np.array(input_mask) speaker_ids = np.array(speaker_ids) assert num_words == np.sum(input_mask), (num_words, np.sum(input_mask)) doc_key = example["doc_key"] self.subtoken_maps[doc_key] = example.get("subtoken_map", None) self.gold[doc_key] = example["clusters"] genre = self.genres.get(doc_key[:2], 0) gold_starts, gold_ends = self.tensorize_mentions(gold_mentions) example_tensors = (input_ids, input_mask, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids, sentence_map) if is_training and len(sentences) > self.config["max_training_sentences"]: if self.config['single_example']: return self.truncate_example(*example_tensors) else: offsets = range(self.config['max_training_sentences'], len(sentences), self.config['max_training_sentences']) tensor_list = [self.truncate_example(*(example_tensors + (offset,))) for offset in offsets] return tensor_list else: return example_tensors