Tutorial: Creating a Planning and Learning Algorithm

Tutorials > Creating a Planning and Learning Algorithm > Part 3

Q-Learning Overview

For our learning algorithm example, we'll be implementing Q-learning. The difference between a learning algorithm and a planning algorithm is that a planning algorithm has access to a model of the world, or at least a simulator, whereas a learning algorithm involves determining behavior when the agent does not know how the world works and must learn how to behave from direct experience with the world. In general, there are two approaches to reinforcement learning: (1) to learn a model of the world from experience and then use planning with that learned model to dictate behavior (model-based) and (2) to learn a policy or value function directly from experience (model-free). Q-learning belongs to the latter.

As the name suggests, Q-learning learns estimates of the optimal Q-values of an MDP, which means that behavior can be dictated by taking actions greedily with respect to the learned Q-values. Q-learning can be summarized in the following pseudocode.

Q-Learning

  1. Initialize Q-values ($Q(s,a)$) arbitrarily for all state-action pairs.
  2. For life or until learning is stopped...
  3.     Choose an action ($a$) in the current world state ($s$) based on current Q-value estimates ($Q(s,\cdot)$).
  4.     Take the action ($a$) and observe the the outcome state ($s'$) and reward ($r$).
  5.     Update $Q(s,a) := Q(s,a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s,a) \right]$

The two key steps in the above pseudocode are steps 3 and 5. There are many ways to choose actions based on the current Q-value estimates (step 3), but one of the most common is to use an $\epsilon$-greedy policy. In this policy, the action is selected greedily with respect to the Q-value estimates a fraction ($1-\epsilon$) of the time (where $\epsilon$ is a fraction between 0 and 1), and randomly selected among all actions a fraction $\epsilon$ of the time. In general, you want a policy that has some randomness to it so that it promotes exploration of the state space.

