Tutorial: Creating a Planning and Learning Algorithm

Tutorials > Creating a Planning and Learning Algorithm > Part 3

Q-Learning Overview

For our learning algorithm example, we'll be implementing Q-learning. The difference between a learning algorithm and a planning algorithm is that a planning algorithm has access to a model of the world, or at least a simulator, whereas a learning algorithm involves determining behavior when the agent does not know how the world works and must learn how to behave from direct experience with the world. In general, there are two approaches to reinforcement learning: (1) to learn a model of the world from experience and then use planning with that learned model to dictate behavior (model-based) and (2) to learn a policy or value function directly from experience (model-free). Q-learning belongs to the latter.

As the name suggests, Q-learning learns estimates of the optimal Q-values of an MDP, which means that behavior can be dictated by taking actions greedily with respect to the learned Q-values. Q-learning can be summarized in the following pseudocode.

Q-Learning

  1. Initialize Q-values ($Q(s,a)$) arbitrarily for all state-action pairs.
  2. For life or until learning is stopped...
  3.     Choose an action ($a$) in the current world state ($s$) based on current Q-value estimates ($Q(s,\cdot)$).
  4.     Take the action ($a$) and observe the the outcome state ($s'$) and reward ($r$).
  5.     Update $Q(s,a) := Q(s,a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s,a) \right]$

The two key steps in the above pseudocode are steps 3 and 5. There are many ways to choose actions based on the current Q-value estimates (step 3), but one of the most common is to use an $\epsilon$-greedy policy. In this policy, the action is selected greedily with respect to the Q-value estimates a fraction ($1-\epsilon$) of the time (where $\epsilon$ is a fraction between 0 and 1), and randomly selected among all actions a fraction $\epsilon$ of the time. In general, you want a policy that has some randomness to it so that it promotes exploration of the state space.

