Java Code Examples for weka.core.Utils#maxIndex()
The following examples show how to use
weka.core.Utils#maxIndex() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: Metrics.java From meka with GNU General Public License v3.0 | 6 votes |
/** * OneError - */ public static double L_OneError(int Y[][], double Rpred[][]) { // works with missing int N = Y.length; int one_error = 0; int missing = 0; for(int i = 0; i < N; i++) { if(allMissing(Y[i])){ missing ++; continue; } if(Y[i][Utils.maxIndex(Rpred[i])] == 0) one_error++; } N-= missing; if (N == 0) { return Double.NaN; } return (double)one_error/(double)N; }
Example 2
Source File: DMNBtext.java From tsml with GNU General Public License v3.0 | 6 votes |
/** * Calculates the class membership probabilities for the given test * instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if there is a problem generating the prediction */ public double[] distributionForInstance(Instance instance) throws Exception { if (m_numClasses == 2) { // System.out.println(m_binaryClassifiers[0].getProbForTargetClass(instance)); return m_binaryClassifiers[0].distributionForInstance(instance); } double[] logDocGivenClass = new double[instance.numClasses()]; for (int i = 0; i < m_numClasses; i++) logDocGivenClass[i] = m_binaryClassifiers[i].getLogProbForTargetClass(instance); double max = logDocGivenClass[Utils.maxIndex(logDocGivenClass)]; for(int i = 0; i<m_numClasses; i++) logDocGivenClass[i] = Math.exp(logDocGivenClass[i] - max); try { Utils.normalize(logDocGivenClass); } catch (Exception e) { e.printStackTrace(); } return logDocGivenClass; }
Example 3
Source File: CR.java From meka with GNU General Public License v3.0 | 6 votes |
@Override public double[] distributionForInstance(Instance x) throws Exception { int L = x.classIndex(); double y[] = new double[L*2]; for (int j = 0; j < L; j++) { Instance x_j = (Instance)x.copy(); x_j.setDataset(null); x_j = MLUtils.keepAttributesAt(x_j,new int[]{j},L); x_j.setDataset(m_Templates[j]); double w[] = m_MultiClassifiers[j].distributionForInstance(x_j); // e.g. [0.1, 0.8, 0.1] y[j] = Utils.maxIndex(w); // e.g. 1 y[L+j] = w[(int)y[j]]; // e.g. 0.8 } return y; }
Example 4
Source File: PSUtils.java From meka with GNU General Public License v3.0 | 6 votes |
/** * Convert Distribution - Given the posterior across combinations, return the distribution across labels. * <br> * TODO Use recombination!!! * @see PSUtils#recombination(double[],int,LabelSet[]) * @param p the posterior of the super classes (combinations), e.g., P([1,3],[2]) = [1,0] * @param L the number of labels * @return the distribution across labels, e.g., P(1,2,3) = [1,0,1] */ @Deprecated public static double[] convertDistribution(double p[], int L, Instances iTemplate) { double y[] = new double[L]; int i = Utils.maxIndex(p); double d[] = toDoubleArray(iTemplate.classAttribute().value(i),L); for(int j = 0; j < d.length; j++) { if(d[j] > 0.0) y[j] = 1.0; } return y; }
Example 5
Source File: Iadem3.java From moa with GNU General Public License v3.0 | 5 votes |
@Override public Node learnFromInstance(Instance inst) { double[] classVote = getMajorityClassVotes(inst); double error = (Utils.maxIndex(classVote) == (int) inst.classValue()) ? 0.0 : 1.0; this.majorityClassError.input(error); classVote = getNaiveBayesPrediction(inst); error = (Utils.maxIndex(classVote) == (int) inst.classValue()) ? 0.0 : 1.0; this.naiveBayesError.input(error); return super.learnFromInstance(inst); }
Example 6
Source File: DynamicWeightedMajority.java From moa with GNU General Public License v3.0 | 5 votes |
@Override public double[] getVotesForInstance(Instance inst) { double[] Pr = new double[inst.numClasses()]; for (int i = 0; i < this.experts.size(); i++) { double[] pr = this.experts.get(i).getVotesForInstance(inst); int yHat = Utils.maxIndex(pr); Pr[yHat] += this.weights.get(i); } // for Utils.normalize(Pr); return Pr; }
Example 7
Source File: NSR.java From meka with GNU General Public License v3.0 | 5 votes |
@Override public double[] distributionForInstance(Instance x) throws Exception { int L = x.classIndex(); //if there is only one class (as for e.g. in some hier. mtds) predict it //if(L == 1) return new double[]{1.0}; Instance x_sl = PSUtils.convertInstance(x,L,m_InstancesTemplate); // the sl instance //x_sl.setDataset(m_InstancesTemplate); // where y in {comb_1,comb_2,...,comb_k} double w[] = m_Classifier.distributionForInstance(x_sl); // w[j] = p(y_j) for each j = 1,...,L int max_j = Utils.maxIndex(w); // j of max w[j] //int max_j = (int)m_Classifier.classifyInstance(x_sl); // where comb_i is selected String y_max = m_InstancesTemplate.classAttribute().value(max_j); // comb_i e.g. "0+3+0+0+1+2+0+0" double y[] = Arrays.copyOf(A.toDoubleArray(MLUtils.decodeValue(y_max)),L*2); // "0+3+0+0+1+2+0+0" -> [0.0,3.0,0.0,...,0.0] HashMap<Double,Double> votes[] = new HashMap[L]; for(int j = 0; j < L; j++) { votes[j] = new HashMap<Double,Double>(); } for(int i = 0; i < w.length; i++) { double y_i[] = A.toDoubleArray(MLUtils.decodeValue(m_InstancesTemplate.classAttribute().value(i))); for(int j = 0; j < y_i.length; j++) { votes[j].put(y_i[j] , votes[j].containsKey(y_i[j]) ? votes[j].get(y_i[j]) + w[i] : w[i]); } } // some confidence information for(int j = 0; j < L; j++) { y[j+L] = votes[j].size() > 0 ? Collections.max(votes[j].values()) : 0.0; } return y; }
Example 8
Source File: RandomTree.java From tsml with GNU General Public License v3.0 | 5 votes |
/** * Outputs a leaf. * * @return the leaf as string * @throws Exception if generation fails */ protected String leafString() throws Exception { double sum = 0, maxCount = 0; int maxIndex = 0; if (m_ClassDistribution != null) { sum = Utils.sum(m_ClassDistribution); maxIndex = Utils.maxIndex(m_ClassDistribution); maxCount = m_ClassDistribution[maxIndex]; } return " : " + m_Info.classAttribute().value(maxIndex) + " (" + Utils.doubleToString(sum, 2) + "/" + Utils.doubleToString(sum - maxCount, 2) + ")"; }
Example 9
Source File: Logistic.java From tsml with GNU General Public License v3.0 | 5 votes |
/** * Evaluate objective function * @param x the current values of variables * @return the value of the objective function */ protected double objectiveFunction(double[] x){ double nll = 0; // -LogLikelihood int dim = m_NumPredictors+1; // Number of variables per class for(int i=0; i<cls.length; i++){ // ith instance double[] exp = new double[m_NumClasses-1]; int index; for(int offset=0; offset<m_NumClasses-1; offset++){ index = offset * dim; for(int j=0; j<dim; j++) exp[offset] += m_Data[i][j]*x[index + j]; } double max = exp[Utils.maxIndex(exp)]; double denom = Math.exp(-max); double num; if (cls[i] == m_NumClasses - 1) { // Class of this instance num = -max; } else { num = exp[cls[i]] - max; } for(int offset=0; offset<m_NumClasses-1; offset++){ denom += Math.exp(exp[offset] - max); } nll -= weights[i]*(num - Math.log(denom)); // Weighted NLL } // Ridge: note that intercepts NOT included for(int offset=0; offset<m_NumClasses-1; offset++){ for(int r=1; r<dim; r++) nll += m_Ridge*x[offset*dim+r]*x[offset*dim+r]; } return nll; }
Example 10
Source File: Iadem3.java From moa with GNU General Public License v3.0 | 5 votes |
private void updateCountersForChange(Instance inst) { double[] classVotes = this.getClassVotes(inst); boolean trueClass = (Utils.maxIndex(classVotes) == (int) inst.classValue()); if (estimador != null && ((Iadem3) this.tree).restartAtDrift) { double error = trueClass == true ? 0.0 : 1.0; this.estimador.input(error); if (this.estimador.getChange()) { this.resetVariablesAtDrift(); } } }
Example 11
Source File: PLSFilter.java From tsml with GNU General Public License v3.0 | 5 votes |
/** * determines the dominant eigenvector for the given matrix and returns it * * @param m the matrix to determine the dominant eigenvector for * @return the dominant eigenvector */ protected Matrix getDominantEigenVector(Matrix m) { EigenvalueDecomposition eigendecomp; double[] eigenvalues; int index; Matrix result; eigendecomp = m.eig(); eigenvalues = eigendecomp.getRealEigenvalues(); index = Utils.maxIndex(eigenvalues); result = columnAsVector(eigendecomp.getV(), index); return result; }
Example 12
Source File: Iadem3.java From moa with GNU General Public License v3.0 | 5 votes |
private void updateCounters(Instance experiencia) { double[] classVotes = this.getClassVotes(experiencia); boolean trueClass = (Utils.maxIndex(classVotes) == (int) experiencia.classValue()); if (estimator != null && ((Iadem3) this.tree).restartAtDrift) { double error = trueClass == true ? 0.0 : 1.0; this.estimator.input(error); if (this.estimator.getChange()) { this.restartVariablesAtDrift(); } } }
Example 13
Source File: Iadem3.java From moa with GNU General Public License v3.0 | 4 votes |
@Override public double[] getClassVotes(Instance instance) { double[] votes = super.getClassVotes(instance); this.lastPrediction = Utils.maxIndex(votes); return votes; }
Example 14
Source File: XML.java From tsml with GNU General Public License v3.0 | 4 votes |
/** * Store the prediction made by the classifier as a string. * * @param dist the distribution to use * @param inst the instance to generate text from * @param index the index in the dataset * @throws Exception if something goes wrong */ protected void doPrintClassification(double[] dist, Instance inst, int index) throws Exception { int prec = m_NumDecimals; Instance withMissing = (Instance)inst.copy(); withMissing.setDataset(inst.dataset()); double predValue = 0; if (Utils.sum(dist) == 0) { predValue = Utils.missingValue(); } else { if (inst.classAttribute().isNominal()) { predValue = Utils.maxIndex(dist); } else { predValue = dist[0]; } } // opening tag append(" <" + TAG_PREDICTION + " " + ATT_INDEX + "=\"" + (index+1) + "\">\n"); if (inst.dataset().classAttribute().isNumeric()) { // actual append(" <" + TAG_ACTUAL_VALUE + ">"); if (inst.classIsMissing()) append("?"); else append(Utils.doubleToString(inst.classValue(), prec)); append("</" + TAG_ACTUAL_VALUE + ">\n"); // predicted append(" <" + TAG_PREDICTED_VALUE + ">"); if (inst.classIsMissing()) append("?"); else append(Utils.doubleToString(predValue, prec)); append("</" + TAG_PREDICTED_VALUE + ">\n"); // error append(" <" + TAG_ERROR + ">"); if (Utils.isMissingValue(predValue) || inst.classIsMissing()) append("?"); else append(Utils.doubleToString(predValue - inst.classValue(), prec)); append("</" + TAG_ERROR + ">\n"); } else { // actual append(" <" + TAG_ACTUAL_LABEL + " " + ATT_INDEX + "=\"" + ((int) inst.classValue()+1) + "\"" + ">"); append(sanitize(inst.toString(inst.classIndex()))); append("</" + TAG_ACTUAL_LABEL + ">\n"); // predicted append(" <" + TAG_PREDICTED_LABEL + " " + ATT_INDEX + "=\"" + ((int) predValue+1) + "\"" + ">"); if (Utils.isMissingValue(predValue)) append("?"); else append(sanitize(inst.dataset().classAttribute().value((int)predValue))); append("</" + TAG_PREDICTED_LABEL + ">\n"); // error? append(" <" + TAG_ERROR + ">"); if (!Utils.isMissingValue(predValue) && !inst.classIsMissing() && ((int) predValue+1 != (int) inst.classValue()+1)) append(VAL_YES); else append(VAL_NO); append("</" + TAG_ERROR + ">\n"); // prediction/distribution if (m_OutputDistribution) { append(" <" + TAG_DISTRIBUTION + ">\n"); for (int n = 0; n < dist.length; n++) { append(" <" + TAG_CLASS_LABEL + " " + ATT_INDEX + "=\"" + (n+1) + "\""); if (!Utils.isMissingValue(predValue) && (n == (int) predValue)) append(" " + ATT_PREDICTED + "=\"" + VAL_YES + "\""); append(">"); append(Utils.doubleToString(dist[n], prec)); append("</" + TAG_CLASS_LABEL + ">\n"); } append(" </" + TAG_DISTRIBUTION + ">\n"); } else { append(" <" + TAG_PREDICTION + ">"); if (Utils.isMissingValue(predValue)) append("?"); else append(Utils.doubleToString(dist[(int)predValue], prec)); append("</" + TAG_PREDICTION + ">\n"); } } // attributes if (m_Attributes != null) append(attributeValuesString(withMissing)); // closing tag append(" </" + TAG_PREDICTION + ">\n"); }
Example 15
Source File: MINND.java From tsml with GNU General Public License v3.0 | 4 votes |
/** * Pre-process the given exemplar according to the other exemplars * in the given exemplars. It also updates noise data statistics. * * @param data the whole exemplars * @param pos the position of given exemplar in data * @return the processed exemplar * @throws Exception if the returned exemplar is wrong */ public Instance preprocess(Instances data, int pos) throws Exception{ Instance before = data.instance(pos); if((int)before.classValue() == 0){ m_NoiseM[pos] = null; m_NoiseV[pos] = null; return before; } Instances after_relationInsts =before.attribute(1).relation().stringFreeStructure(); Instances noises_relationInsts =before.attribute(1).relation().stringFreeStructure(); Instances newData = m_Attributes; Instance after = new DenseInstance(before.numAttributes()); Instance noises = new DenseInstance(before.numAttributes()); after.setDataset(newData); noises.setDataset(newData); for(int g=0; g < before.relationalValue(1).numInstances(); g++){ Instance datum = before.relationalValue(1).instance(g); double[] dists = new double[data.numInstances()]; for(int i=0; i < data.numInstances(); i++){ if(i != pos) dists[i] = distance(datum, m_Mean[i], m_Variance[i], i); else dists[i] = Double.POSITIVE_INFINITY; } int[] pred = new int[m_NumClasses]; for(int n=0; n < pred.length; n++) pred[n] = 0; for(int o=0; o<m_Select; o++){ int index = Utils.minIndex(dists); pred[(int)m_Class[index]]++; dists[index] = Double.POSITIVE_INFINITY; } int clas = Utils.maxIndex(pred); if((int)before.classValue() != clas) noises_relationInsts.add(datum); else after_relationInsts.add(datum); } int relationValue; relationValue = noises.attribute(1).addRelation( noises_relationInsts); noises.setValue(0,before.value(0)); noises.setValue(1, relationValue); noises.setValue(2, before.classValue()); relationValue = after.attribute(1).addRelation( after_relationInsts); after.setValue(0,before.value(0)); after.setValue(1, relationValue); after.setValue(2, before.classValue()); if(Utils.gr(noises.relationalValue(1).sumOfWeights(), 0)){ for (int i=0; i<m_Dimension; i++) { m_NoiseM[pos][i] = noises.relationalValue(1).meanOrMode(i); m_NoiseV[pos][i] = noises.relationalValue(1).variance(i); if(Utils.eq(m_NoiseV[pos][i],0.0)) m_NoiseV[pos][i] = m_ZERO; } /* for(int y=0; y < m_NoiseV[pos].length; y++){ if(Utils.eq(m_NoiseV[pos][y],0.0)) m_NoiseV[pos][y] = m_ZERO; } */ } else{ m_NoiseM[pos] = null; m_NoiseV[pos] = null; } return after; }
Example 16
Source File: Grading.java From tsml with GNU General Public License v3.0 | 4 votes |
/** * Returns class probabilities for a given instance using the stacked classifier. * One class will always get all the probability mass (i.e. probability one). * * @param instance the instance to be classified * @throws Exception if instance could not be classified * successfully * @return the class distribution for the given instance */ public double[] distributionForInstance(Instance instance) throws Exception { double maxPreds; int numPreds=0; int numClassifiers=m_Classifiers.length; int idxPreds; double [] predConfs = new double[numClassifiers]; double [] preds; for (int i=0; i<numClassifiers; i++) { preds = m_MetaClassifiers[i].distributionForInstance(metaInstance(instance,i)); if (m_MetaClassifiers[i].classifyInstance(metaInstance(instance,i))==1) predConfs[i]=preds[1]; else predConfs[i]=-preds[0]; } if (predConfs[Utils.maxIndex(predConfs)]<0.0) { // no correct classifiers for (int i=0; i<numClassifiers; i++) // use neg. confidences instead predConfs[i]=1.0+predConfs[i]; } else { for (int i=0; i<numClassifiers; i++) // otherwise ignore neg. conf if (predConfs[i]<0) predConfs[i]=0.0; } /*System.out.print(preds[0]); System.out.print(":"); System.out.print(preds[1]); System.out.println("#");*/ preds=new double[instance.numClasses()]; for (int i=0; i<instance.numClasses(); i++) preds[i]=0.0; for (int i=0; i<numClassifiers; i++) { idxPreds=(int)(m_Classifiers[i].classifyInstance(instance)); preds[idxPreds]+=predConfs[i]; } maxPreds=preds[Utils.maxIndex(preds)]; int MaxInstPerClass=-100; int MaxClass=-1; for (int i=0; i<instance.numClasses(); i++) { if (preds[i]==maxPreds) { numPreds++; if (m_InstPerClass[i]>MaxInstPerClass) { MaxInstPerClass=(int)m_InstPerClass[i]; MaxClass=i; } } } int predictedIndex; if (numPreds==1) predictedIndex = Utils.maxIndex(preds); else { // System.out.print("?"); // System.out.print(instance.toString()); // for (int i=0; i<instance.numClasses(); i++) { // System.out.print("/"); // System.out.print(preds[i]); // } // System.out.println(MaxClass); predictedIndex = MaxClass; } double[] classProbs = new double[instance.numClasses()]; classProbs[predictedIndex] = 1.0; return classProbs; }
Example 17
Source File: ReplaceMissingValues.java From tsml with GNU General Public License v3.0 | 4 votes |
/** * Signify that this batch of input to the filter is finished. * If the filter requires all instances prior to filtering, * output() may now be called to retrieve the filtered instances. * * @return true if there are instances pending output * @throws IllegalStateException if no input structure has been defined */ public boolean batchFinished() { if (getInputFormat() == null) { throw new IllegalStateException("No input instance format defined"); } if (m_ModesAndMeans == null) { // Compute modes and means double sumOfWeights = getInputFormat().sumOfWeights(); double[][] counts = new double[getInputFormat().numAttributes()][]; for (int i = 0; i < getInputFormat().numAttributes(); i++) { if (getInputFormat().attribute(i).isNominal()) { counts[i] = new double[getInputFormat().attribute(i).numValues()]; if (counts[i].length > 0) counts[i][0] = sumOfWeights; } } double[] sums = new double[getInputFormat().numAttributes()]; for (int i = 0; i < sums.length; i++) { sums[i] = sumOfWeights; } double[] results = new double[getInputFormat().numAttributes()]; for (int j = 0; j < getInputFormat().numInstances(); j++) { Instance inst = getInputFormat().instance(j); for (int i = 0; i < inst.numValues(); i++) { if (!inst.isMissingSparse(i)) { double value = inst.valueSparse(i); if (inst.attributeSparse(i).isNominal()) { if (counts[inst.index(i)].length > 0) { counts[inst.index(i)][(int)value] += inst.weight(); counts[inst.index(i)][0] -= inst.weight(); } } else if (inst.attributeSparse(i).isNumeric()) { results[inst.index(i)] += inst.weight() * inst.valueSparse(i); } } else { if (inst.attributeSparse(i).isNominal()) { if (counts[inst.index(i)].length > 0) { counts[inst.index(i)][0] -= inst.weight(); } } else if (inst.attributeSparse(i).isNumeric()) { sums[inst.index(i)] -= inst.weight(); } } } } m_ModesAndMeans = new double[getInputFormat().numAttributes()]; for (int i = 0; i < getInputFormat().numAttributes(); i++) { if (getInputFormat().attribute(i).isNominal()) { if (counts[i].length == 0) m_ModesAndMeans[i] = Utils.missingValue(); else m_ModesAndMeans[i] = (double)Utils.maxIndex(counts[i]); } else if (getInputFormat().attribute(i).isNumeric()) { if (Utils.gr(sums[i], 0)) { m_ModesAndMeans[i] = results[i] / sums[i]; } } } // Convert pending input instances for(int i = 0; i < getInputFormat().numInstances(); i++) { convertInstance(getInputFormat().instance(i)); } } // Free memory flushInput(); m_NewBatch = true; return (numPendingOutput() != 0); }
Example 18
Source File: Logistic.java From tsml with GNU General Public License v3.0 | 4 votes |
/** * Evaluate Jacobian vector * @param x the current values of variables * @return the gradient vector */ protected double[] evaluateGradient(double[] x){ double[] grad = new double[x.length]; int dim = m_NumPredictors+1; // Number of variables per class for(int i=0; i<cls.length; i++){ // ith instance double[] num=new double[m_NumClasses-1]; // numerator of [-log(1+sum(exp))]' int index; for(int offset=0; offset<m_NumClasses-1; offset++){ // Which part of x double exp=0.0; index = offset * dim; for(int j=0; j<dim; j++) exp += m_Data[i][j]*x[index + j]; num[offset] = exp; } double max = num[Utils.maxIndex(num)]; double denom = Math.exp(-max); // Denominator of [-log(1+sum(exp))]' for(int offset=0; offset<m_NumClasses-1; offset++){ num[offset] = Math.exp(num[offset] - max); denom += num[offset]; } Utils.normalize(num, denom); // Update denominator of the gradient of -log(Posterior) double firstTerm; for(int offset=0; offset<m_NumClasses-1; offset++){ // Which part of x index = offset * dim; firstTerm = weights[i] * num[offset]; for(int q=0; q<dim; q++){ grad[index + q] += firstTerm * m_Data[i][q]; } } if(cls[i] != m_NumClasses-1){ // Not the last class for(int p=0; p<dim; p++){ grad[cls[i]*dim+p] -= weights[i]*m_Data[i][p]; } } } // Ridge: note that intercepts NOT included for(int offset=0; offset<m_NumClasses-1; offset++){ for(int r=1; r<dim; r++) grad[offset*dim+r] += 2*m_Ridge*x[offset*dim+r]; } return grad; }
Example 19
Source File: DecisionTreeModel.java From collective-classification-weka-package with GNU General Public License v3.0 | 4 votes |
/** * Outputs one node for graph. * * @param text the buffer to add the data to * @param num the node number * @param node the node * @return the new node number * @throws Exception if something goes wrong */ public int toGraph(StringBuffer text, int num, DecisionTreeNode node) throws Exception { double[] classprobs = node.getClassProbabilities(); int att = node.getAttribute(); double splitpoint = node.getSplitPoint(); Instances info = node.getInformation(); int maxIndex = Utils.maxIndex(classprobs); String classValue = info.classAttribute().value(maxIndex); num++; if (att == -1) { text.append("N" + Integer.toHexString(hashCode()) + " [label=\"" + num + ": " + classValue + "\"" + "shape=box]\n"); } else { text.append("N" + Integer.toHexString(hashCode()) + " [label=\"" + num + ": " + classValue + "\"]\n"); for (int i = 0; i < node.getChildCount(); i++) { text.append("N" + Integer.toHexString(hashCode()) + "->" + "N" + Integer.toHexString(node.getNodeAt(i).hashCode()) + " [label=\"" + info.attribute(att).name()); if (info.attribute(att).isNumeric()) { if (i == 0) { text.append(" < " + Utils.doubleToString(splitpoint, 2)); } else { text.append(" >= " + Utils.doubleToString(splitpoint, 2)); } } else { // split for every nominal value? if (node.getNominalSplit() == null) { text.append( " = " + info.attribute(att).value(i)); } else { text.append(" = ["); for (int n = 0; n < node.getNominalSplit()[i].length; n++) { if (n > 0) text.append(", "); text.append( info.attribute(att).value((int) node.getNominalSplit()[i][n])); } text.append("]"); } } text.append("\"]\n"); num = toGraph(text, num, node.getNodeAt(i)); } } return num; }
Example 20
Source File: REPTree.java From tsml with GNU General Public License v3.0 | 4 votes |
/** * Inserts an instance from the hold-out set into the tree. * * @param inst the instance to insert * @param weight the weight of the instance * @param parent the parent of the node * @throws Exception if insertion fails */ protected void insertHoldOutInstance(Instance inst, double weight, Tree parent) throws Exception { // Insert instance into hold-out class distribution if (inst.classAttribute().isNominal()) { // Nominal case m_HoldOutDist[(int)inst.classValue()] += weight; int predictedClass = 0; if (m_ClassProbs == null) { predictedClass = Utils.maxIndex(parent.m_ClassProbs); } else { predictedClass = Utils.maxIndex(m_ClassProbs); } if (predictedClass != (int)inst.classValue()) { m_HoldOutError += weight; } } else { // Numeric case m_HoldOutDist[0] += weight; m_HoldOutDist[1] += weight * inst.classValue(); double diff = 0; if (m_ClassProbs == null) { diff = parent.m_ClassProbs[0] - inst.classValue(); } else { diff = m_ClassProbs[0] - inst.classValue(); } m_HoldOutError += diff * diff * weight; } // The process is recursive if (m_Attribute != -1) { // If node is not a leaf if (inst.isMissing(m_Attribute)) { // Distribute instance for (int i = 0; i < m_Successors.length; i++) { if (m_Prop[i] > 0) { m_Successors[i].insertHoldOutInstance(inst, weight * m_Prop[i], this); } } } else { if (m_Info.attribute(m_Attribute).isNominal()) { // Treat nominal attributes m_Successors[(int)inst.value(m_Attribute)]. insertHoldOutInstance(inst, weight, this); } else { // Treat numeric attributes if (inst.value(m_Attribute) < m_SplitPoint) { m_Successors[0].insertHoldOutInstance(inst, weight, this); } else { m_Successors[1].insertHoldOutInstance(inst, weight, this); } } } } }