Java Code Examples for it.unimi.dsi.fastutil.ints.IntSet#contains()
The following examples show how to use
it.unimi.dsi.fastutil.ints.IntSet#contains() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ExpReplay.java From deeplearning4j with Apache License 2.0 | 6 votes |
public ArrayList<Transition<A>> getBatch(int size) { ArrayList<Transition<A>> batch = new ArrayList<>(size); int storageSize = storage.size(); int actualBatchSize = Math.min(storageSize, size); int[] actualIndex = new int[actualBatchSize]; IntSet set = new IntOpenHashSet(); for( int i=0; i<actualBatchSize; i++ ){ int next = rnd.nextInt(storageSize); while(set.contains(next)){ next = rnd.nextInt(storageSize); } set.add(next); actualIndex[i] = next; } for (int i = 0; i < actualBatchSize; i ++) { Transition<A> trans = storage.get(actualIndex[i]); batch.add(trans.dup()); } return batch; }
Example 2
Source File: NbestListUtils.java From phrasal with GNU General Public License v3.0 | 6 votes |
/** * Baseline implementation. Augments the "standard" list with alternatives. * * @param l1 * @param l2 * @return */ public static <TK,FV> List<RichTranslation<TK,FV>> mergeAndDedup(List<RichTranslation<TK,FV>> standard, List<RichTranslation<TK,FV>> alt, int maxAltItems) { IntSet hashCodeSet = new IntOpenHashSet(standard.size()); for (RichTranslation<TK,FV> s : standard) { hashCodeSet.add(derivationHashCode(s.getFeaturizable().derivation)); } List<RichTranslation<TK,FV>> returnList = new ArrayList<>(standard); for (int i = 0, sz = Math.min(maxAltItems, alt.size()); i < sz; ++i) { RichTranslation<TK,FV> t = alt.get(i); int hashCode = derivationHashCode(t.getFeaturizable().derivation); if (! hashCodeSet.contains(hashCode)) returnList.add(t); } Collections.sort(returnList); return returnList; }
Example 3
Source File: IntToIntPairMapTestHelper.java From GraphJet with Apache License 2.0 | 6 votes |
public static KeyTestInfo generateRandomKeys(Random random, int maxNumKeys) { int maxKeyOrValue = maxNumKeys << 2; int[] keysAndValues = new int[maxNumKeys * 3]; int[] nonKeys = new int[maxNumKeys]; IntSet keySet = new IntOpenHashBigSet(maxNumKeys); for (int i = 0; i < maxNumKeys; i++) { int entry; do { entry = random.nextInt(maxKeyOrValue); } while (keySet.contains(entry)); keysAndValues[i * 3] = entry; keysAndValues[i * 3 + 1] = random.nextInt(maxKeyOrValue); keysAndValues[i * 3 + 2] = random.nextInt(maxKeyOrValue); keySet.add(entry); } for (int i = 0; i < maxNumKeys; i++) { int nonKey; do { nonKey = random.nextInt(maxKeyOrValue); } while (keySet.contains(nonKey)); nonKeys[i] = nonKey; } return new KeyTestInfo(keysAndValues, nonKeys); }
Example 4
Source File: MAPEvaluator.java From jstarcraft-ai with Apache License 2.0 | 6 votes |
@Override protected float measure(IntSet checkCollection, IntList rankList) { if (rankList.size() > size) { rankList = rankList.subList(0, size); } int count = 0; float map = 0F; for (int index = 0; index < rankList.size(); index++) { int itemIndex = rankList.get(index); if (checkCollection.contains(itemIndex)) { count++; map += 1F * count / (index + 1); } } return map / (checkCollection.size() < rankList.size() ? checkCollection.size() : rankList.size()); }
Example 5
Source File: PageToCategoryIDs.java From tagme with Apache License 2.0 | 5 votes |
@Override protected int[][] parseSet() throws IOException { final Int2ObjectMap<IntSet> map = new Int2ObjectOpenHashMap<IntSet>(3000000); final IntSet hidden= DatasetLoader.get(new HiddenCategoriesWIDs(lang)); File input = WikipediaFiles.CAT_LINKS.getSourceFile(lang); final Object2IntMap<String> categories=DatasetLoader.get(new CategoriesToWIDMap(lang)); SQLWikiParser parser = new SQLWikiParser(log) { @Override public boolean compute(ArrayList<String> values) throws IOException { String c_title=cleanPageName(values.get(SQLWikiParser.CATLINKS_TITLE_TO)); int id=Integer.parseInt(values.get(SQLWikiParser.CATLINKS_ID_FROM)); if(categories.containsKey(c_title) && !hidden.contains(categories.get(c_title).intValue())){ if(map.containsKey(id)){ map.get(id).add(categories.get(c_title).intValue()); }else{ IntSet set = new IntOpenHashSet(); set.add(categories.get(c_title).intValue()); map.put(id, set); } return true; } else return false; } }; InputStreamReader reader = new InputStreamReader(new FileInputStream(input), Charset.forName("UTF-8")); parser.compute(reader); reader.close(); return createDump(map); }
Example 6
Source File: ValueInTransformFunction.java From incubator-pinot with Apache License 2.0 | 5 votes |
private static int[] filterInts(IntSet intSet, int[] source) { IntList intList = new IntArrayList(); for (int value : source) { if (intSet.contains(value)) { intList.add(value); } } if (intList.size() == source.length) { return source; } else { return intList.toIntArray(); } }
Example 7
Source File: FastFilters.java From RankSys with Mozilla Public License 2.0 | 5 votes |
/** * Item filter that discards items in the training preference data. * * @param <U> type of the users * @param <I> type of the items * @param trainData preference data * @return item filters for each using returning true if the * user-item pair was not observed in the preference data */ public static <U, I> Function<U, IntPredicate> notInTrain(FastPreferenceData<U, I> trainData) { return user -> { IntSet set = new IntOpenHashSet(); trainData.getUidxPreferences(trainData.user2uidx(user)) .mapToInt(IdxPref::v1) .forEach(set::add); return iidx -> !set.contains(iidx); }; }
Example 8
Source File: NegativeSamplingExpander.java From samantha with MIT License | 5 votes |
private IntList getSampledIndices(IntSet trues, int maxVal) { IntList samples = new IntArrayList(); int num = trues.size(); if (maxNumSample != null) { num = maxNumSample; } for (int i=0; i<num; i++) { int dice = new Random().nextInt(maxVal); if (!trues.contains(dice)) { samples.add(dice); } } return samples; }
Example 9
Source File: AgreeSetGenerator.java From metanome-algorithms with Apache License 2.0 | 5 votes |
private void intersect(IntSet positions, IntSet indexSet) { IntSet toRemove = new IntArraySet(); for (int l : positions) { if (!indexSet.contains(l)) { toRemove.add(l); } } positions.removeAll(toRemove); }
Example 10
Source File: RecallEvaluator.java From jstarcraft-ai with Apache License 2.0 | 5 votes |
@Override protected float measure(IntSet checkCollection, IntList rankList) { if (rankList.size() > size) { rankList = rankList.subList(0, size); } int count = 0; for (int itemIndex : rankList) { if (checkCollection.contains(itemIndex)) { count++; } } return count / (checkCollection.size() + 0F); }
Example 11
Source File: PrecisionEvaluator.java From jstarcraft-ai with Apache License 2.0 | 5 votes |
@Override protected float measure(IntSet checkCollection, IntList rankList) { if (rankList.size() > size) { rankList = rankList.subList(0, size); } int count = 0; for (int itemIndex : rankList) { if (checkCollection.contains(itemIndex)) { count++; } } return count / (size + 0F); }
Example 12
Source File: RankingTask.java From jstarcraft-rns with Apache License 2.0 | 5 votes |
@Override protected IntList recommend(Model recommender, int userIndex) { ReferenceModule trainModule = trainModules[userIndex]; ReferenceModule testModule = testModules[userIndex]; IntSet itemSet = new IntOpenHashSet(); for (DataInstance instance : trainModule) { itemSet.add(instance.getQualityFeature(itemDimension)); } // TODO 此处代码需要重构 ArrayInstance copy = new ArrayInstance(trainMarker.getQualityOrder(), trainMarker.getQuantityOrder()); copy.copyInstance(testModule.getInstance(0)); copy.setQualityFeature(userDimension, userIndex); List<Integer2FloatKeyValue> rankList = new ArrayList<>(itemSize - itemSet.size()); for (int itemIndex = 0; itemIndex < itemSize; itemIndex++) { if (itemSet.contains(itemIndex)) { continue; } copy.setQualityFeature(itemDimension, itemIndex); recommender.predict(copy); rankList.add(new Integer2FloatKeyValue(itemIndex, copy.getQuantityMark())); } Collections.sort(rankList, (left, right) -> { return Float.compare(right.getValue(), left.getValue()); }); IntList recommendList = new IntArrayList(rankList.size()); for (Integer2FloatKeyValue keyValue : rankList) { recommendList.add(keyValue.getKey()); } return recommendList; }
Example 13
Source File: SBPRModel.java From jstarcraft-rns with Apache License 2.0 | 5 votes |
@Override public void prepare(Configurator configuration, DataModule model, DataSpace space) { super.prepare(configuration, model, space); regBias = configuration.getFloat("recommender.bias.regularization", 0.01F); // cacheSpec = conf.get("guava.cache.spec", // "maximumSize=5000,expireAfterAccess=50m"); itemBiases = DenseVector.valueOf(itemSize); itemBiases.iterateElement(MathCalculator.SERIAL, (scalar) -> { scalar.setValue(RandomUtility.randomFloat(1F)); }); userItemSet = getUserItemSet(scoreMatrix); // TODO 考虑重构 // find items rated by trusted neighbors only socialItemList = new ArrayList<>(userSize); for (int userIndex = 0; userIndex < userSize; userIndex++) { SparseVector userVector = scoreMatrix.getRowVector(userIndex); IntSet itemSet = userItemSet.get(userIndex); // find items rated by trusted neighbors only SparseVector socialVector = socialMatrix.getRowVector(userIndex); List<Integer> socialList = new LinkedList<>(); for (VectorScalar term : socialVector) { int socialIndex = term.getIndex(); userVector = scoreMatrix.getRowVector(socialIndex); for (VectorScalar enrty : userVector) { int itemIndex = enrty.getIndex(); // v's rated items if (!itemSet.contains(itemIndex) && !socialList.contains(itemIndex)) { socialList.add(itemIndex); } } } socialItemList.add(new ArrayList<>(socialList)); } }
Example 14
Source File: PRankDModel.java From jstarcraft-rns with Apache License 2.0 | 4 votes |
/** * train model * * @throws ModelException if error occurs */ @Override protected void doPractice() { List<IntSet> userItemSet = getUserItemSet(scoreMatrix); for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { totalError = 0F; // for each rated user-item (u,i) pair for (int userIndex = 0; userIndex < userSize; userIndex++) { SparseVector userVector = scoreMatrix.getRowVector(userIndex); if (userVector.getElementSize() == 0) { continue; } IntSet itemSet = userItemSet.get(userIndex); for (VectorScalar term : userVector) { // each rated item i int positiveItemIndex = term.getIndex(); float positiveScore = term.getValue(); int negativeItemIndex = -1; do { // draw an item j with probability proportional to // popularity negativeItemIndex = SampleUtility.binarySearch(itemProbabilities, 0, itemProbabilities.getElementSize() - 1, RandomUtility.randomFloat(itemProbabilities.getValue(itemProbabilities.getElementSize() - 1))); // ensure that it is unrated by user u } while (itemSet.contains(negativeItemIndex)); float negativeScore = 0F; // compute predictions float positivePredict = predict(userIndex, positiveItemIndex), negativePredict = predict(userIndex, negativeItemIndex); float distance = (float) Math.sqrt(1 - Math.tanh(itemCorrelations.getValue(positiveItemIndex, negativeItemIndex) * similarityFilter)); float itemWeight = itemWeights.getValue(negativeItemIndex); float error = itemWeight * (positivePredict - negativePredict - distance * (positiveScore - negativeScore)); totalError += error * error; // update vectors float learnFactor = learnRatio * error; for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { float userFactor = userFactors.getValue(userIndex, factorIndex); float positiveItemFactor = itemFactors.getValue(positiveItemIndex, factorIndex); float negativeItemFactor = itemFactors.getValue(negativeItemIndex, factorIndex); userFactors.shiftValue(userIndex, factorIndex, -learnFactor * (positiveItemFactor - negativeItemFactor)); itemFactors.shiftValue(positiveItemIndex, factorIndex, -learnFactor * userFactor); itemFactors.shiftValue(negativeItemIndex, factorIndex, learnFactor * userFactor); } } } totalError *= 0.5F; if (isConverged(epocheIndex) && isConverged) { break; } isLearned(epocheIndex); currentError = totalError; } }
Example 15
Source File: WikipediaEdges.java From tagme with Apache License 2.0 | 4 votes |
@Override protected void parseFile(File file) throws IOException { final Int2IntMap redirects = DatasetLoader.get(new RedirectMap(lang)); final IntSet disambiguations = DatasetLoader.get(new DisambiguationWIDs(lang)); final IntSet listpages = DatasetLoader.get(new ListPageWIDs(lang)); final IntSet ignores = DatasetLoader.get(new IgnoreWIDs(lang)); final IntSet valids = new AllWIDs(lang).getDataset();//DatasetLoader.get(new AllWIDs(lang)); valids.removeAll(redirects.keySet()); //valids.removeAll(disambiguations); //valids.removeAll(listpages); valids.removeAll(ignores); final Object2IntMap<String> titles = DatasetLoader.get(new TitlesToWIDMap(lang)); File tmp = Dataset.createTmpFile(); final BufferedWriter out = new BufferedWriter(new FileWriter(tmp)); SQLWikiParser parser = new SQLWikiParser(log) { @Override public boolean compute(ArrayList<String> values) throws IOException { int idFrom = Integer.parseInt(values.get(SQLWikiParser.PAGELINKS_ID_FROM)); if (redirects.containsKey(idFrom)) idFrom = redirects.get(idFrom); int ns = Integer.parseInt(values.get(SQLWikiParser.PAGELINKS_NS)); if (ns == SQLWikiParser.NS_ARTICLE && !redirects.containsKey(idFrom) && !ignores.contains(idFrom) && //questo e' necessario perchè alcune pagine che sono delle liste, in inglese finiscono //tra le pagine di disambiguazione (per via della categoria All_set_index_articles) (listpages.contains(idFrom) || !disambiguations.contains(idFrom)) //!listpages.contains(idFrom) && !disambiguations.contains(idFrom) && valids.contains(idFrom) /**/ ) { String titleTo = Dataset.cleanPageName(values.get(SQLWikiParser.PAGELINKS_TITLE_TO)); int idTo = titles.getInt(titleTo); if (redirects.containsKey(idTo)) idTo = redirects.get(idTo); if (idTo >= 0 && !ignores.contains(idTo) && (listpages.contains(idFrom) || !disambiguations.contains(idFrom)) && valids.contains(idTo)) { out.append(Integer.toString(idFrom)); out.append(SEP_CHAR); out.append(Integer.toString(idTo)); out.append('\n'); return true; } } return false; } }; File input = WikipediaFiles.PAGE_LINKS.getSourceFile(lang); parser.compute(input); out.close(); log.info("Now sorting edges..."); ExternalSort sorter = new ExternalSort(); sorter.setUniq(true); sorter.setNumeric(true); sorter.setColumns(new int[]{0,1}); sorter.setInFile(tmp.getAbsolutePath()); sorter.setOutFile(file.getAbsolutePath()); sorter.run(); tmp.delete(); log.info("Sorted. Done."); }
Example 16
Source File: DiverseNbestDecoder.java From phrasal with GNU General Public License v3.0 | 4 votes |
/** * Extract the n-best list. * * @param size * @param distinct * @return */ public List<Derivation<TK,FV>> decode(int size, boolean distinct, int sourceInputId, FeatureExtractor<TK, FV> featurizer, Scorer<FV> scorer, SearchHeuristic<TK, FV> heuristic, OutputSpace<TK, FV> outputSpace) { if (isIncompleteLattice) return Collections.emptyList(); List<Derivation<TK,FV>> returnList = new ArrayList<>(size); // WSGDEBUG // TODO(spenceg) Remaining bugs // // 1) Sometimes duplicate derivations can be extracted. Probably has to do with recombination. // for (int i = 0, sz = markedNodes.size(); i < sz; ++i) { // for (int i = 0, sz = Math.min(markedNodes.size(), size); i < sz; ++i) { Derivation<TK,FV> node = markedNodes.get(i); Derivation<TK,FV> finalDerivation = constructDerivation(node, sourceInputId, featurizer, scorer, heuristic, outputSpace); returnList.add(finalDerivation); } // Sort the return list returnList = returnList.stream().sorted().limit(size).collect(Collectors.toList()); // Apply distinctness after the sort. The ordering of markedNodes doesn't account for // combination costs. if (distinct) { IntSet uniqSet = new IntOpenHashSet(markedNodes.size()); List<Derivation<TK,FV>> uniqList = new ArrayList<>(returnList.size()); for (Derivation<TK,FV> d : returnList) { int hashCode = d.targetSequence.hashCode(); if (! uniqSet.contains(hashCode)) { uniqSet.add(hashCode); uniqList.add(d); } } returnList = uniqList; } // WSGDEBUG // System.err.printf("### %d: %d marked nodes ########%n", sourceInputId, markedNodes.size()); // System.err.println(prefix); // System.err.println(oneBest); // System.err.println("-------"); // returnList.stream().forEach(d -> { // System.err.println(d); // }); // if (returnList.get(0).score < oneBest.score) { // System.err.println(returnList.get(0)); // System.err.println(oneBest); // } return returnList; }
Example 17
Source File: RankSGDModel.java From jstarcraft-rns with Apache License 2.0 | 4 votes |
@Override protected void doPractice() { List<IntSet> userItemSet = getUserItemSet(scoreMatrix); for (int epocheIndex = 0; epocheIndex < epocheSize; epocheIndex++) { totalError = 0F; // for each rated user-item (u,i) pair for (MatrixScalar term : scoreMatrix) { int userIndex = term.getRow(); IntSet itemSet = userItemSet.get(userIndex); int positiveItemIndex = term.getColumn(); float positiveScore = term.getValue(); int negativeItemIndex = -1; do { // draw an item j with probability proportional to // popularity negativeItemIndex = SampleUtility.binarySearch(itemProbabilities, 0, itemProbabilities.getElementSize() - 1, RandomUtility.randomFloat(itemProbabilities.getValue(itemProbabilities.getElementSize() - 1))); // ensure that it is unrated by user u } while (itemSet.contains(negativeItemIndex)); float negativeScore = 0F; // compute predictions float error = (predict(userIndex, positiveItemIndex) - predict(userIndex, negativeItemIndex)) - (positiveScore - negativeScore); totalError += error * error; // update vectors float value = learnRatio * error; for (int factorIndex = 0; factorIndex < factorSize; factorIndex++) { float userFactor = userFactors.getValue(userIndex, factorIndex); float positiveItemFactor = itemFactors.getValue(positiveItemIndex, factorIndex); float negativeItemFactor = itemFactors.getValue(negativeItemIndex, factorIndex); userFactors.shiftValue(userIndex, factorIndex, -value * (positiveItemFactor - negativeItemFactor)); itemFactors.shiftValue(positiveItemIndex, factorIndex, -value * userFactor); itemFactors.shiftValue(negativeItemIndex, factorIndex, value * userFactor); } } totalError *= 0.5D; if (isConverged(epocheIndex) && isConverged) { break; } isLearned(epocheIndex); currentError = totalError; } }