Lets start by creating our class for VI, which we'll call VITutorial, and all the methods it will need to implement to satisfy the OOMDPPlanner and QComputablePlanner class/interface.
import burlap.behavior.singleagent.QValue; import burlap.behavior.singleagent.planning.OOMDPPlanner; import burlap.behavior.singleagent.planning.QComputablePlanner; import burlap.oomdp.core.AbstractGroundedAction; import burlap.oomdp.core.State; public class VITutorial extends OOMDPPlanner implements QComputablePlanner { @Override public List<QValue> getQs(State s) { // TODO Auto-generated method stub return null; } @Override public QValue getQ(State s, AbstractGroundedAction a) { // TODO Auto-generated method stub return null; } @Override public void planFromState(State initialState) { // TODO Auto-generated method stub } @Override public void resetPlannerResults() { // TODO Auto-generated method stub } }
Because we are sub classing OOMDPPlanner, this object will auto create data members that define our domain and task (the domain, reward function, terminal function, discount factor, and state hashing factory that is used to hash and check the equality of states). However, the other critical data that VI needs to store are its estimates of the value function! We can store the value function as a mapping from states to double values and we can use a provided StateHashFactory to create fast hashable states (StateHashTuples). Furthermore, it might be useful to have the value of each state be initialized to something sensible that the client can specify. We can accept a procedure for initializing the value function by using a ValueFunctionInitialization object. Lets add those data members to our class, and make sure you also add the requisite imports. Finally, as a parameter to the algorithm, we'll let the client specify for many iterations VI will run, so we'll also need a data member for that. (We could also specify what the maximum allowable change in the value function was, but for simplicity for this tutorial we'll just use a fixed number of iterations.)
import java.util.Map; import java.util.HashMap; import burlap.behavior.singleagent.ValueFunctionInitialization;
protected Map<StateHashTuple, Double> valueFunction; protected ValueFunctionInitialization vinit; protected int numIterations;
Now lets add a constructor to accept and initialize all our data. Again, we'll need to add some imports too.
import burlap.behavior.statehashing.StateHashFactory; import burlap.oomdp.core.Domain; import burlap.oomdp.core.TerminalFunction; import burlap.oomdp.singleagent.RewardFunction;
public VITutorial(Domain domain, RewardFunction rf, TerminalFunction tf, double gamma, StateHashFactory hashingFactory, ValueFunctionInitialization vinit, int numIterations){ this.plannerInit(domain, rf, tf, gamma, hashingFactory); this.vinit = vinit; this.numIterations = numIterations; this.valueFunction = new HashMap<StateHashTuple, Double>(); }
Note that since our OOMDPPlanner superclass will hold our data members for the domain, reward function, terminal function, discount factor, and state hashing factory, we can initialize them with its plannerInit method.
There is one other critical component VI needs that isn't part of the data we've given it in the constructor: the full state space! One reason we might not want to demand this upfront is because in an OO-MDP, it might be possible for the state space to be infinite even though for any given input state there may only be a finite set of states that are reachable. We could require the user to provide to our algorithm up front what the state space is, but it's much easier on the client if we determine the set of possible reachable states for any given seed state ourself and only perform this procedure when planning is requested for a previously unseen state. In fact, there are planning algorithm independent tools in BURLAP that can find all reachable states from a seed state for us (see the StateReachabilityClass for this purpose); however, for the purposes of illustration, we will not make use of those tools and instead implement the reachability code ourselves.
To find all possibly reachable states from a source seed state, we need to do a kind of breadth-first search where we start with a queue containing only our seed state. We then dequeue a state from the queue and expand it by checking what all the possible outcomes states are from all possible actions and add those states to our queue if we've never seen them before. The search is complete when the queue is empty and every expanded state represents a state in our reachable state space. VI will then be able to iterate over this space. Lets implement that method now. In our code, we will use our valueFunction data member to effectively be our test for whether a state has been seen before and add each expanded node with its value function initialization as we see it. We will also need to add a few new imports for this method.
import java.util.LinkedList; import java.util.List; import burlap.oomdp.core.TransitionProbability; import burlap.oomdp.singleagent.GroundedAction;
public void performReachabilityFrom(State seedState){ StateHashTuple hashedSeed = this.hashingFactory.hashState(seedState); //mark our seed state as seen and set its initial value function value this.valueFunction.put(hashedSeed, this.vinit.value(hashedSeed.s)); LinkedList<StateHashTuple> open = new LinkedList<StateHashTuple>(); open.offer(hashedSeed); while(open.size() > 0){ //pop off a state and expand it StateHashTuple sh = open.poll(); //which actions can be applied on this state? List<GroundedAction> appliactionActions = this.getAllGroundedActions(sh.s); //for each action... for(GroundedAction ga : appliactionActions){ //what are the possible outcomes? List<TransitionProbability> tps = ga.action.getTransitions(sh.s, ga.params); //for each possible outcome... for(TransitionProbability tp : tps){ //add previously unseed states to our open queue and //set their initial value function StateHashTuple shp = this.hashingFactory.hashState(tp.s); if(!this.valueFunction.containsKey(shp)){ this.valueFunction.put(shp, this.vinit.value(shp.s)); open.offer(shp); } } } } }
With the inline comments, most of this code should be self explanatory. However, there are a couple of things to which you should pay closer attention. In each state we want to know what all the possible actions are. For this, we're using the OOMDPPlanner super class method getAllGroundedActions. We could have used the Action static method getAllApplicableActionsFromActionList in conjunction with the action list provided from our domain to get all grounded actions; however, it possible to add high-level (or otherwise) actions to a planner that are not strictly part of the domain (like options) and using our OOMDPPlanner method will always return the possible actions it was given instead of the primitives defined with the domain, using its method is preferable.
The other thing to pay attention to is how we get all possible state outcomes from applying an action in the state. For this request, we can get from our GroundedAction object the Action object reference and ask it to return a list of transition probabilities given the source state in question (and any possible action parameters specified if there are any) using the getTransitions method. The returned list from getTransitions is made up of TransitionProbability objects, which specify an outcome state and the probability of transitioning to it. As you may recall from our Building a Domain tutorial the getTransitons method will only return transitions to states that have non-zero probability so you don't have to worry about iterating over an unnecessarily large list of impossible transitions. However, it is worth pointing out that some domains may not implement the getTransitions method, particularly if there are an infinite number of states. Planners such as VI are not equipped to handle such domains (especially since we will use the full set of possible transitions to compute Q-values), so we must assume that that method is implemented.
The other method we'll need to implement is the Bellman Equation. As noted on the previous page, the Bellman Equation is just a max over the Q-values and since we already have methods defined for getting the Q-value of states (a requirement of implementing the QComputablePlannerInterface), we will implement those methods and a Bellman Equation method next. We will also need to add an import for ArrayList for these methods.
import java.util.ArrayList;
@Override public List<Value> getQs(State s) { List<GroundedAction> applicableActions = this.getAllGroundedActions(s); List<QValue> qs = new ArrayList<Value>(applicableActions.size()); for(GroundedAction ga : applicableActions){ qs.add(this.getQ(s, ga)); } return qs; } @Override public QValue getQ(State s, AbstractGroundedAction a) { //type cast to the type of grounded action we're using GroundedAction ga = (GroundedAction)a; //what are the possible outcomes? List<TransitionProbability> tps = ga.action.getTransitions(s, ga.params); //aggregate over each possible outcome double q = 0.; for(TransitionProbability tp : tps){ //what is reward for this transition? double r = this.rf.reward(s, ga, tp.s); //what is the value for the next state? double vp = this.valueFunction.get(this.hashingFactory.hashState(tp.s)); //add contribution weighted by transition probabiltiy and //discounting the next state q += tp.p * (r + this.gamma * vp); } //create Q-value wrapper QValue qValue = new QValue(s, ga, q); return qValue; } protected double bellmanEquation(State s){ if(this.tf.isTerminal(s)){ return 0.; } List<QValue> qs = this.getQs(s); double maxQ = Double.NEGATIVE_INFINITY; for(QValue q : qs){ maxQ = Math.max(maxQ, q.q); } return maxQ; }
You'll note that the Q-value methods return QValue objects, which are just triples consisting of a State object, an AbstractGroundedAction object, and a double for the Q-value associated with them.
In the getQs method, we simply find all possible grounded actions, ask our getQ method what the Q-value is, and then return the list of all those Q-values.
In the getQ method, we find all possible transitions from the input state and weigh the value of those outcomes by the probability of the transition occurring. The value of each outcome is the reward received, and the discounted value we have estimated for the outcome state.
In the bellmanEquation method, we in general just return the maximum Q-value for the state; however, there is a catch. That is, if the input state is a terminal state, then by definition of it being a terminal state the value is zero, because the idea of a terminal state is that no action can follow from it. Therefore, if the state is a terminal state, we return a value of 0 and ignore whatever the domain object would say the possible transitions would be. Note that this check is not just a performance saver; all terminal states are specified by the TerminalFunction interface, so we must always refer to it to handle terminal states and cannot expect that a domain's transition dynamics have it baked in.
We now have all the tools we need to do planning, so it's time to implement the planFromStateMethod. This method is called whenever a client wants to run planning from a given initial (or seed) state. What we'll do then is first check if we've already performed planning that includes that state. If so, we'll do nothing, having assumed to already have computed the value for it. However, if we haven't seen it before, then we'll first find all reachable states from it, and then run value iteration for a given number of iterations. As a reminder, running value iteration means making iterative sweeps over the entire state space in which the value of each state is re-estimated to what the Bellman equation says it is given the previously estimated value of the states.
@Override public void planFromState(State initialState) { StateHashTuple hashedInitialState = this.hashingFactory.hashState(initialState); if(this.valueFunction.containsKey(hashedInitialState)){ return; //already performed planning here! } //if the state is new, then find all reachable states from it first this.performReachabilityFrom(initialState); //now perform multiple iterations over the whole state space for(int i = 0; i < this.numIterations; i++){ //iterate over each state for(StateHashTuple sh : this.valueFunction.keySet()){ //update its value using the bellman equation this.valueFunction.put(sh, this.bellmanEquation(sh.s)); } } }
We're now just about finished! The only thing left is that each OOMDPPlanner instance is asked to implement the method resetPlannerResults, which when called should have the effect of resetting all data so that it's as if no planning calls had ever been made. For our VI implementation, all this requires is clearing our value function.
@Override public void resetPlannerResults() { this.valueFunction.clear(); }
To test our code, you can try using this planning algorithm with the grid world task created in the previous Basic Planning and Learning tutorial. You'll note that since we implement the QComputablePlanner interface, we can use any existing Q-value derived policy, such as GreedyQ, EpsilonGreedy, etc. Alternatively, below is a main method that you can add to test your VI implementation that creates a stochastic grid world, plans for it, and evaluates a single rollout of the resulting policy. You should find that the agent takes north and east actions exclusively.
import burlap.behavior.singleagent.EpisodeAnalysis; import burlap.behavior.singleagent.Policy; import burlap.behavior.singleagent.planning.commonpolicies.GreedyQPolicy; import burlap.behavior.statehashing.DiscreteStateHashFactory; import burlap.domain.singleagent.gridworld.GridWorldDomain; import burlap.domain.singleagent.gridworld.GridWorldTerminalFunction; import burlap.oomdp.singleagent.common.UniformCostRF;
public static void main(String [] args){ GridWorldDomain gwd = new GridWorldDomain(11, 11); gwd.setMapToFourRooms(); //only go in intended directon 80% of the time gwd.setProbSucceedTransitionDynamics(0.8); Domain domain = gwd.generateDomain(); //get initial state with agent in 0,0 State s = GridWorldDomain.getOneAgentNoLocationState(domain); GridWorldDomain.setAgent(s, 0, 0); //all transitions return -1 RewardFunction rf = new UniformCostRF(); //terminate in top right corner TerminalFunction tf = new GridWorldTerminalFunction(10, 10); //setup vi with 0.99 discount factor, discrete state hashing factory, a value //function initialization that initializes all states to value 0, and which will //run for 30 iterations over the state space VITutorial vi = new VITutorial(domain, rf, tf, 0.99, new DiscreteStateHashFactory(), new ValueFunctionInitialization.ConstantValueFunctionInitialization(0.0), 30); //run planning from our initial state vi.planFromState(s); //get the greedy policy from it Policy p = new GreedyQPolicy(vi); //evaluate the policy with one roll out and print out the action sequence EpisodeAnalysis ea = p.evaluateBehavior(s, rf, tf); System.out.println(ea.getActionSequenceString("\n")); }
If you're looking to extend this tutorial on VI a little more, you might consider implementing a more intelligent VI termination condition so that rather than always running VI for a fixed number of iterations, VI terminates if the maximum change in the value function is smaller than some small threshold. Otherwise, it's now time to move on to our Q-learning example! If you'd like to see the full code we wrote all together, jump to the end of this tutorial.