Tutorial: Basic Planning and Learning

Tutorials (v1) > Basic Planning and Learning > Part 6

Experimenter Tools and Performance Plotting

In the previous section you learned how to use an ActionObserver to perform live visualization of the agent in the state space as it was learning. In this section we will make use another ActionObserver called PerformancePlotter to record a learning algorithm's performance and compare it to another learning algorithm. The PerformacnePlotter has a lot of powerful tools to display lots of important experimental results. To streamline the process, we will make use of the LearningAlgorithmExperimenter class, which is convenient for comparing the performance of multiple learning algorithms over many trials. If you'd prefer to run your own experiment using a different design flow, however, you could make use of the PerformancePlotter directly yourself.

To demonstrate the LearningAlgorithmExperimenter, we will compare the learning performance of a Q-learning algorithm with the performance of a SARSA(λ) algorithm. To make the results a bit more interesting to visualize, we will also use a different reward function than we have been that returns a reward of 5 when the goal is reached and -0.1 for every other step. First, add the following imports:

import burlap.oomdp.auxiliary.StateGenerator;
import burlap.oomdp.auxiliary.common.ConstantStateGenerator;
import burlap.behavior.singleagent.auxiliary.performance.LearningAlgorithmExperimenter;
import burlap.behavior.singleagent.auxiliary.performance.PerformanceMetric;
import burlap.behavior.singleagent.auxiliary.performance.TrialMode;
				

Lets then begin by adding a new method to our BasicBehavior class to call to test the experimenter tools. Inside the method, we will setup the new reward function using an instance of the GoalBasedRF which takes a StateCondition test object to specify goal conditions (which we have already set up previously in the tutorial for our search algorithms like BFS), a goal reward, and a default reward for all non-goal states.

public void experimenterAndPlotter(){
		
	//custom reward function for more interesting results
	final RewardFunction rf = new GoalBasedRF(this.goalCondition, 5., -0.1);


}
				

For the LearningAlgorithmExperimenter to report an average performance of each algorithm, it will test the algorithm over multiple trials. Therefore, at the start of each trial a clean agent instance of the algorithm without knowledge of any of the previous trials must be generated. To be able to easily get a clean version of each agent the LearningAlgorithmExperimenter will request a sequence of LearningAgentFactory objects that can be used to generate a clean version of each agent on demand. In the following code, we create a LearningAgentFactory for a Q-learning algorithm and a SARSA(λ) algorithm.

/**
 * Create factories for Q-learning agent and SARSA agent to compare
 */

LearningAgentFactory qLearningFactory = new LearningAgentFactory() {
	
	@Override
	public String getAgentName() {
		return "Q-learning";
	}
	
	@Override
	public LearningAgent generateAgent() {
		return new QLearning(domain, rf, tf, 0.99, hashingFactory, 0.3, 0.1);
	}
};


LearningAgentFactory sarsaLearningFactory = new LearningAgentFactory() {
	
	@Override
	public String getAgentName() {
		return "SARSA";
	}
	
	@Override
	public LearningAgent generateAgent() {
		return new SarsaLam(domain, rf, tf, 0.99, hashingFactory, 0.0, 0.1, 1.);
	}
};
				

Note that the factory also requires a getAgentName() method to be implemented. The LearningAgentExperimenter class will use this name to label the results of each learning algorithm's performance.

For experimentation, the LearningAlgorithmExperimenter will also need to able to generate a new state for the beginning of each episode. In some tasks we might consider generating a random start state for each episode, but for simplifying purposes we will consider a task that always starts in the bottom left hand corner of the world, like we did in our previous planning and learning examples of the this tutorial. We can make a state generator that always returns the same initial state using the BURLAP provided ConstantStateGenerator.

StateGenerator sg = new ConstantStateGenerator(this.initialState);
				

With the task and learning algorithms now decided, we are just about ready to create our experimenter, but first we will need to decide how many trials we want to test, the length of the trials, and what performance data we want to plot. There are six possible performance metrics we could plot:

  1. cumulative reward per step,
  2. cumulative reward per episode,
  3. average reward per episode,
  4. median reward per episode,
  5. cumulative steps per episode, and
  6. steps per episode.

