Java Code Examples for burlap.behavior.singleagent.Episode#numTimeSteps()
The following examples show how to use
burlap.behavior.singleagent.Episode#numTimeSteps() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: LSPI.java From burlap with Apache License 2.0 | 6 votes |
@Override public Episode runLearningEpisode(Environment env, int maxSteps) { Episode ea = maxSteps != -1 ? PolicyUtils.rollout(this.learningPolicy, env, maxSteps) : PolicyUtils.rollout(this.learningPolicy, env); this.updateDatasetWithLearningEpisode(ea); if(this.shouldRereunPolicyIteration(ea)){ this.runPolicyIteration(this.maxNumPlanningIterations, this.maxChange); this.numStepsSinceLastLearningPI = 0; } else{ this.numStepsSinceLastLearningPI += ea.numTimeSteps()-1; } if(episodeHistory.size() >= numEpisodesToStore){ episodeHistory.poll(); } episodeHistory.offer(ea); return ea; }
Example 2
Source File: LearningAlgorithmExperimenter.java From burlap with Apache License 2.0 | 6 votes |
/** * Runs a trial for an agent generated by the given factor when interpreting trial length as a number of total steps. * @param agentFactory the agent factory used to generate the agent to test. */ protected void runStepBoundTrial(LearningAgentFactory agentFactory){ //temporarily disable plotter data collection to avoid possible contamination for any actions taken by the agent generation //(e.g., if there is pre-test training) this.plotter.toggleDataCollection(false); LearningAgent agent = agentFactory.generateAgent(); this.plotter.toggleDataCollection(true); //turn it back on to begin this.plotter.startNewTrial(); int stepsRemaining = this.trialLength; while(stepsRemaining > 0){ Episode ea = agent.runLearningEpisode(this.environmentSever, stepsRemaining); stepsRemaining -= ea.numTimeSteps()-1; //-1 because we want to subtract the number of actions, not the number of states seen this.plotter.endEpisode(); this.environmentSever.resetEnvironment(); } this.plotter.endTrial(); }
Example 3
Source File: PolicyUtils.java From burlap with Apache License 2.0 | 5 votes |
/** * Follows the policy in the given {@link burlap.mdp.singleagent.environment.Environment}. The policy will stop being followed once a terminal state * in the environment is reached or when the provided number of steps has been taken. * @param p the {@link Policy} * @param env The {@link burlap.mdp.singleagent.environment.Environment} in which this policy is to be evaluated. * @param numSteps the maximum number of steps to take in the environment. * @return An {@link Episode} object specifying the interaction with the environment. */ public static Episode rollout(Policy p, Environment env, int numSteps){ Episode ea = new Episode(env.currentObservation()); int nSteps; do{ followAndRecordPolicy(p, env, ea); nSteps = ea.numTimeSteps(); }while(!env.isInTerminalState() && nSteps < numSteps); return ea; }
Example 4
Source File: MLIRL.java From burlap with Apache License 2.0 | 5 votes |
/** * Computes and returns the log-likelihood of the given trajectory under the current reward function parameters and weights it by the given weight. * @param ea the trajectory * @param weight the weight to assign the trajectory * @return the log-likelihood of the given trajectory under the current reward function parameters and weights it by the given weight. */ public double logLikelihoodOfTrajectory(Episode ea, double weight){ double logLike = 0.; Policy p = new BoltzmannQPolicy((QProvider)this.request.getPlanner(), 1./this.request.getBoltzmannBeta()); for(int i = 0; i < ea.numTimeSteps()-1; i++){ this.request.getPlanner().planFromState(ea.state(i)); double actProb = p.actionProb(ea.state(i), ea.action(i)); logLike += Math.log(actProb); } logLike *= weight; return logLike; }
Example 5
Source File: MLIRL.java From burlap with Apache License 2.0 | 5 votes |
/** * Computes and returns the gradient of the log-likelihood of all trajectories * @return the gradient of the log-likelihood of all trajectories */ public FunctionGradient logLikelihoodGradient(){ HashedAggregator<Integer> gradientSum = new HashedAggregator<Integer>(); double [] weights = this.request.getEpisodeWeights(); List<Episode> exampleTrajectories = this.request.getExpertEpisodes(); for(int i = 0; i < exampleTrajectories.size(); i++){ Episode ea = exampleTrajectories.get(i); double weight = weights[i]; for(int t = 0; t < ea.numTimeSteps()-1; t++){ this.request.getPlanner().planFromState(ea.state(t)); FunctionGradient policyGrad = this.logPolicyGrad(ea.state(t), ea.action(t)); //weigh it by trajectory strength for(FunctionGradient.PartialDerivative pd : policyGrad.getNonZeroPartialDerivatives()){ double newVal = pd.value * weight; gradientSum.add(pd.parameterId, newVal); } } } FunctionGradient gradient = new FunctionGradient.SparseGradient(gradientSum.size()); for(Map.Entry<Integer, Double> e : gradientSum.entrySet()){ gradient.put(e.getKey(), e.getValue()); } return gradient; }
Example 6
Source File: LSPI.java From burlap with Apache License 2.0 | 5 votes |
/** * Updates this object's {@link SARSData} to include the results of a learning episode. * @param ea the learning episode as an {@link Episode} object. */ protected void updateDatasetWithLearningEpisode(Episode ea){ if(this.dataset == null){ this.dataset = new SARSData(ea.numTimeSteps()-1); } for(int i = 0; i < ea.numTimeSteps()-1; i++){ this.dataset.add(ea.state(i), ea.action(i), ea.reward(i+1), ea.state(i+1)); } }
Example 7
Source File: LSPI.java From burlap with Apache License 2.0 | 5 votes |
/** * Returns whether LSPI should be rereun given the latest learning episode results. Default behavior is to return true * if the number of leanring episode steps plus the number of steps since the last run is greater than the {@link #numStepsSinceLastLearningPI} threshold. * @param ea the most recent learning episode * @return true if LSPI should be rerun; false otherwise. */ protected boolean shouldRereunPolicyIteration(Episode ea){ if(this.numStepsSinceLastLearningPI+ea.numTimeSteps()-1 > this.minNewStepsForLearningPI){ return true; } return false; }
Example 8
Source File: TrainingHelper.java From burlap_caffe with Apache License 2.0 | 4 votes |
public void run() { int testCountDown = testInterval; int snapshotCountDown = snapshotInterval; long trainingStart = System.currentTimeMillis(); int trainingSteps = 0; while (stepCounter < totalTrainingSteps) { long epStartTime = 0; if (verbose) { System.out.println(String.format("Training Episode %d at step %d", episodeCounter, stepCounter)); epStartTime = System.currentTimeMillis(); } // Set variables needed for training prepareForTraining(); env.resetEnvironment(); // run learning episode Episode ea = learner.runLearningEpisode(env, Math.min(totalTrainingSteps - stepCounter, maxEpisodeSteps)); // add up episode reward double totalReward = 0; for (double r : ea.rewardSequence) { totalReward += r; } if (verbose) { // output episode data long epEndTime = System.currentTimeMillis(); double timeInterval = (epEndTime - epStartTime)/1000.0; System.out.println(String.format("Episode reward: %.2f -- %.1f steps/sec", totalReward, ea.numTimeSteps()/timeInterval)); System.out.println(); } // take snapshot every snapshotCountDown steps stepCounter += ea.numTimeSteps(); trainingSteps += ea.numTimeSteps(); episodeCounter++; if (snapshotPrefix != null) { snapshotCountDown -= ea.numTimeSteps(); if (snapshotCountDown <= 0) { saveLearningState(snapshotPrefix); snapshotCountDown += snapshotInterval; } } // take test set every testCountDown steps testCountDown -= ea.numTimeSteps(); if (testCountDown <= 0) { double trainingTimeInterval = (System.currentTimeMillis() - trainingStart)/1000.0; // run test set runTestSet(); testCountDown += testInterval; // output training rate System.out.printf("Training rate: %.1f steps/sec\n\n", testInterval/trainingTimeInterval); // restart training timer trainingStart = System.currentTimeMillis(); } } if (testOutput != null) { testOutput.printf("Final best: %.2f\n", highestAverageReward); testOutput.flush(); } System.out.println("Done Training!"); }
Example 9
Source File: TrainingHelper.java From burlap_caffe with Apache License 2.0 | 4 votes |
public void runTestSet() { long testStart = System.currentTimeMillis(); int numSteps = 0; int numEpisodes = 0; // Change any learning variables to test values (i.e. experience memory) prepareForTesting(); // Run the test policy on test episodes System.out.println("Running Test Set..."); double totalTestReward = 0; while (true) { env.resetEnvironment(); Episode e = tester.runTestEpisode(env, Math.min(maxEpisodeSteps, totalTestSteps - numSteps)); double totalReward = 0; for (double reward : e.rewardSequence) { totalReward += reward; } if (verbose) { System.out.println(String.format("%d: Reward = %.2f, Steps = %d", numEpisodes, totalReward, numSteps)); } numSteps += e.numTimeSteps(); if (numSteps >= totalTestSteps) { if (numEpisodes == 0) { totalTestReward = totalReward; numEpisodes = 1; } break; } totalTestReward += totalReward; numEpisodes += 1; } double averageReward = totalTestReward/numEpisodes; if (averageReward > highestAverageReward) { if (resultsPrefix != null) { vfa.snapshot(new File(resultsPrefix, "best_net.caffemodel").toString(), null); } highestAverageReward = averageReward; } double testTimeInterval = (System.currentTimeMillis() - testStart)/1000.0; System.out.printf("Average Test Reward: %.2f -- highest: %.2f, Test rate: %.1f\n\n", averageReward, highestAverageReward, numSteps/testTimeInterval); if (testOutput != null) { testOutput.printf("Frame %d: %.2f\n", stepCounter, averageReward); testOutput.flush(); } }
Example 10
Source File: BeliefSparseSampling.java From burlap with Apache License 2.0 | 3 votes |
public static void main(String [] args){ TigerDomain tiger = new TigerDomain(true); PODomain domain = (PODomain)tiger.generateDomain(); BeliefState initialBelief = TigerDomain.getInitialBeliefState(domain); BeliefSparseSampling bss = new BeliefSparseSampling(domain, 0.99, new ReflectiveHashableStateFactory(), 10, -1); Policy p = new GreedyQPolicy(bss); SimulatedPOEnvironment env = new SimulatedPOEnvironment(domain); env.setCurStateTo(new TigerState(TigerDomain.VAL_LEFT)); BeliefPolicyAgent agent = new BeliefPolicyAgent(domain, env, p); agent.setBeliefState(initialBelief); agent.setEnvironment(env); /* State initialBeliefStateOb = BeliefMDPGenerator.getBeliefMDPState(bss.getBeliefMDP(), initialBelief); List<QValue> qs = bss.getQs(initialBeliefStateOb); for(QValue q : qs){ System.out.println(q.a.toString() + ": " + q.q); } */ Episode ea = agent.actUntilTerminalOrMaxSteps(30); for(int i = 0; i < ea.numTimeSteps()-1; i++){ System.out.println(ea.action(i) + " " + ea.reward(i+1)); } }