` Building a Domain

Tutorial: Building an OO-MDP Domain

Tutorials > Building an OO-MDP Domain > Part 4



Conclusion

We've now walked you through what an OO-MDP is and how to make your OO-MDP domains using our previous MDP Grid World example from our Building a Domain tutorial as a starting point. We also showed you how to make use of the OO-MDP visualization tools. All of the code from this tutorial is included below and as always can also be found in the burlap_examples repository.

Final Code

ExGridAgent.java

import burlap.mdp.core.oo.state.ObjectInstance;
import burlap.mdp.core.state.MutableState;
import burlap.mdp.core.state.StateUtilities;
import burlap.mdp.core.state.UnknownKeyException;
import burlap.mdp.core.state.annotations.DeepCopyState;

import java.util.Arrays;
import java.util.List;

import static edu.brown.cs.burlap.tutorials.domain.oo.ExampleOOGridWorld.CLASS_AGENT;
import static edu.brown.cs.burlap.tutorials.domain.oo.ExampleOOGridWorld.VAR_X;
import static edu.brown.cs.burlap.tutorials.domain.oo.ExampleOOGridWorld.VAR_Y;


@DeepCopyState
public class ExGridAgent implements ObjectInstance, MutableState {

	public int x;
	public int y;

	public String name = "agent";

	private final static List<Object> keys = Arrays.<Object>asList(VAR_X, VAR_Y);

	public ExGridAgent() {
	}

	public ExGridAgent(int x, int y) {
		this.x = x;
		this.y = y;
	}

	public ExGridAgent(int x, int y, String name) {
		this.x = x;
		this.y = y;
		this.name = name;
	}

	@Override
	public String className() {
		return CLASS_AGENT;
	}

	@Override
	public String name() {
		return name;
	}

	@Override
	public ObjectInstance copyWithName(String objectName) {
		return new ExGridAgent(x, y, objectName);
	}

	@Override
	public MutableState set(Object variableKey, Object value) {
		if(variableKey.equals(VAR_X)){
			this.x = StateUtilities.stringOrNumber(value).intValue();
		}
		else if(variableKey.equals(VAR_Y)){
			this.y = StateUtilities.stringOrNumber(value).intValue();
		}
		else{
			throw new UnknownKeyException(variableKey);
		}
		return this;
	}

	@Override
	public List<Object> variableKeys() {
		return keys;
	}

	@Override
	public Object get(Object variableKey) {
		if(variableKey.equals(VAR_X)){
			return x;
		}
		else if(variableKey.equals(VAR_Y)){
			return y;
		}
		throw new UnknownKeyException(variableKey);
	}

	@Override
	public ExGridAgent copy() {
		return new ExGridAgent(x, y, name);
	}

	@Override
	public String toString() {
		return StateUtilities.stateToString(this);
	}

}

				

ExGridLocation.java

import burlap.mdp.core.oo.state.ObjectInstance;
import burlap.mdp.core.state.MutableState;
import burlap.mdp.core.state.StateUtilities;

import java.util.Arrays;
import java.util.List;

import static edu.brown.cs.burlap.tutorials.domain.oo.ExampleOOGridWorld.*;

public class EXGridLocation extends ExGridAgent{

	public int type;

	private final static List<Object> keys = Arrays.<Object>asList(VAR_X, VAR_Y, VAR_TYPE);

	public EXGridLocation() {
	}

	public EXGridLocation(int x, int y, String name) {
		super(x, y, name);
	}

	public EXGridLocation(int x, int y, int type, String name) {
		super(x, y, name);
		this.type = type;
	}


	@Override
	public List<Object> variableKeys() {
		return keys;
	}

	@Override
	public Object get(Object variableKey) {
		if(variableKey.equals(VAR_TYPE)){
			return this.type;
		}
		return super.get(variableKey);
	}

	@Override
	public MutableState set(Object variableKey, Object value) {
		if(variableKey.equals(VAR_TYPE)){
			this.type = StateUtilities.stringOrNumber(value).intValue();
		}
		else{
			super.set(variableKey, value);
		}
		return this;

	}

	@Override
	public String className() {
		return CLASS_LOCATION;
	}

	@Override
	public ObjectInstance copyWithName(String objectName) {
		return new EXGridLocation(x, y, type, objectName);
	}

	@Override
	public EXGridLocation copy() {
		return new EXGridLocation(x, y, type, name);
	}
}
				

ExampleOOGridWorld.java

import burlap.mdp.auxiliary.DomainGenerator;
import burlap.mdp.auxiliary.common.SinglePFTF;
import burlap.mdp.core.StateTransitionProb;
import burlap.mdp.core.TerminalFunction;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.action.UniversalActionType;
import burlap.mdp.core.oo.OODomain;
import burlap.mdp.core.oo.propositional.PropositionalFunction;
import burlap.mdp.core.oo.state.OOState;
import burlap.mdp.core.oo.state.ObjectInstance;
import burlap.mdp.core.oo.state.generic.GenericOOState;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.common.SingleGoalPFRF;
import burlap.mdp.singleagent.environment.SimulatedEnvironment;
import burlap.mdp.singleagent.model.FactoredModel;
import burlap.mdp.singleagent.model.RewardFunction;
import burlap.mdp.singleagent.model.statemodel.FullStateModel;
import burlap.mdp.singleagent.oo.OOSADomain;
import burlap.shell.visual.VisualExplorer;
import burlap.visualizer.*;

import java.awt.*;
import java.awt.geom.Ellipse2D;
import java.awt.geom.Rectangle2D;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class ExampleOOGridWorld implements DomainGenerator{

	public static final String VAR_X = "x";
	public static final String VAR_Y = "y";
	public static final String VAR_TYPE = "type";

	public static final String CLASS_AGENT = "agent";
	public static final String CLASS_LOCATION = "location";

	public static final String ACTION_NORTH = "north";
	public static final String ACTION_SOUTH = "south";
	public static final String ACTION_EAST = "east";
	public static final String ACTION_WEST = "west";

	public static final String PF_AT = "at";


	//ordered so first dimension is x
	protected int [][] map = new int[][]{
			{0,0,0,0,0,1,0,0,0,0,0},
			{0,0,0,0,0,0,0,0,0,0,0},
			{0,0,0,0,0,1,0,0,0,0,0},
			{0,0,0,0,0,1,0,0,0,0,0},
			{0,0,0,0,0,1,0,0,0,0,0},
			{1,0,1,1,1,1,1,1,0,1,1},
			{0,0,0,0,1,0,0,0,0,0,0},
			{0,0,0,0,1,0,0,0,0,0,0},
			{0,0,0,0,0,0,0,0,0,0,0},
			{0,0,0,0,1,0,0,0,0,0,0},
			{0,0,0,0,1,0,0,0,0,0,0},
	};

	public List<PropositionalFunction> generatePfs(){
		return Arrays.<PropositionalFunction>asList(new AtLocation());
	}

	@Override
	public OOSADomain generateDomain() {

		OOSADomain domain = new OOSADomain();

		domain.addStateClass(CLASS_AGENT, ExGridAgent.class)
				.addStateClass(CLASS_LOCATION, EXGridLocation.class);

		domain.addActionTypes(
				new UniversalActionType(ACTION_NORTH),
				new UniversalActionType(ACTION_SOUTH),
				new UniversalActionType(ACTION_EAST),
				new UniversalActionType(ACTION_WEST));


		OODomain.Helper.addPfsToDomain(domain, this.generatePfs());

		OOGridWorldStateModel smodel = new OOGridWorldStateModel();
		RewardFunction rf = new SingleGoalPFRF(domain.propFunction(PF_AT), 100, -1);
		TerminalFunction tf = new SinglePFTF(domain.propFunction(PF_AT));

		domain.setModel(new FactoredModel(smodel, rf, tf));


		return domain;
	}


	protected class OOGridWorldStateModel implements FullStateModel {


		protected double [][] transitionProbs;

		public OOGridWorldStateModel() {
			this.transitionProbs = new double[4][4];
			for(int i = 0; i < 4; i++){
				for(int j = 0; j < 4; j++){
					double p = i != j ? 0.2/3 : 0.8;
					transitionProbs[i][j] = p;
				}
			}
		}

		public List<StateTransitionProb> stateTransitions(State s, Action a) {

			//get agent current position
			GenericOOState gs = (GenericOOState)s;
			ExGridAgent agent = (ExGridAgent)gs.object(CLASS_AGENT);

			int curX = agent.x;
			int curY = agent.y;

			int adir = actionDir(a);

			List<StateTransitionProb> tps = new ArrayList<StateTransitionProb>(4);
			StateTransitionProb noChange = null;
			for(int i = 0; i < 4; i++){

				int [] newPos = this.moveResult(curX, curY, i);
				if(newPos[0] != curX || newPos[1] != curY){
					//new possible outcome
					GenericOOState ns = gs.copy();
					ExGridAgent nagent = (ExGridAgent)ns.touch(CLASS_AGENT);
					nagent.x = newPos[0];
					nagent.y = newPos[1];

					//create transition probability object and add to our list of outcomes
					tps.add(new StateTransitionProb(ns, this.transitionProbs[adir][i]));
				}
				else{
					//this direction didn't lead anywhere new
					//if there are existing possible directions
					//that wouldn't lead anywhere, aggregate with them
					if(noChange != null){
						noChange.p += this.transitionProbs[adir][i];
					}
					else{
						//otherwise create this new state and transition
						noChange = new StateTransitionProb(s.copy(), this.transitionProbs[adir][i]);
						tps.add(noChange);
					}
				}

			}


			return tps;
		}

		public State sample(State s, Action a) {

			s = s.copy();
			GenericOOState gs = (GenericOOState)s;
			ExGridAgent agent = (ExGridAgent)gs.touch(CLASS_AGENT);
			int curX = agent.x;
			int curY = agent.y;

			int adir = actionDir(a);

			//sample direction with random roll
			double r = Math.random();
			double sumProb = 0.;
			int dir = 0;
			for(int i = 0; i < 4; i++){
				sumProb += this.transitionProbs[adir][i];
				if(r < sumProb){
					dir = i;
					break; //found direction
				}
			}

			//get resulting position
			int [] newPos = this.moveResult(curX, curY, dir);

			//set the new position
			agent.x = newPos[0];
			agent.y = newPos[1];

			//return the state we just modified
			return gs;
		}

		protected int actionDir(Action a){
			int adir = -1;
			if(a.actionName().equals(ACTION_NORTH)){
				adir = 0;
			}
			else if(a.actionName().equals(ACTION_SOUTH)){
				adir = 1;
			}
			else if(a.actionName().equals(ACTION_EAST)){
				adir = 2;
			}
			else if(a.actionName().equals(ACTION_WEST)){
				adir = 3;
			}
			return adir;
		}


		protected int [] moveResult(int curX, int curY, int direction){

			//first get change in x and y from direction using 0: north; 1: south; 2:east; 3: west
			int xdelta = 0;
			int ydelta = 0;
			if(direction == 0){
				ydelta = 1;
			}
			else if(direction == 1){
				ydelta = -1;
			}
			else if(direction == 2){
				xdelta = 1;
			}
			else{
				xdelta = -1;
			}

			int nx = curX + xdelta;
			int ny = curY + ydelta;

			int width = ExampleOOGridWorld.this.map.length;
			int height = ExampleOOGridWorld.this.map[0].length;

			//make sure new position is valid (not a wall or off bounds)
			if(nx < 0 || nx >= width || ny < 0 || ny >= height ||
					ExampleOOGridWorld.this.map[nx][ny] == 1){
				nx = curX;
				ny = curY;
			}


			return new int[]{nx,ny};

		}
	}


	public Visualizer getVisualizer(){
		return new Visualizer(this.getStateRenderLayer());
	}

	public StateRenderLayer getStateRenderLayer(){
		StateRenderLayer rl = new StateRenderLayer();
		rl.addStatePainter(new ExampleOOGridWorld.WallPainter());
		OOStatePainter ooStatePainter = new OOStatePainter();
		ooStatePainter.addObjectClassPainter(CLASS_LOCATION, new LocationPainter());
		ooStatePainter.addObjectClassPainter(CLASS_AGENT, new AgentPainter());
		rl.addStatePainter(ooStatePainter);


		return rl;
	}


	protected class AtLocation extends PropositionalFunction {

		public AtLocation(){
			super(PF_AT, new String []{CLASS_AGENT, CLASS_LOCATION});
		}

		@Override
		public boolean isTrue(OOState s, String... params) {
			ObjectInstance agent = s.object(params[0]);
			ObjectInstance location = s.object(params[1]);

			int ax = (Integer)agent.get(VAR_X);
			int ay = (Integer)agent.get(VAR_Y);

			int lx = (Integer)location.get(VAR_X);
			int ly = (Integer)location.get(VAR_Y);

			return ax == lx && ay == ly;

		}

	}



	public class WallPainter implements StatePainter {

		public void paint(Graphics2D g2, State s, float cWidth, float cHeight) {

			//walls will be filled in black
			g2.setColor(Color.BLACK);

			//set up floats for the width and height of our domain
			float fWidth = ExampleOOGridWorld.this.map.length;
			float fHeight = ExampleOOGridWorld.this.map[0].length;

			//determine the width of a single cell
			//on our canvas such that the whole map can be painted
			float width = cWidth / fWidth;
			float height = cHeight / fHeight;

			//pass through each cell of our map and if it's a wall, paint a black rectangle on our
			//cavas of dimension widthxheight
			for(int i = 0; i < ExampleOOGridWorld.this.map.length; i++){
				for(int j = 0; j < ExampleOOGridWorld.this.map[0].length; j++){

					//is there a wall here?
					if(ExampleOOGridWorld.this.map[i][j] == 1){

						//left coordinate of cell on our canvas
						float rx = i*width;

						//top coordinate of cell on our canvas
						//coordinate system adjustment because the java canvas
						//origin is in the top left instead of the bottom right
						float ry = cHeight - height - j*height;

						//paint the rectangle
						g2.fill(new Rectangle2D.Float(rx, ry, width, height));

					}


				}
			}

		}


	}


	public class AgentPainter implements ObjectPainter {

		@Override
		public void paintObject(Graphics2D g2, OOState s, ObjectInstance ob,
								float cWidth, float cHeight) {

			//agent will be filled in gray
			g2.setColor(Color.GRAY);

			//set up floats for the width and height of our domain
			float fWidth = ExampleOOGridWorld.this.map.length;
			float fHeight = ExampleOOGridWorld.this.map[0].length;

			//determine the width of a single cell on our canvas
			//such that the whole map can be painted
			float width = cWidth / fWidth;
			float height = cHeight / fHeight;

			int ax = (Integer)ob.get(VAR_X);
			int ay = (Integer)ob.get(VAR_Y);

			//left coordinate of cell on our canvas
			float rx = ax*width;

			//top coordinate of cell on our canvas
			//coordinate system adjustment because the java canvas
			//origin is in the top left instead of the bottom right
			float ry = cHeight - height - ay*height;

			//paint the rectangle
			g2.fill(new Ellipse2D.Float(rx, ry, width, height));


		}



	}

	public class LocationPainter implements ObjectPainter {

		@Override
		public void paintObject(Graphics2D g2, OOState s, ObjectInstance ob,
								float cWidth, float cHeight) {

			//agent will be filled in blue
			g2.setColor(Color.BLUE);

			//set up floats for the width and height of our domain
			float fWidth = ExampleOOGridWorld.this.map.length;
			float fHeight = ExampleOOGridWorld.this.map[0].length;

			//determine the width of a single cell on our canvas
			//such that the whole map can be painted
			float width = cWidth / fWidth;
			float height = cHeight / fHeight;

			int ax = (Integer)ob.get(VAR_X);
			int ay = (Integer)ob.get(VAR_Y);

			//left coordinate of cell on our canvas
			float rx = ax*width;

			//top coordinate of cell on our canvas
			//coordinate system adjustment because the java canvas
			//origin is in the top left instead of the bottom right
			float ry = cHeight - height - ay*height;

			//paint the rectangle
			g2.fill(new Rectangle2D.Float(rx, ry, width, height));


		}



	}


	public static void main(String [] args){

		ExampleOOGridWorld gen = new ExampleOOGridWorld();
		OOSADomain domain = gen.generateDomain();
		State initialState = new GenericOOState(new ExGridAgent(0, 0), 
												new EXGridLocation(10, 10, "loc0"));
		SimulatedEnvironment env = new SimulatedEnvironment(domain, initialState);

		Visualizer v = gen.getVisualizer();
		VisualExplorer exp = new VisualExplorer(domain, env, v);

		exp.addKeyAction("w", ACTION_NORTH, "");
		exp.addKeyAction("s", ACTION_SOUTH, "");
		exp.addKeyAction("d", ACTION_EAST, "");
		exp.addKeyAction("a", ACTION_WEST, "");

		exp.initGUI();


	}


}

				
End.