Furthermore, we could also have our plotter display the results from the most recent trial only, the average performance across all trials, or both. For this tutorial, we will plot both the most recent trial and average trial performance for the cumulative steps per episode and the average reward per episode. We will also test the algorithms for 10 trials that last 100 episodes each. The below code will create our experimenter, start it, and also save all the data for all six metrics to CSV files.

LearningAlgorithmExperimenter exp = new LearningAlgorithmExperimenter(this.domain, rf, sg, 
	10, 100, qLearningFactory, sarsaLearningFactory);

exp.setUpPlottingConfiguration(500, 250, 2, 1000, 
	TrialMode.MOSTRECENTANDAVERAGE, 
	PerformanceMetric.CUMULATIVESTEPSPEREPISODE, 
	PerformanceMetric.AVERAGEEPISODEREWARD);

exp.startExperiment();

exp.writeStepAndEpisodeDataToCSV("expData");
				

Note that in the constructor, the LearningAgent factories are the last parameters, of which a variable number of factories could be provided; we could have tested just one agent, or we could have tested many more, but there should naturally always be at least one LearningAgentFactory provided.

The other important part of the code is the setUpPlottingConfiguration method, which is used to define what results are plotted and how they are displayed. The first four parameters specify a plot's width and height, the number of columns of plots, and the maximum window height. In this case, plots are set to be 500x200, with two columns of plots. Plots are placed columns first, wrapping down to a new row as needed. The window size will be scaled to the width of plots times the number of columns; the height will scale to the height of the plots times the number of rows, unless that height is greater than the maximum window height, in which case the plots will be placed in a scroll view. The next parameter specifies whether to show plots for only the most recent trial, the average performance over all trials, or both. We have selected to show both. The remaining parameters are variable in size and specify which performance metrics will be plotted. The order of the performance metrics provided also dictates the order that the plots will fill the window (again, filling columns first).

The startExperiment method begins the experiment which will run all trials for all learning algorithms provided.

Once you point our BasicBevhavior main method to the new method we've created and run the code, A GUI should appear with the plots requested, displaying the performance as it is available. When the experiment is complete, you should be left with an image like the below.

In the trial average plots, you'll note that a translucent filled area around each of the curves is present. This filled area shows the 95% confidence interval. You can change the significance level used before running the experiment using the setPlotCISignificance method.

Note that these plots are not static and you can interact with them. If you click drag in a region, it will cause the plot to zoom into the selected area. If you right click, you'll find a number of other options that you can set, including changing the labels. Another important feature you'll see from the contextual menu is an option to save plot image to disk.

Since we also told the LearningAlgorithmExperimenter object to save the data to csv files, you should find two files that it created: expDataSteps.csv and expDataEpisodes.csv. The first contains all trial data for the cumulative reward per step metric. The latter contains all of the episode-wise metric data (even for the metrics that we did not plot). These files will make interacting with the data in another program, such as R, convenient.

Conclusion

This ends our tutorial on implementing basic planning and learning algorithms in BURLAP. There are other planning and learning algorithms in BURLAP, but hopefully this tutorial has explained the core concepts well enough that you should be able to try different algorithms easily. As a future exercise, we encourage you to try just that within this code you've created! The complete set of code that we wrote in this tutorial is shown below for your convenience.

You can also download the .java file direction from here.


import java.awt.Color;
import java.util.List;

