Lets start by creating our class for VI, which we'll call VITutorial. Our class will extend MDPSolver, to gain many of the useful datastructures used in solving an MDP, and it will implement the Planner and QProvider interfaces. The former because we will implement the planFromState method and the latter because Value Iteration computes the value function from which Q-values can be computed (the QProvider interface extends the QFunction interface, which in turns extends the ValueFunction interface. QFunction adds a method to ValueFunction get the Q-value for a state-action pair, and QProvider provides a method to return all Q-values for an input state). We will also add all the imports we will need in developing this class.
import burlap.behavior.policy.GreedyQPolicy; import burlap.behavior.policy.Policy; import burlap.behavior.policy.PolicyUtils; import burlap.behavior.singleagent.Episode; import burlap.behavior.singleagent.MDPSolver; import burlap.behavior.singleagent.auxiliary.EpisodeSequenceVisualizer; import burlap.behavior.singleagent.auxiliary.StateReachability; import burlap.behavior.singleagent.planning.Planner; import burlap.behavior.valuefunction.ConstantValueFunction; import burlap.behavior.valuefunction.QProvider; import burlap.behavior.valuefunction.QValue; import burlap.behavior.valuefunction.ValueFunction; import burlap.domain.singleagent.gridworld.GridWorldDomain; import burlap.domain.singleagent.gridworld.GridWorldTerminalFunction; import burlap.domain.singleagent.gridworld.GridWorldVisualizer; import burlap.domain.singleagent.gridworld.state.GridAgent; import burlap.domain.singleagent.gridworld.state.GridWorldState; import burlap.mdp.core.action.Action; import burlap.mdp.core.state.State; import burlap.mdp.singleagent.SADomain; import burlap.mdp.singleagent.model.FullModel; import burlap.mdp.singleagent.model.TransitionProb; import burlap.statehashing.HashableState; import burlap.statehashing.HashableStateFactory; import burlap.statehashing.simple.SimpleHashableStateFactory; import burlap.visualizer.Visualizer; import java.util.*; public class VITutorial extends MDPSolver implements Planner, QProvider{ @Override public double value(State s) { return 0.; } @Override public List<QValue> qValues(State s) { // TODO Auto-generated method stub return null; } @Override public double qValue(State s, Action a) { // TODO Auto-generated method stub return 0.; } @Override public Policy planFromState(State initialState) { // TODO Auto-generated method stub } @Override public void resetSolver() { // TODO Auto-generated method stub } }
Because we are sub classing MDPSolver, this object will auto create data members that define our domain and task (the Domain, discount factor, and HashableStateFactory that is used to hash and check the equality of states). However, the other critical data that VI needs to store are its estimates of the value function! A value function is ultimately a mapping from states to a real value. Therefore, for fast access we can use a HashMap and use a HashableStateFactory to provide HashableState instances from states. One way to make VI run faster is to inititialize its value funciton to something close to the optimal value function. Therefore, we can also accept another ValueFunction to use as the initial value function. We'll also have a parameter that specifies how long value iteration should run before it terminates (there are others to test for convergence that we will not cover here). Lets create datamembers for these elements and create a constructor.
protected Map<HashableState, Double> valueFunction; protected ValueFunction vinit; protected int numIterations; public VITutorial(SADomain domain, double gamma, HashableStateFactory hashingFactory, ValueFunction vinit, int numIterations){ this.solverInit(domain, gamma, hashingFactory); this.vinit = vinit; this.numIterations = numIterations; this.valueFunction = new HashMap<HashableState, Double>(); }
Note that since our MDPSolver superclass will hold our data members for the domain, discount factor, and HashableStateFactory, we can initialize them with its solverInit method.
There is one other critical component VI needs that isn't part of the data we've given it in the constructor: the full state space! One reason we might not want to demand this upfront is because in an MDP, it is possible for the state space to be infinite even though for any given input state there may only be a finite set of states that are reachable. We could require the user to provide to our algorithm up front what the state space is, but it's much easier on the client if we determine the set of possible reachable states for any given seed state ourself and only perform this procedure when planning is requested for a previously unseen state. Lets define a method to get all reachable states from an input state and initialize the value for them with our ValueFunctionInitialization object. Add the below method.
public void performReachabilityFrom(State seedState){ Set<HashableState> hashedStates = StateReachability.getReachableHashedStates(seedState, this.domain, this.hashingFactory); //initialize the value function for all states for(HashableState hs : hashedStates){ if(!this.valueFunction.containsKey(hs)){ this.valueFunction.put(hs, this.vinit.value(hs.s())); } } }
In the first line, we make use of BURLAP's StateReachability tool to do the heavy lifting of finding all reachable states. Then we simply iterate through the list, and for every HashableState for which we do not already have an entry, we initialize it with the value returned from the ValueFunction we use for initialization. You may notice that the value function is passed hs.s(). Since our set of states are actually a set of HashableState instances, we retrieve the underlying State object stored in the HashableState by its s() method.
The other method we'll need to implement is the Bellman Equation. As noted on the previous page, the Bellman Equation is just a max over the Q-values and since we already have methods defined for getting the Q-value of states (a requirement of implementing the QProvider interface), we will implement those methods and a Bellman Equation method next.
@Override public List<QValue> qValues(State s) { List<Action> applicableActions = this.applicableActions(s); List<QValue> qs = new ArrayList<QValue>(applicableActions.size()); for(Action a : applicableActions){ qs.add(new QValue(s, a, this.qValue(s, a))); } return qs; } @Override public double qValue(State s, Action a) { if(this.model.terminal(s)){ return 0.; } //what are the possible outcomes? List<TransitionProb> tps = ((FullModel)this.model).transitions(s, a); //aggregate over each possible outcome double q = 0.; for(TransitionProb tp : tps){ //what is reward for this transition? double r = tp.eo.r; //what is the value for the next state? double vp = this.valueFunction.get(this.hashingFactory.hashState(tp.eo.op)); //add contribution weighted by transition probability and //discounting the next state q += tp.p * (r + this.gamma * vp); } return q; }
You'll note that the qValues method returns a list of QValue objects, which are just triples consisting of a State object, an Action object, and a double for the Q-value associated with them.
In the qValues method, we simply find all possible grounded actions (using a method inherited from MDPSolver which we extended. Alternatively, we could use an ActionUtils method that takes is list of Action objects and State and returns all applicable groundings), ask our qValue method what the Q-value is, and then return the list of all those Q-values.
In the qValue method, we first ask our model whether the input state is terminal. If it is, then the Q-value must be 0, because that is the definition of a terminal state. Otherwise, we find all possible transitions from the input state and weigh the value of those outcomes by the probability of the transition occurring. The value of each outcome is the reward received, and the discounted value we have estimated for the outcome state.
You might wonder where the model data member comes from. Because we are extending the MDPSolver class, when we called the solverInit method, it automatically unpacked the model included with the domain into a model data member that we can use. This is convenient, because we also allow a client to change the model the solver uses to something other than when comes out of the domain object with the setModel method. Note that the model cast to the super interface SampleModel. To perform dynamic programming, we require a FullModel, and we assume the model is of that type, so we type cast to that and call the FullModel transitions method.
We now have all the tools we need to do planning, so it's time to implement the planFromState method. This method is called whenever a client wants to run planning from a given initial (or seed) state. What we'll do then is first check if we've already performed planning that includes that state. If so, we'll do nothing, having assumed to already have computed the value for it. However, if we haven't seen it before, then we'll first find all reachable states from it, and then run value iteration for a given number of iterations. As a reminder, running value iteration means making iterative sweeps over the entire state space in which the value of each state is re-estimated to what the Bellman equation says it is given the previously estimated value of the states. The Bellman equation is just the maximum Q-value, and we can call the QProvider helper method to get the maximum Q-value from objects that implement QProvider, which our class does! Finally, all planFromState methods require return a suitable Policy object to use the planning results. For value iteration, assuming it converged, the optimal policy is to select the action with the highest Q-value; therefore, we'll return a GreedyQPolicy object. GreedyQPolicy objects need to be told what their QFunction source is, which in this case, is the instance of our class.
@Override public GreedyQPolicy planFromState(State initialState) { HashableState hashedInitialState = this.hashingFactory.hashState(initialState); if(this.valueFunction.containsKey(hashedInitialState)){ return new GreedyQPolicy(this); //already performed planning here! } //if the state is new, then find all reachable states from it first this.performReachabilityFrom(initialState); //now perform multiple iterations over the whole state space for(int i = 0; i < this.numIterations; i++){ //iterate over each state for(HashableState sh : this.valueFunction.keySet()){ //update its value using the bellman equation this.valueFunction.put(sh, QProvider.Helper.maxQ(this, sh.s())); } } return new GreedyQPolicy(this); }
We're now just about finished! The only thing left is that each MDPSolver instance is asked to implement the method resetSolver, which when called should have the effect of resetting all data so that it's as if no planning calls had ever been made. For our VI implementation, all this requires is clearing our value function.
@Override public void resetSolver() { this.valueFunction.clear(); }
To test our code, you can try using this planning algorithm with the grid world task created in the previous Basic Planning and Learning tutorial. Alternatively, below is a main method that you can add to test your VI implementation that creates a stochastic grid world, plans for it, and evaluates a single rollout of the resulting policy and visualizes the results.
public static void main(String [] args){ GridWorldDomain gwd = new GridWorldDomain(11, 11); gwd.setTf(new GridWorldTerminalFunction(10, 10)); gwd.setMapToFourRooms(); //only go in intended directon 80% of the time gwd.setProbSucceedTransitionDynamics(0.8); SADomain domain = gwd.generateDomain(); //get initial state with agent in 0,0 State s = new GridWorldState(new GridAgent(0, 0)); //setup vi with 0.99 discount factor, a value //function initialization that initializes all states to value 0, and which will //run for 30 iterations over the state space VITutorial vi = new VITutorial(domain, 0.99, new SimpleHashableStateFactory(), new ConstantValueFunction(0.0), 30); //run planning from our initial state Policy p = vi.planFromState(s); //evaluate the policy with one roll out visualize the trajectory Episode ea = PolicyUtils.rollout(p, s, domain.getModel()); Visualizer v = GridWorldVisualizer.getVisualizer(gwd.getMap()); new EpisodeSequenceVisualizer(v, domain, Arrays.asList(ea)); }
If you're looking to extend this tutorial on VI a little more, you might consider implementing a more intelligent VI termination condition so that rather than always running VI for a fixed number of iterations, VI terminates if the maximum change in the value function is smaller than some small threshold. Otherwise, it's now time to move on to our Q-learning example! If you'd like to see the full code we wrote all together, jump to the end of this tutorial.