In this tutorial we showed you how to implement your own planning and learning algorithms. Although these algorithms were simple, they exposed the necessary BURLAP tools and mechanisms you will need to use to implement your own algorithms and should enable you to start writing your own code. In general, we highly recommend that you use BURLAP's existing implementations of Value Iteration and Q-Learning since they support a number of other features (Options, learning rate decay schedules, etc.). If you would like to see all of the code that was written in this tutorial, we have provided it below (first the Value Iteration code, then the Q-learning Code).
import burlap.behavior.policy.GreedyQPolicy;
import burlap.behavior.policy.Policy;
import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.auxiliary.EpisodeSequenceVisualizer;
import burlap.behavior.singleagent.auxiliary.StateReachability;
import burlap.behavior.singleagent.planning.Planner;
import burlap.behavior.valuefunction.QFunction;
import burlap.behavior.valuefunction.QValue;
import burlap.behavior.valuefunction.ValueFunctionInitialization;
import burlap.domain.singleagent.gridworld.GridWorldDomain;
import burlap.domain.singleagent.gridworld.GridWorldTerminalFunction;
import burlap.domain.singleagent.gridworld.GridWorldVisualizer;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.core.TransitionProbability;
import burlap.oomdp.core.states.State;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import burlap.oomdp.singleagent.SADomain;
import burlap.oomdp.singleagent.common.UniformCostRF;
import burlap.oomdp.statehashing.HashableState;
import burlap.oomdp.statehashing.HashableStateFactory;
import burlap.oomdp.statehashing.SimpleHashableStateFactory;
import burlap.oomdp.visualizer.Visualizer;
import java.util.*;
public class VITutorial extends MDPSolver implements Planner, QFunction{
protected Map<HashableState, Double> valueFunction;
protected ValueFunctionInitialization vinit;
protected int numIterations;
public VITutorial(Domain domain, RewardFunction rf, TerminalFunction tf, double gamma,
HashableStateFactory hashingFactory, ValueFunctionInitialization vinit,
int numIterations){
this.solverInit(domain, rf, tf, gamma, hashingFactory);
this.vinit = vinit;
this.numIterations = numIterations;
this.valueFunction = new HashMap<HashableState, Double>();
}
@Override
public double value(State s) {
Double d = this.valueFunction.get(hashingFactory.hashState(s));
if(d == null){
return vinit.value(s);
}
return d;
}
@Override
public List<QValue> getQs(State s) {
List<GroundedAction> applicableActions = this.getAllGroundedActions(s);
List<QValue> qs = new ArrayList<QValue>(applicableActions.size());
for(GroundedAction ga : applicableActions){
qs.add(this.getQ(s, ga));
}
return qs;
}
@Override
public QValue getQ(State s, AbstractGroundedAction a) {
//type cast to the type we're using
GroundedAction ga = (GroundedAction)a;
//what are the possible outcomes?
List<TransitionProbability> tps = ga.getTransitions(s);
//aggregate over each possible outcome
double q = 0.;
for(TransitionProbability tp : tps){
//what is reward for this transition?
double r = this.rf.reward(s, ga, tp.s);
//what is the value for the next state?
double vp = this.valueFunction.get(this.hashingFactory.hashState(tp.s));
//add contribution weighted by transition probabiltiy and
//discounting the next state
q += tp.p * (r + this.gamma * vp);
}
//create Q-value wrapper
QValue qValue = new QValue(s, ga, q);
return qValue;
}
protected double bellmanEquation(State s){
if(this.tf.isTerminal(s)){
return 0.;
}
List<QValue> qs = this.getQs(s);
double maxQ = Double.NEGATIVE_INFINITY;
for(QValue q : qs){
maxQ = Math.max(maxQ, q.q);
}
return maxQ;
}
@Override
public GreedyQPolicy planFromState(State initialState) {
HashableState hashedInitialState = this.hashingFactory.hashState(initialState);
if(this.valueFunction.containsKey(hashedInitialState)){
return new GreedyQPolicy(this); //already performed planning here!
}
//if the state is new, then find all reachable states from it first
this.performReachabilityFrom(initialState);
//now perform multiple iterations over the whole state space
for(int i = 0; i < this.numIterations; i++){
//iterate over each state
for(HashableState sh : this.valueFunction.keySet()){
//update its value using the bellman equation
this.valueFunction.put(sh, this.bellmanEquation(sh.s));
}
}
return new GreedyQPolicy(this);
}
@Override
public void resetSolver() {
}
public void performReachabilityFrom(State seedState){
Set<HashableState> hashedStates = StateReachability.getReachableHashedStates(seedState,
(SADomain) this.domain, this.hashingFactory);
//initialize the value function for all states
for(HashableState hs : hashedStates){
if(!this.valueFunction.containsKey(hs)){
this.valueFunction.put(hs, this.vinit.value(hs.s));
}
}
}
public static void main(String [] args){
GridWorldDomain gwd = new GridWorldDomain(11, 11);
gwd.setMapToFourRooms();
//only go in intended directon 80% of the time
gwd.setProbSucceedTransitionDynamics(0.8);
Domain domain = gwd.generateDomain();
//get initial state with agent in 0,0
State s = GridWorldDomain.getOneAgentNoLocationState(domain);
GridWorldDomain.setAgent(s, 0, 0);
//all transitions return -1
RewardFunction rf = new UniformCostRF();
//terminate in top right corner
TerminalFunction tf = new GridWorldTerminalFunction(10, 10);
//setup vi with 0.99 discount factor, a value
//function initialization that initializes all states to value 0, and which will
//run for 30 iterations over the state space
VITutorial vi = new VITutorial(domain, rf, tf, 0.99, new SimpleHashableStateFactory(),
new ValueFunctionInitialization.ConstantValueFunctionInitialization(0.0), 30);
//run planning from our initial state
Policy p = vi.planFromState(s);
//evaluate the policy with one roll out visualize the trajectory
EpisodeAnalysis ea = p.evaluateBehavior(s, rf, tf);
Visualizer v = GridWorldVisualizer.getVisualizer(gwd.getMap());
new EpisodeSequenceVisualizer(v, domain, Arrays.asList(ea));
}
}
import burlap.behavior.policy.EpsilonGreedy;
import burlap.behavior.policy.Policy;
import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.auxiliary.EpisodeSequenceVisualizer;
import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.valuefunction.QFunction;
import burlap.behavior.valuefunction.QValue;
import burlap.behavior.valuefunction.ValueFunctionInitialization;
import burlap.domain.singleagent.gridworld.GridWorldDomain;
import burlap.domain.singleagent.gridworld.GridWorldTerminalFunction;
import burlap.domain.singleagent.gridworld.GridWorldVisualizer;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.core.states.State;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import burlap.oomdp.singleagent.common.UniformCostRF;
import burlap.oomdp.singleagent.environment.Environment;
import burlap.oomdp.singleagent.environment.EnvironmentOutcome;
import burlap.oomdp.singleagent.environment.SimulatedEnvironment;
import burlap.oomdp.statehashing.HashableState;
import burlap.oomdp.statehashing.HashableStateFactory;
import burlap.oomdp.statehashing.SimpleHashableStateFactory;
import burlap.oomdp.visualizer.Visualizer;
import java.util.*;
public class QLTutorial extends MDPSolver implements LearningAgent, QFunction {
Map<HashableState, List<QValue>> qValues;
ValueFunctionInitialization qinit;
double learningRate;
Policy learningPolicy;
public QLTutorial(Domain domain, double gamma, HashableStateFactory hashingFactory,
ValueFunctionInitialization qinit, double learningRate, double epsilon){
this.solverInit(domain, null, null, gamma, hashingFactory);
this.qinit = qinit;
this.learningRate = learningRate;
this.qValues = new HashMap<HashableState, List<QValue>>();
this.learningPolicy = new EpsilonGreedy(this, epsilon);
}
@Override
public EpisodeAnalysis runLearningEpisode(Environment env) {
return this.runLearningEpisode(env, -1);
}
@Override
public EpisodeAnalysis runLearningEpisode(Environment env, int maxSteps) {
//initialize our episode analysis object with the initial state of the environment
EpisodeAnalysis ea = new EpisodeAnalysis(env.getCurrentObservation());
//behave until a terminal state or max steps is reached
State curState = env.getCurrentObservation();
int steps = 0;
while(!env.isInTerminalState() && (steps < maxSteps || maxSteps == -1)){
//select an action
GroundedAction a = (GroundedAction)this.learningPolicy.getAction(curState);
//take the action and observe outcome
EnvironmentOutcome eo = a.executeIn(env);
//record result
ea.recordTransitionTo(a, eo.op, eo.r);
//get the max Q value of the resulting state if it's not terminal, 0 otherwise
double maxQ = eo.terminated ? 0. : this.value(eo.op);
//update the old Q-value
QValue oldQ = this.getQ(curState, a);
oldQ.q = oldQ.q + this.learningRate * (eo.r + this.gamma * maxQ - oldQ.q);
//move on to next state
curState = eo.op;
steps++;
}
return ea;
}
@Override
public void resetSolver() {
this.qValues.clear();
}
@Override
public List<QValue> getQs(State s) {
//first get hashed state
HashableState sh = this.hashingFactory.hashState(s);
//check if we already have stored values
List<QValue> qs = this.qValues.get(sh);
//create and add initialized Q-values if we don't have them stored for this state
if(qs == null){
List<GroundedAction> actions = this.getAllGroundedActions(s);
qs = new ArrayList<QValue>(actions.size());
//create a Q-value for each action
for(GroundedAction ga : actions){
//add q with initialized value
qs.add(new QValue(s, ga, this.qinit.qValue(s, ga)));
}
//store this for later
this.qValues.put(sh, qs);
}
return qs;
}
@Override
public QValue getQ(State s, AbstractGroundedAction a) {
//first get all Q-values
List<QValue> qs = this.getQs(s);
//translate action parameters to source state for Q-values if needed
a = ((GroundedAction)a).translateParameters(s, qs.get(0).s);
//iterate through stored Q-values to find a match for the input action
for(QValue q : qs){
if(q.a.equals(a)){
return q;
}
}
throw new RuntimeException("Could not find matching Q-value.");
}
@Override
public double value(State s) {
return QFunctionHelper.getOptimalValue(this, s);
}
public static void main(String[] args) {
GridWorldDomain gwd = new GridWorldDomain(11, 11);
gwd.setMapToFourRooms();
gwd.setProbSucceedTransitionDynamics(0.8);
Domain domain = gwd.generateDomain();
//get initial state with agent in 0,0
State s = GridWorldDomain.getOneAgentNoLocationState(domain);
GridWorldDomain.setAgent(s, 0, 0);
//all transitions return -1
RewardFunction rf = new UniformCostRF();
//terminate in top right corner
TerminalFunction tf = new GridWorldTerminalFunction(10, 10);
//create environment
SimulatedEnvironment env = new SimulatedEnvironment(domain,rf, tf, s);
//create Q-learning
QLTutorial agent = new QLTutorial(domain, 0.99, new SimpleHashableStateFactory(),
new ValueFunctionInitialization.ConstantValueFunctionInitialization(), 0.1, 0.1);
//run Q-learning and store results in a list
List<EpisodeAnalysis> episodes = new ArrayList<EpisodeAnalysis>(1000);
for(int i = 0; i < 1000; i++){
episodes.add(agent.runLearningEpisode(env));
env.resetEnvironment();
}
Visualizer v = GridWorldVisualizer.getVisualizer(gwd.getMap());
new EpisodeSequenceVisualizer(v, domain, episodes);
}
}