import burlap.behavior.singleagent.*;
import burlap.domain.singleagent.gridworld.*;
import burlap.oomdp.core.*;
import burlap.oomdp.singleagent.*;
import burlap.oomdp.singleagent.common.*;
import burlap.behavior.statehashing.DiscreteStateHashFactory;
import burlap.behavior.singleagent.learning.*;
import burlap.behavior.singleagent.learning.tdmethods.*;
import burlap.behavior.singleagent.planning.*;
import burlap.behavior.singleagent.planning.commonpolicies.GreedyQPolicy;
import burlap.behavior.singleagent.planning.deterministic.*;
import burlap.behavior.singleagent.planning.deterministic.informed.Heuristic;
import burlap.behavior.singleagent.planning.deterministic.informed.astar.AStar;
import burlap.behavior.singleagent.planning.deterministic.uninformed.bfs.BFS;
import burlap.behavior.singleagent.planning.deterministic.uninformed.dfs.DFS;
import burlap.behavior.singleagent.planning.stochastic.valueiteration.ValueIteration;
import burlap.oomdp.visualizer.Visualizer;
import burlap.oomdp.auxiliary.StateGenerator;
import burlap.oomdp.auxiliary.StateParser;
import burlap.oomdp.auxiliary.common.ConstantStateGenerator;
import burlap.behavior.singleagent.EpisodeSequenceVisualizer;
import burlap.behavior.singleagent.auxiliary.StateReachability;
import burlap.behavior.singleagent.auxiliary.performance.LearningAlgorithmExperimenter;
import burlap.behavior.singleagent.auxiliary.performance.PerformanceMetric;
import burlap.behavior.singleagent.auxiliary.performance.TrialMode;
import burlap.behavior.singleagent.auxiliary.valuefunctionvis.ValueFunctionVisualizerGUI;
import burlap.behavior.singleagent.auxiliary.valuefunctionvis.common.*;
import burlap.behavior.singleagent.auxiliary.valuefunctionvis.common.PolicyGlyphPainter2D.PolicyGlyphRenderStyle;
import burlap.oomdp.singleagent.common.VisualActionObserver;

public class BasicBehavior {

	
	
	GridWorldDomain 			gwdg;
	Domain						domain;
	StateParser 				sp;
	RewardFunction 				rf;
	TerminalFunction			tf;
	StateConditionTest			goalCondition;
	State 						initialState;
	DiscreteStateHashFactory	hashingFactory;
	
	
	public static void main(String[] args) {
		
		
		BasicBehavior example = new BasicBehavior();
		String outputPath = "output/"; 
		
		
		//uncomment the example you want to see (and comment-out the rest)
		
		example.BFSExample(outputPath);
		//example.DFSExample(outputPath);
		//example.AStarExample(outputPath);
		//example.ValueIterationExample(outputPath);
		//example.QLearningExample(outputPath);
		//example.SarsaLearningExample(outputPath);
		//example.experimenterAndPlotter();
		
		
		//run the visualizer (only use if you don't use the experiment plotter example)
		example.visualize(outputPath);

	}
	
	
	public BasicBehavior(){
	
		//create the domain
		gwdg = new GridWorldDomain(11, 11);
		gwdg.setMapToFourRooms(); 
		domain = gwdg.generateDomain();
		
		//create the state parser
		sp = new GridWorldStateParser(domain); 
		
		//define the task
		rf = new UniformCostRF(); 
		tf = new SinglePFTF(domain.getPropFunction(GridWorldDomain.PFATLOCATION)); 
		goalCondition = new TFGoalCondition(tf);
		
		//set up the initial state of the task
		initialState = GridWorldDomain.getOneAgentOneLocationState(domain);
		GridWorldDomain.setAgent(initialState, 0, 0);
		GridWorldDomain.setLocation(initialState, 0, 10, 10);
		
		//set up the state hashing system
		hashingFactory = new DiscreteStateHashFactory();
		hashingFactory.setAttributesForClass(GridWorldDomain.CLASSAGENT, 
				domain.getObjectClass(GridWorldDomain.CLASSAGENT).attributeList); 
				
				
		//add visual observer
		VisualActionObserver observer = new VisualActionObserver(domain, 
			GridWorldVisualizer.getVisualizer(gwdg.getMap()));
		((SADomain)this.domain).setActionObserverForAllAction(observer);
		observer.initGUI();		
		
			
	}
	
	
	public void visualize(String outputPath){
		Visualizer v = GridWorldVisualizer.getVisualizer(gwdg.getMap());
		EpisodeSequenceVisualizer evis = new EpisodeSequenceVisualizer(v, 
								domain, sp, outputPath);
	}
	
	
	
	
	public void BFSExample(String outputPath){
		
		if(!outputPath.endsWith("/")){
			outputPath = outputPath + "/";
		}
		
		//BFS ignores reward; it just searches for a goal condition satisfying state
		DeterministicPlanner planner = new BFS(domain, goalCondition, hashingFactory); 
		planner.planFromState(initialState);
		
		//capture the computed plan in a partial policy
		Policy p = new SDPlannerPolicy(planner);
		
		//record the plan results to a file
		p.evaluateBehavior(initialState, rf, tf).writeToFile(outputPath + "planResult", sp);
			
	}				



