Tutorial: Creating a Planning and Learning Algorithm

Tutorials > Creating a Planning and Learning Algorithm > Part 2

VI Code

Lets start by creating our class for VI, which we'll call VITutorial. Our class will extend MDPSolver, to gain many of the useful datastructures used in solving an MDP, and it will implement the Planner and QFunction interfaces. The former because we will implement the planFromState method and the latter because ValueIteration computes and ValueFunction and QFunction (the QFunction interface extends the ValueFunction interface). We will also add all the imports we will need in developing this class.

import burlap.behavior.policy.GreedyQPolicy;
import burlap.behavior.policy.Policy;
import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.auxiliary.EpisodeSequenceVisualizer;
import burlap.behavior.singleagent.auxiliary.StateReachability;
import burlap.behavior.singleagent.planning.Planner;
import burlap.behavior.valuefunction.QFunction;
import burlap.behavior.valuefunction.QValue;
import burlap.behavior.valuefunction.ValueFunctionInitialization;
import burlap.domain.singleagent.gridworld.GridWorldDomain;
import burlap.domain.singleagent.gridworld.GridWorldTerminalFunction;
import burlap.domain.singleagent.gridworld.GridWorldVisualizer;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.core.TransitionProbability;
import burlap.oomdp.core.states.State;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import burlap.oomdp.singleagent.SADomain;
import burlap.oomdp.singleagent.common.UniformCostRF;
import burlap.oomdp.statehashing.HashableState;
import burlap.oomdp.statehashing.HashableStateFactory;
import burlap.oomdp.statehashing.SimpleHashableStateFactory;
import burlap.oomdp.visualizer.Visualizer;

import java.util.*;

public class VITutorial extends MDPSolver implements Planner, QFunction{


	@Override
	public double value(State s) {
		return 0.;
	}

	@Override
	public List<QValue> getQs(State s) {
		// TODO Auto-generated method stub
		return null;
	}

	@Override
	public QValue getQ(State s, AbstractGroundedAction a) {
		// TODO Auto-generated method stub
		return null;
	}

	@Override
	public Policy planFromState(State initialState) {
		// TODO Auto-generated method stub

	}

	@Override
	public void resetSolverResults() {
		// TODO Auto-generated method stub

	}

}
				

Because we are sub classing MDPSolver, this object will auto create data members that define our domain and task (the Domain, RewardFunction, TerminalFunction, discount factor, and HashableStateFactory that is used to hash and check the equality of states). However, the other critical data that VI needs to store are its estimates of the value function! A value function is ultimately a mapping from states to a real value. Therefore, for fast access we can use a HashMap and use a HashableStateFactory to provide HashableState instances from states. One way to make VI run faster is to inititialize its value funciton to something close to the optimal value function. BURLAP has a ValueFuncitonInitialization interface that can be provided to our code for choosing initialization values. We'll also have a parameter that specifies how long value iteration should run before it terminates (there are others to test for convergence that we will not cover here). Lets create datamembers for these elements and create a constructor.

protected Map<HashableState, Double> valueFunction;
protected ValueFunctionInitialization vinit;
protected int numIterations;


public VITutorial(Domain domain, RewardFunction rf, TerminalFunction tf, double gamma,
				  HashableStateFactory hashingFactory, ValueFunctionInitialization vinit, 
				  int numIterations){
	this.solverInit(domain, rf, tf, gamma, hashingFactory);
	this.vinit = vinit;
	this.numIterations = numIterations;
	this.valueFunction = new HashMap<HashableState, Double>();
}
				

Note that since our MDPSolver superclass will hold our data members for the domain, reward function, terminal function, discount factor, and HashableStateFactory, we can initialize them with its solverInit method.

There is one other critical component VI needs that isn't part of the data we've given it in the constructor: the full state space! One reason we might not want to demand this upfront is because in an OO-MDP, it is possible for the state space to be infinite even though for any given input state there may only be a finite set of states that are reachable. We could require the user to provide to our algorithm up front what the state space is, but it's much easier on the client if we determine the set of possible reachable states for any given seed state ourself and only perform this procedure when planning is requested for a previously unseen state. Lets define a method to get all reachable states from an input state and initialize the value for them with our ValueFunctionInitialization object. Add the below method.