The update rule: $$ \large Q(s,a) := Q(s,a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s,a) \right] $$ updates the Q-value of the last state-action pair ($s,a$) with respect to the observed outcome state ($s'$) and reward ($r$), where $\alpha \in (0, 1)$ is a learning rate parameter.

To unpack this update, recall from the Bellman equation that the Value of a state is the maximum Q-value and the Q-value is the expected sum of the reward and discounted value of the next state, where the expectation is with respect to the probability of each state transition. In the Q-learning update rule, the reward plus the discounted max Q-value in the observed next state is effectively what the Bellman equation tells us the Q-value is, except in this case, we're not marginalizing over all possible outcome states, we only have the one observed state and reward that we happened to get. However, because our learning rate only allows our Q-value to change slightly from its old estimate to a new estimate in the direction of the observed state and reward, as long as we keep retrying that action in the same state, we'll see the other possible states that could have occurred and move in their direction too. In aggregate over multiple tries of the action then, the Q-value will move toward the true expected value, even though we never directly used the transition probabilities. To have guaranteed convergence to the true Q-value, we should actually be slowly decreasing the learning rate parameter over time. However, in practice, it's often sufficient to simply use a small learning rate parameter, so for simplicity in our implementation, we'll use a fixed value for the learning rate rather that one that changes with time (though in the full Q-learning algorithm provided in BURLAP, you can use different schedules for decreasing the learning rate, including client-provided custom schedules with the LearningRate interface).

Q-Learning Code

Lets begin implementing our Q-learning algorithm code. Our class, called QLTutorial, will extend MDPSolver and implement the LearningAgent and QFunction interfaces. The LearningAgent interface specifies the common methods a learning algorithm is expected to implement so that it can be used by other BURLAP tools.

Below is the skeleton code that is created when we created our class.

import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.valuefunction.QFunction;
import burlap.behavior.valuefunction.QValue;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.states.State;
import burlap.oomdp.singleagent.environment.Environment;

import java.util.List;


public class QLTutorial extends MDPSolver implements LearningAgent, QFunction {

	@Override
	public EpisodeAnalysis runLearningEpisode(Environment env) {
		return null;
	}

	@Override
	public EpisodeAnalysis runLearningEpisode(Environment env, int maxSteps) {
		return null;
	}

	@Override
	public void resetSolver() {

	}

	@Override
	public List getQs(State s) {
		return null;
	}

	@Override
	public QValue getQ(State s, AbstractGroundedAction a) {
		return null;
	}

	@Override
	public double value(State s) {
		return 0;
	}
}
				

Similar to VI, the primary data we will want to store is a set of estimated Q-values for each state and action pair. We'll also again let the user specify the Q-value function initialization with a ValueFunctionInitialization object. We'll also need a learning rate parameter to be set. Finally, we'll need a learning policy to follow; that is, a policy that dictates how the agent chooses actions at each step. For this tutorial, we'll assume an $\epsilon$-greedy policy and let the client specify the value for $\epsilon$. Lets add data members for those elements now, as well as the value function which needs to return the maximum Q-value.

protected Map<StateHashTuple, List<QValue>> qValues;
protected ValueFunctionInitialization qinit;
protected double learningRate;
protected Policy learningPolicy;
				

Lets also add a constructor to initialize our data members and some of those that we inherit from MDPSolver.

public QLTutorial(Domain domain, double gamma, HashableStateFactory hashingFactory,
				  ValueFunctionInitialization qinit, double learningRate,  double epsilon){

	this.solverInit(domain, null, null, gamma, hashingFactory);
	this.qinit = qinit;
	this.learningRate = learningRate;
	this.qValues = new HashMap>();
	this.learningPolicy = new EpsilonGreedy(this, epsilon);

}
				

Note that we did not have to initialize the reward function or terminal function for the MDPSolver (the two null parameters) since these will be handled by the Environment object with which the learning algorithm will interact. The EpsilonGreedy policy object we create takes as input a QFunction, which this class will provide, and the value for epsilon to use.


One of the primary tools we'll need is a method that grabs our Q-values, or creates and stores them with the proper initialization value if it's for an unseen state. Lets implement our Q-Value methods now as well as the value method which returns the state value: the maximum Q-value of all actions applicable in the state.

@Override
public List<QValue> getQs(State s) {
	
	//first get hashed state
	StateHashTuple sh = this.hashingFactory.hashState(s);
	
	//check if we already have stored values
	List<QValue> qs = this.qValues.get(sh);
	
	//create and add initialized Q-values if we don't have them stored for this state
	if(qs == null){
		List<GroundedAction> actions = this.getAllGroundedActions(s);
		qs = new ArrayList<QValue>(actions.size());
		//create a Q-value for each action
		for(GroundedAction ga : actions){
			//add q with initialized value
			qs.add(new QValue(s, ga, this.qinit.qValue(s, ga)));
		}
		//store this for later
		this.qValues.put(sh, qs);
	}
	
	return qs;
}

@Override
public QValue getQ(State s, AbstractGroundedAction a) {
	
	//first get all Q-values
	List qs = this.getQs(s);

	//iterate through stored Q-values to find a match for the input action
	for(QValue q : qs){
		if(q.a.equals(a)){
			return q;
		}
	}

	throw new RuntimeException("Could not find matching Q-value.");

}

@Override
public double value(State s) {
	return QFunctionHelper.getOptimalValue(this, s);
}
				

Note that for the value method, we used the QFunciton helper class QFunctionHelper whose method getOptimalValue will return the maximum Q-value in a QFunction object by querying its getQs method and returning the maximum Q-value value.


Now that we have all of our helper methods, lets implement the learning algorithm. The LearningAgent interface requires us to implement two methods that cause learning to be run for one episode in some Environment; one that will run learning until the agent reaches a terminal state and one that will run learning for a maximum number of steps or until a terminal state is reached. We will have the former call the latter with a -1 for the maximum number of steps to indicate that it should never stop until the agent reaches a terminal state. Both methods also require returning an EpisodeAnalysis object, which is a recording of all the states, actions, and rewards that occurred in an episode, so as we write the code to have the agent iteratively select actions, we'll record the results to an EpisodeAnalysis object.

Below is the learning algorithm code for Q-learning.

@Override
public EpisodeAnalysis runLearningEpisode(Environment env) {
	return this.runLearningEpisode(env, -1);
}

@Override
public EpisodeAnalysis runLearningEpisode(Environment env, int maxSteps) {
	//initialize our episode analysis object with the initial state of the environment
	EpisodeAnalysis ea = new EpisodeAnalysis(env.getCurrentObservation());

	//behave until a terminal state or max steps is reached
	State curState = env.getCurrentObservation();
	int steps = 0;
	while(!env.isInTerminalState() && (steps < maxSteps || maxSteps == -1)){

		//select an action
		GroundedAction a = (GroundedAction)this.learningPolicy.getAction(curState);

		//take the action and observe outcome
		EnvironmentOutcome eo = a.executeIn(env);

		//record result
		ea.recordTransitionTo(a, eo.o, eo.r);

		//get the max Q value of the resulting state if it's not terminal, 0 otherwise
		double maxQ = eo.terminated ? 0. : this.value(eo.op);

		//update the old Q-value
		QValue oldQ = this.getQ(curState, a);
		oldQ.q = oldQ.q + this.learningRate * (eo.r + this.gamma * maxQ - oldQ.q);


		//move on to next state
		curState = eo.op;
		steps++;

	}

	return ea;
}
				

The beginning of the code is fairly straightforward; we construct a new EpisodeAnalysis object rooted in the current state of the environment, which we get back from the Environment method getCurrentObservation(). We then begin an execution loop that lasts either until the Environment reaches a terminal state or until the number of steps we've taken exceeds the number requested.

Inside the execution loop, we first select an action using our learning policy. Then we execute the action in the environment using the GroundedAction method executeIn(Environment), which returns to us an EnvironmentOutcome object. This object is tuple with the following datamembers.

Environment Observations

You may have noticed that the Environment uses "observation" terminology instead of "state" terminology. This choice is because Environment objects are not under obligation to return to the agent a full state, only an observation. Typically, for MDP domains you can expect it to be a full State, and regardless of whether it is a partial observation or not, the observation itself will always be represnted by a BURLAP State object. Note that the use of this terminology is especially useful if you begin using BURLAP's POMDP framework.

Using the new observations from the environment, we record the transition in our EpisodeAnalysis and update the previous Q-Value. To update the previous Q-value, we need to get the maximum Q-value for the next state we encounted. However, if that state is a terminal state, then the value should always be zero, because the agent cannot act further from that state. Otherwise, we can get the maximum value by using value method that we previously defined.

Finally, we can implement the resetSoverResults method, which only needs to clear our Q-values.

@Override
public void resetSolverResults() {
	this.qValues.clear();
}
				

Testing Q-Learning

As before, you can now test your learning algorithm with the previous code developed in the Basic Planning and Learning tutorial. Alternatively, you can use the below main method which creates a similar Grid World domain and task as the test code we wrote for our VI implementation, except applies the Q-Learning algorithm to it in a simulated environment. The results of each leaning episode will be presented for you after learning completes. Note that because the domain is stochastic (and follows a nosiy exploration policy), it can take much longer to learn and the resulting policy will not be a straight shot to the goal.

public static void main(String [] args){
		
	GridWorldDomain gwd = new GridWorldDomain(11, 11);
	gwd.setMapToFourRooms();
	gwd.setProbSucceedTransitionDynamics(0.8);

	Domain domain = gwd.generateDomain();

	//get initial state with agent in 0,0
	State s = GridWorldDomain.getOneAgentNoLocationState(domain);
	GridWorldDomain.setAgent(s, 0, 0);

	//all transitions return -1
	RewardFunction rf = new UniformCostRF();

	//terminate in top right corner
	TerminalFunction tf = new GridWorldTerminalFunction(10, 10);

	//create environment
	SimulatedEnvironment env = new SimulatedEnvironment(domain,rf, tf, s);

	//create Q-learning
	QLTutorial agent = new QLTutorial(domain, 0.99, new SimpleHashableStateFactory(),
			new ValueFunctionInitialization.ConstantValueFunctionInitialization(), 0.1, 0.1);

	//run Q-learning and store results in a list
	List episodes = new ArrayList(1000);
	for(int i = 0; i < 1000; i++){
		episodes.add(agent.runLearningEpisode(env));
		env.resetEnvironment();
	}

	Visualizer v = GridWorldVisualizer.getVisualizer(gwd.getMap());
	new EpisodeSequenceVisualizer(v, domain, episodes);
	
}