	public void DFSExample(String outputPath){
		
		if(!outputPath.endsWith("/")){
			outputPath = outputPath + "/";
		}
		
		//DFS ignores reward; it just searches for a goal condition satisfying state
		DeterministicPlanner planner = new DFS(domain, goalCondition, hashingFactory);
		planner.planFromState(initialState);
		
		//capture the computed plan in a partial policy
		Policy p = new SDPlannerPolicy(planner);
		
		//record the plan results to a file
		p.evaluateBehavior(initialState, rf, tf).writeToFile(outputPath + "planResult", sp);
		
	}
	
	
	
	
	public void AStarExample(String outputPath){
		
		if(!outputPath.endsWith("/")){
			outputPath = outputPath + "/";
		}
		
		
		Heuristic mdistHeuristic = new Heuristic() {
			
			@Override
			public double h(State s) {
				
				String an = GridWorldDomain.CLASSAGENT;
				String ln = GridWorldDomain.CLASSLOCATION;
				
				ObjectInstance agent = s.getObjectsOfTrueClass(an).get(0); 
				ObjectInstance location = s.getObjectsOfTrueClass(ln).get(0); 
				
				//get agent position
				int ax = agent.getDiscValForAttribute(GridWorldDomain.ATTX);
				int ay = agent.getDiscValForAttribute(GridWorldDomain.ATTY);
				
				//get location position
				int lx = location.getDiscValForAttribute(GridWorldDomain.ATTX);
				int ly = location.getDiscValForAttribute(GridWorldDomain.ATTY);
				
				//compute Manhattan distance
				double mdist = Math.abs(ax-lx) + Math.abs(ay-ly);
				
				return -mdist;
			}
		};
		
		//provide A* the heuristic as well as the reward function so that it can keep
		//track of the actual cost
		DeterministicPlanner planner = new AStar(domain, rf, goalCondition, 
			hashingFactory, mdistHeuristic);
		planner.planFromState(initialState);
		
		
		//capture the computed plan in a partial policy
		Policy p = new SDPlannerPolicy(planner);
		
		//record the plan results to a file
		p.evaluateBehavior(initialState, rf, tf).writeToFile(outputPath + "planResult", sp);
		
	}	
	
	
	
	public void ValueIterationExample(String outputPath){
		
		if(!outputPath.endsWith("/")){
			outputPath = outputPath + "/";
		}
		
		
		OOMDPPlanner planner = new ValueIteration(domain, rf, tf, 0.99, hashingFactory,
								0.001, 100);
		planner.planFromState(initialState);
		
		//create a Q-greedy policy from the planner
		Policy p = new GreedyQPolicy((QComputablePlanner)planner);
		
		//record the plan results to a file
		p.evaluateBehavior(initialState, rf, tf).writeToFile(outputPath + "planResult", sp);
		
		//visualize the value function and policy
		this.valueFunctionVisualize((QComputablePlanner)planner, p);
		
	}
	
	
	public void QLearningExample(String outputPath){
		
		if(!outputPath.endsWith("/")){
			outputPath = outputPath + "/";
		}
		
		//discount= 0.99; initialQ=0.0; learning rate=0.9
		LearningAgent agent = new QLearning(domain, rf, tf, 0.99, hashingFactory, 0., 0.9);
		
		//run learning for 100 episodes
		for(int i = 0; i < 100; i++){
			EpisodeAnalysis ea = agent.runLearningEpisodeFrom(initialState);
			ea.writeToFile(String.format("%se%03d", outputPath, i), sp); 
			System.out.println(i + ": " + ea.numTimeSteps());
		}
		
	}
	
	
	public void SarsaLearningExample(String outputPath){
		
		if(!outputPath.endsWith("/")){
			outputPath = outputPath + "/";
		}
		
		//discount= 0.99; initialQ=0.0; learning rate=0.5; lambda=1.0
		LearningAgent agent = new SarsaLam(domain, rf, tf, 0.99, hashingFactory,
						0., 0.5, 1.0);
		
		//run learning for 100 episodes
		for(int i = 0; i < 100; i++){
			EpisodeAnalysis ea = agent.runLearningEpisodeFrom(initialState);
			ea.writeToFile(String.format("%se%03d", outputPath, i), sp);
			System.out.println(i + ": " + ea.numTimeSteps());
		}
		
	}	
	
	
	