public void performReachabilityFrom(State seedState){

	Set<HashableState> hashedStates = StateReachability.getReachableHashedStates(seedState, 
											(SADomain) this.domain, this.hashingFactory);
	//initialize the value function for all states
	for(HashableState hs : hashedStates){
		if(!this.valueFunction.containsKey(hs)){
			this.valueFunction.put(hs, this.vinit.value(hs.s));
		}
	}

}
				

In the first line, we make use of BURLAP's StateReachability tool to do the heavy lifting of finding all reachable states. Then we simply iterate through the list, and for every HashableState for which we do not already have an entry, we initialize it with the value returned from the ValueFunctionInitialization. You may notice that the value function is passed hs.s. Since our set of states are actually a set of HashableState instances, we retrieve the underlying State object stored in the HashableState by its .s member.


The other method we'll need to implement is the Bellman Equation. As noted on the previous page, the Bellman Equation is just a max over the Q-values and since we already have methods defined for getting the Q-value of states (a requirement of implementing the QFunction interface), we will implement those methods and a Bellman Equation method next.

@Override
public List<QValue> getQs(State s) {
	List<GroundedAction> applicableActions = this.getAllGroundedActions(s);
	List<QValue> qs = new ArrayList<QValue>(applicableActions.size());
	for(GroundedAction ga : applicableActions){
		qs.add(this.getQ(s, ga));
	}
	return qs;
}

@Override
public QValue getQ(State s, AbstractGroundedAction a) {

	//type cast to the type we're using
	GroundedAction ga = (GroundedAction)a;

	//what are the possible outcomes?
	List<TransitionProbability> tps = ga.getTransitions(s);

	//aggregate over each possible outcome
	double q = 0.;
	for(TransitionProbability tp : tps){
		//what is reward for this transition?
		double r = this.rf.reward(s, ga, tp.s);

		//what is the value for the next state?
		double vp = this.valueFunction.get(this.hashingFactory.hashState(tp.s));

		//add contribution weighted by transition probabiltiy and
		//discounting the next state
		q += tp.p * (r + this.gamma * vp);
	}

	//create Q-value wrapper
	QValue qValue = new QValue(s, ga, q);

	return qValue;
}

protected double bellmanEquation(State s){

	if(this.tf.isTerminal(s)){
		return 0.;
	}

	List<QValue> qs = this.getQs(s);
	double maxQ = Double.NEGATIVE_INFINITY;
	for(QValue q : qs){
		maxQ = Math.max(maxQ, q.q);
	}
	return maxQ;
}
				

You'll note that the Q-value methods return QValue objects, which are just triples consisting of a State object, an AbstractGroundedAction object, and a double for the Q-value associated with them.

AbstractGroundedAction versus Action

You might wonder why we're using AbstractGroundedAction references for actions, rather than a Action instance that we subclassed to define actions in the Building a Domain Tutorial. However, recall that the Action class is used for defining actions, whereas the GroundedAction class is a reference to an Action definition that also contains any action parameter selections necessary to execute the action. Since actions could be parameterized, we use implementations of the general AbstractGroundedAction interface, of which GroundedAction is an implementation, to reason about decisions, or in this case, estimate the Q-value for the action selection.

In the getQs method, we simply find all possible grounded actions (using a method ineheretied from MDPSolver which we extended; alternatively, we could use an Action static method that takes is list of Aciton objects and State and returns all applicable groundings), ask our getQ method what the Q-value is, and then return the list of all those Q-values.

In the getQ method, we find all possible transitions from the input state and weigh the value of those outcomes by the probability of the transition occurring. The value of each outcome is the reward received, and the discounted value we have estimated for the outcome state.

In the bellmanEquation method, we in general just return the maximum Q-value for the state; however, there is a catch. That is, if the input state is a terminal state, then by definition of it being a terminal state the value is zero, because the idea of a terminal state is that no action can follow from it. Therefore, if the state is a terminal state, we return a value of 0 and ignore whatever the domain object would say the possible transitions would be. Note that this check is not just a performance saver; all terminal states are specified by the TerminalFunction interface, so we must always refer to it to handle terminal states and cannot expect that a domain's transition dynamics have it baked in.