The update rule: $$ \large Q(s,a) := Q(s,a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s,a) \right] $$ updates the Q-value of the last state-action pair ($s,a$) with respect to the observed outcome state ($s'$) and reward ($r$), where $\alpha \in (0, 1)$ is a learning rate parameter.

To unpack this update, recall from the Bellman equation that the Value of a state is the maximum Q-value and the Q-value is the expected sum of the reward and discounted value of the next state, where the expectation is with respect to the probability of each state transition. In the Q-learning update rule, the reward plus the discounted max Q-value in the observed next state is effectively what the Bellman equation tells us the Q-value is, except in this case, we're not marginalizing over all possible outcome states, we only have the one observed state and reward that we happened to get. However, because our learning rate only allows our Q-value to change slightly from its old estimate to a new estimate in the direction of the observed state and reward, as long as we keep retrying that action in the same state, we'll see the other possible states that could have occurred and move in their direction too. In aggregate over multiple tries of the action then, the Q-value will move toward the true expected value, even though we never directly used the transition probabilities. To have guaranteed convergence to the true Q-value, we should actually be slowly decreasing the learning rate parameter over time. However, in practice, it's often sufficient to simply use a small learning rate parameter, so for simplicity in our implementation, we'll use a fixed value for the learning rate rather that one that changes with time (though in the full Q-learning algorithm provided in BURLAP, you can use different schedules for decreasing the learning rate, including client-provided custom schedules with the LearningRate interface).

Q-Learning Code

Lets begin implementing our Q-learning algorithm code. Our class, called QLTutorial, will extend MDPSolver and implement the LearningAgent and QProvider interfaces. The LearningAgent interface specifies the common methods a learning algorithm is expected to implement so that it can be used by other BURLAP tools.

Below is the skeleton code that is created when we created our class.

import burlap.behavior.policy.EpsilonGreedy;
import burlap.behavior.policy.Policy;
import burlap.behavior.singleagent.Episode;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.auxiliary.EpisodeSequenceVisualizer;
import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.valuefunction.ConstantValueFunction;
import burlap.behavior.valuefunction.QFunction;
import burlap.behavior.valuefunction.QProvider;
import burlap.behavior.valuefunction.QValue;
import burlap.domain.singleagent.gridworld.GridWorldDomain;
import burlap.domain.singleagent.gridworld.GridWorldTerminalFunction;
import burlap.domain.singleagent.gridworld.GridWorldVisualizer;
import burlap.domain.singleagent.gridworld.state.GridAgent;
import burlap.domain.singleagent.gridworld.state.GridWorldState;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.Environment;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import burlap.mdp.singleagent.environment.SimulatedEnvironment;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import burlap.statehashing.simple.SimpleHashableStateFactory;
import burlap.visualizer.Visualizer;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;


public class QLTutorial extends MDPSolver implements LearningAgent, QProvider {

	@Override
	public Episode runLearningEpisode(Environment env) {
		return null;
	}

	@Override
	public Episode runLearningEpisode(Environment env, int maxSteps) {
		return null;
	}

	@Override
	public void resetSolver() {
		
	}

	@Override
	public List<QValue> qValues(State s) {
		return null;
	}

	@Override
	public double qValue(State s, Action a) {
		return 0.;
	}


	@Override
	public double value(State s) {
		return 0.;
	}
				

Similar to VI, the primary data we will want to store is a set of estimated Q-values for each state and action pair. We'll also again let the user specify the Q-value function initialization with a QFunction object. We'll also need a learning rate parameter to be set. Finally, we'll need a learning policy to follow; that is, a policy that dictates how the agent chooses actions at each step. For this tutorial, we'll assume an $\epsilon$-greedy policy and let the client specify the value for $\epsilon$. Lets add data members for those elements now.

Map<HashableState, List<QValue>> qValues;
QFunction qinit;
double learningRate;
Policy learningPolicy;
				

Lets also add a constructor to initialize our data members and some of those that we inherit from MDPSolver.

public QLTutorial(SADomain domain, double gamma, HashableStateFactory hashingFactory,
					  QFunction qinit, double learningRate, double epsilon){

	this.solverInit(domain, gamma, hashingFactory);
	this.qinit = qinit;
	this.learningRate = learningRate;
	this.qValues = new HashMap<HashableState, List<QValue>>();
	this.learningPolicy = new EpsilonGreedy(this, epsilon);

}
				

Note that the EpsilonGreedy policy object we create takes as input a QProvider, which this class implements, and the value for epsilon to use.


Getting and storing Q-values is the primary tool we'll need for our algorithm, so lets implement the value function methods now.

@Override
public List<QValue> qValues(State s) {
	//first get hashed state
	HashableState sh = this.hashingFactory.hashState(s);

	//check if we already have stored values
	List<QValue> qs = this.qValues.get(sh);

	//create and add initialized Q-values if we don't have them stored for this state
	if(qs == null){
		List<Action> actions = this.applicableActions(s);
		qs = new ArrayList<QValue>(actions.size());
		//create a Q-value for each action
		for(Action a : actions){
			//add q with initialized value
			qs.add(new QValue(s, a, this.qinit.qValue(s, a)));
		}
		//store this for later
		this.qValues.put(sh, qs);
	}

	return qs;
}

@Override
public double qValue(State s, Action a) {
	return storedQ(s, a).q;
}


protected QValue storedQ(State s, Action a){
	//first get all Q-values
	List qs = this.qValues(s);

	//iterate through stored Q-values to find a match for the input action
	for(QValue q : qs){
		if(q.a.equals(a)){
			return q;
		}
	}

	throw new RuntimeException("Could not find matching Q-value.");
}

@Override
public double value(State s) {
	return QProvider.Helper.maxQ(this, s);
}
				

Note that the qValues method checks if we've already stored Q-values for the given state. If not, we create them with the initial Q-value defined by our QFunction initialization object. For the qValue method, we go through a helper method that returnes the stored QValue object for a given action; this storedQ method will be useful for updating our Q-values for the actual learning algorithm. The value method, like for our value iteration example, can simply return the result of the QProvider Helper class method maxQ.


Now that we have all of our helper methods, lets implement the learning algorithm. The LearningAgent interface requires us to implement two methods that cause learning to be run for one episode in some Environment; one that will run learning until the agent reaches a terminal state and one that will run learning for a maximum number of steps or until a terminal state is reached. We will have the former call the latter with a -1 for the maximum number of steps to indicate that it should never stop until the agent reaches a terminal state. Both methods also require returning an Episode object, which is a recording of all the states, actions, and rewards that occurred in an episode, so as we write the code to have the agent iteratively select actions, we'll record the results to an Episode object.

Below is the learning algorithm code for Q-learning.

@Override
public Episode runLearningEpisode(Environment env) {
	return this.runLearningEpisode(env, -1);
}

@Override
public Episode runLearningEpisode(Environment env, int maxSteps) {
	//initialize our episode object with the initial state of the environment
	Episode e = new Episode(env.currentObservation());

	//behave until a terminal state or max steps is reached
	State curState = env.currentObservation();
	int steps = 0;
	while(!env.isInTerminalState() && (steps < maxSteps || maxSteps == -1)){

		//select an action
		Action a = this.learningPolicy.action(curState);

		//take the action and observe outcome
		EnvironmentOutcome eo = env.executeAction(a);

		//record result
		e.transition(eo);

		//get the max Q value of the resulting state if it's not terminal, 0 otherwise
		double maxQ = eo.terminated ? 0. : this.value(eo.op);

		//update the old Q-value
		QValue oldQ = this.storedQ(curState, a);
		oldQ.q = oldQ.q + this.learningRate * (eo.r + this.gamma * maxQ - oldQ.q);


		//update state pointer to next environment state observed
		curState = eo.op;
		steps++;

	}

	return e;
}
				

The beginning of the code is fairly straightforward; we construct a new Episode object rooted in the current state of the environment, which we get back from the Environment method getCurrentObservation(). We then begin an execution loop that lasts either until the Environment reaches a terminal state or until the number of steps we've taken exceeds the number requested.

Inside the execution loop, we first select an action using our learning policy. Then we execute the action in the environment using the GroundedAction method executeIn(Environment), which returns to us an EnvironmentOutcome object.

Environment Observations

You may have noticed that the Environment uses "observation" terminology instead of "state" terminology. This choice is because Environment objects are not under obligation to return to the agent a full state, only an observation. Typically, for MDP domains you can expect it to be a full State, and regardless of whether it is a partial observation or not, the observation itself will always be represnted by a BURLAP State object. Note that the use of this terminology is especially useful if you begin using BURLAP's POMDP framework.

Using the new observations from the environment, we record the transition in our Episode and update the previous Q-Value. To update the previous Q-value, we need to get the maximum Q-value for the next state we encounted. However, if that state is a terminal state, then the value should always be zero, because the agent cannot act further from that state. Otherwise, we can get the maximum value by using value method that we previously defined.

Finally, we can implement the resetSover method, which only needs to clear our Q-values.

@Override
public void resetSolver() {
	this.qValues.clear();
}
				

Testing Q-Learning

As before, you can now test your learning algorithm with the previous code developed in the Basic Planning and Learning tutorial. Alternatively, you can use the below main method which creates a similar Grid World domain and task as the test code we wrote for our VI implementation, except applies the Q-Learning algorithm to it in a simulated environment. The results of each leaning episode will be presented for you after learning completes. Note that because the domain is stochastic (and follows a nosiy exploration policy), it can take much longer to learn and the resulting policy will not be a straight shot to the goal.

public static void main(String[] args) {

	GridWorldDomain gwd = new GridWorldDomain(11, 11);
	gwd.setMapToFourRooms();
	gwd.setProbSucceedTransitionDynamics(0.8);
	gwd.setTf(new GridWorldTerminalFunction(10, 10));

	SADomain domain = gwd.generateDomain();

	//get initial state with agent in 0,0
	State s = new GridWorldState(new GridAgent(0, 0));

	//create environment
	SimulatedEnvironment env = new SimulatedEnvironment(domain, s);

	//create Q-learning
	QLTutorial agent = new QLTutorial(domain, 0.99, new SimpleHashableStateFactory(),
			new ConstantValueFunction(), 0.1, 0.1);

	//run Q-learning and store results in a list
	List<Episode> episodes = new ArrayList<Episode>(1000);
	for(int i = 0; i < 1000; i++){
		episodes.add(agent.runLearningEpisode(env));
		env.resetEnvironment();
	}

	Visualizer v = GridWorldVisualizer.getVisualizer(gwd.getMap());
	new EpisodeSequenceVisualizer(v, domain, episodes);

}