	public void valueFunctionVisualize(QComputablePlanner planner, Policy p){
		List <State> allStates = StateReachability.getReachableStates(initialState, 
			(SADomain)domain, hashingFactory);
		LandmarkColorBlendInterpolation rb = new LandmarkColorBlendInterpolation();
		rb.addNextLandMark(0., Color.RED);
		rb.addNextLandMark(1., Color.BLUE);
		
		StateValuePainter2D svp = new StateValuePainter2D(rb);
		svp.setXYAttByObjectClass(GridWorldDomain.CLASSAGENT, GridWorldDomain.ATTX, 
			GridWorldDomain.CLASSAGENT, GridWorldDomain.ATTY);
		
		PolicyGlyphPainter2D spp = new PolicyGlyphPainter2D();
		spp.setXYAttByObjectClass(GridWorldDomain.CLASSAGENT, GridWorldDomain.ATTX, 
			GridWorldDomain.CLASSAGENT, GridWorldDomain.ATTY);
		spp.setActionNameGlyphPainter(GridWorldDomain.ACTIONNORTH, new ArrowActionGlyph(0));
		spp.setActionNameGlyphPainter(GridWorldDomain.ACTIONSOUTH, new ArrowActionGlyph(1));
		spp.setActionNameGlyphPainter(GridWorldDomain.ACTIONEAST, new ArrowActionGlyph(2));
		spp.setActionNameGlyphPainter(GridWorldDomain.ACTIONWEST, new ArrowActionGlyph(3));
		spp.setRenderStyle(PolicyGlyphRenderStyle.DISTSCALED);
		
		ValueFunctionVisualizerGUI gui = new ValueFunctionVisualizerGUI(allStates, svp, planner);
		gui.setSpp(spp);
		gui.setPolicy(p);
		gui.setBgColor(Color.GRAY);
		gui.initGUI();
	}



	public void experimenterAndPlotter(){
		
		//custom reward function for more interesting results
		final RewardFunction rf = new GoalBasedRF(this.goalCondition, 5., -0.1);

		/**
		 * Create factories for Q-learning agent and SARSA agent to compare
		 */

		LearningAgentFactory qLearningFactory = new LearningAgentFactory() {
			
			@Override
			public String getAgentName() {
				return "Q-learning";
			}
			
			@Override
			public LearningAgent generateAgent() {
				return new QLearning(domain, rf, tf, 0.99, hashingFactory, 0.3, 0.1);
			}
		};


		LearningAgentFactory sarsaLearningFactory = new LearningAgentFactory() {
			
			@Override
			public String getAgentName() {
				return "SARSA";
			}
			
			@Override
			public LearningAgent generateAgent() {
				return new SarsaLam(domain, rf, tf, 0.99, hashingFactory, 0.0, 0.1, 1.);
			}
		};

		StateGenerator sg = new ConstantStateGenerator(this.initialState);

		LearningAlgorithmExperimenter exp = new LearningAlgorithmExperimenter((SADomain)this.domain, 
			rf, sg, 10, 100, qLearningFactory, sarsaLearningFactory);

		exp.setUpPlottingConfiguration(500, 250, 2, 1000, 
			TrialMode.MOSTRECENTANDAVERAGE, 
			PerformanceMetric.CUMULATIVESTEPSPEREPISODE, 
			PerformanceMetric.AVERAGEEPISODEREWARD);

		exp.startExperiment();

		exp.writeStepAndEpisodeDataToCSV("expData");


	}


	


	
}

				
End.