We now have all the tools we need to do planning, so it's time to implement the planFromStateMethod. This method is called whenever a client wants to run planning from a given initial (or seed) state. What we'll do then is first check if we've already performed planning that includes that state. If so, we'll do nothing, having assumed to already have computed the value for it. However, if we haven't seen it before, then we'll first find all reachable states from it, and then run value iteration for a given number of iterations. As a reminder, running value iteration means making iterative sweeps over the entire state space in which the value of each state is re-estimated to what the Bellman equation says it is given the previously estimated value of the states. Finally, all planFromState methods require return a suitable Policy object to use the planning results. For value iteration, assuming it converged, the optimal policy is to select the action with the highest Q-value; therefore we'll return a GreedyQPolicy object. GreedyQPolicy objects need to be told what their QFunction source is, which in this case, is the instance of our class.

@Override
public GreedyQPolicy planFromState(State initialState) {
	
	HashableState hashedInitialState = this.hashingFactory.hashState(initialState);
	if(this.valueFunction.containsKey(hashedInitialState)){
		return new GreedyQPolicy(this); //already performed planning here!
	}

	//if the state is new, then find all reachable states from it first
	this.performReachabilityFrom(initialState);

	//now perform multiple iterations over the whole state space
	for(int i = 0; i < this.numIterations; i++){
		//iterate over each state
		for(HashableState sh : this.valueFunction.keySet()){
			//update its value using the bellman equation
			this.valueFunction.put(sh, this.bellmanEquation(sh.s));
		}
	}

	return new GreedyQPolicy(this);

}
				

We're now just about finished! The only thing left is that each MDPSolver instance is asked to implement the method resetSolverResults, which when called should have the effect of resetting all data so that it's as if no planning calls had ever been made. For our VI implementation, all this requires is clearing our value function.

@Override
public void resetSolverResults() {
	this.valueFunction.clear();
}
				

Testing VI

To test our code, you can try using this planning algorithm with the grid world task created in the previous Basic Planning and Learning tutorial.Alternatively, below is a main method that you can add to test your VI implementation that creates a stochastic grid world, plans for it, and evaluates a single rollout of the resulting policy and visualizes the results.

public static void main(String [] args){
		
	GridWorldDomain gwd = new GridWorldDomain(11, 11);
	gwd.setMapToFourRooms();

	//only go in intended directon 80% of the time
	gwd.setProbSucceedTransitionDynamics(0.8);

	Domain domain = gwd.generateDomain();

	//get initial state with agent in 0,0
	State s = GridWorldDomain.getOneAgentNoLocationState(domain);
	GridWorldDomain.setAgent(s, 0, 0);

	//all transitions return -1
	RewardFunction rf = new UniformCostRF();

	//terminate in top right corner
	TerminalFunction tf = new GridWorldTerminalFunction(10, 10);

	//setup vi with 0.99 discount factor, a value
	//function initialization that initializes all states to value 0, and which will
	//run for 30 iterations over the state space
	VITutorial vi = new VITutorial(domain, rf, tf, 0.99, new SimpleHashableStateFactory(),
			new ValueFunctionInitialization.ConstantValueFunctionInitialization(0.0), 30);

	//run planning from our initial state
	Policy p = vi.planFromState(s);
	
	//evaluate the policy with one roll out visualize the trajectory
	EpisodeAnalysis ea = p.evaluateBehavior(s, rf, tf);

	Visualizer v = GridWorldVisualizer.getVisualizer(gwd.getMap());
	new EpisodeSequenceVisualizer(v, domain, Arrays.asList(ea));
		
}
				

If you're looking to extend this tutorial on VI a little more, you might consider implementing a more intelligent VI termination condition so that rather than always running VI for a fixed number of iterations, VI terminates if the maximum change in the value function is smaller than some small threshold. Otherwise, it's now time to move on to our Q-learning example! If you'd like to see the full code we wrote all together, jump to the end of this tutorial.