`
We've now walked you through what an OO-MDP is and how to make your OO-MDP domains using our previous MDP Grid World example from our Building a Domain tutorial as a starting point. We also showed you how to make use of the OO-MDP visualization tools. All of the code from this tutorial is included below and as always can also be found in the burlap_examples repository.
import burlap.mdp.core.oo.state.ObjectInstance; import burlap.mdp.core.state.MutableState; import burlap.mdp.core.state.StateUtilities; import burlap.mdp.core.state.UnknownKeyException; import burlap.mdp.core.state.annotations.DeepCopyState; import java.util.Arrays; import java.util.List; import static edu.brown.cs.burlap.tutorials.domain.oo.ExampleOOGridWorld.CLASS_AGENT; import static edu.brown.cs.burlap.tutorials.domain.oo.ExampleOOGridWorld.VAR_X; import static edu.brown.cs.burlap.tutorials.domain.oo.ExampleOOGridWorld.VAR_Y; @DeepCopyState public class ExGridAgent implements ObjectInstance, MutableState { public int x; public int y; public String name = "agent"; private final static List<Object> keys = Arrays.<Object>asList(VAR_X, VAR_Y); public ExGridAgent() { } public ExGridAgent(int x, int y) { this.x = x; this.y = y; } public ExGridAgent(int x, int y, String name) { this.x = x; this.y = y; this.name = name; } @Override public String className() { return CLASS_AGENT; } @Override public String name() { return name; } @Override public ObjectInstance copyWithName(String objectName) { return new ExGridAgent(x, y, objectName); } @Override public MutableState set(Object variableKey, Object value) { if(variableKey.equals(VAR_X)){ this.x = StateUtilities.stringOrNumber(value).intValue(); } else if(variableKey.equals(VAR_Y)){ this.y = StateUtilities.stringOrNumber(value).intValue(); } else{ throw new UnknownKeyException(variableKey); } return this; } @Override public List<Object> variableKeys() { return keys; } @Override public Object get(Object variableKey) { if(variableKey.equals(VAR_X)){ return x; } else if(variableKey.equals(VAR_Y)){ return y; } throw new UnknownKeyException(variableKey); } @Override public ExGridAgent copy() { return new ExGridAgent(x, y, name); } @Override public String toString() { return StateUtilities.stateToString(this); } }
import burlap.mdp.core.oo.state.ObjectInstance; import burlap.mdp.core.state.MutableState; import burlap.mdp.core.state.StateUtilities; import java.util.Arrays; import java.util.List; import static edu.brown.cs.burlap.tutorials.domain.oo.ExampleOOGridWorld.*; public class EXGridLocation extends ExGridAgent{ public int type; private final static List<Object> keys = Arrays.<Object>asList(VAR_X, VAR_Y, VAR_TYPE); public EXGridLocation() { } public EXGridLocation(int x, int y, String name) { super(x, y, name); } public EXGridLocation(int x, int y, int type, String name) { super(x, y, name); this.type = type; } @Override public List<Object> variableKeys() { return keys; } @Override public Object get(Object variableKey) { if(variableKey.equals(VAR_TYPE)){ return this.type; } return super.get(variableKey); } @Override public MutableState set(Object variableKey, Object value) { if(variableKey.equals(VAR_TYPE)){ this.type = StateUtilities.stringOrNumber(value).intValue(); } else{ super.set(variableKey, value); } return this; } @Override public String className() { return CLASS_LOCATION; } @Override public ObjectInstance copyWithName(String objectName) { return new EXGridLocation(x, y, type, objectName); } @Override public EXGridLocation copy() { return new EXGridLocation(x, y, type, name); } }
import burlap.mdp.auxiliary.DomainGenerator; import burlap.mdp.auxiliary.common.SinglePFTF; import burlap.mdp.core.StateTransitionProb; import burlap.mdp.core.TerminalFunction; import burlap.mdp.core.action.Action; import burlap.mdp.core.action.UniversalActionType; import burlap.mdp.core.oo.OODomain; import burlap.mdp.core.oo.propositional.PropositionalFunction; import burlap.mdp.core.oo.state.OOState; import burlap.mdp.core.oo.state.ObjectInstance; import burlap.mdp.core.oo.state.generic.GenericOOState; import burlap.mdp.core.state.State; import burlap.mdp.singleagent.common.SingleGoalPFRF; import burlap.mdp.singleagent.environment.SimulatedEnvironment; import burlap.mdp.singleagent.model.FactoredModel; import burlap.mdp.singleagent.model.RewardFunction; import burlap.mdp.singleagent.model.statemodel.FullStateModel; import burlap.mdp.singleagent.oo.OOSADomain; import burlap.shell.visual.VisualExplorer; import burlap.visualizer.*; import java.awt.*; import java.awt.geom.Ellipse2D; import java.awt.geom.Rectangle2D; import java.util.ArrayList; import java.util.Arrays; import java.util.List; public class ExampleOOGridWorld implements DomainGenerator{ public static final String VAR_X = "x"; public static final String VAR_Y = "y"; public static final String VAR_TYPE = "type"; public static final String CLASS_AGENT = "agent"; public static final String CLASS_LOCATION = "location"; public static final String ACTION_NORTH = "north"; public static final String ACTION_SOUTH = "south"; public static final String ACTION_EAST = "east"; public static final String ACTION_WEST = "west"; public static final String PF_AT = "at"; //ordered so first dimension is x protected int [][] map = new int[][]{ {0,0,0,0,0,1,0,0,0,0,0}, {0,0,0,0,0,0,0,0,0,0,0}, {0,0,0,0,0,1,0,0,0,0,0}, {0,0,0,0,0,1,0,0,0,0,0}, {0,0,0,0,0,1,0,0,0,0,0}, {1,0,1,1,1,1,1,1,0,1,1}, {0,0,0,0,1,0,0,0,0,0,0}, {0,0,0,0,1,0,0,0,0,0,0}, {0,0,0,0,0,0,0,0,0,0,0}, {0,0,0,0,1,0,0,0,0,0,0}, {0,0,0,0,1,0,0,0,0,0,0}, }; public List<PropositionalFunction> generatePfs(){ return Arrays.<PropositionalFunction>asList(new AtLocation()); } @Override public OOSADomain generateDomain() { OOSADomain domain = new OOSADomain(); domain.addStateClass(CLASS_AGENT, ExGridAgent.class) .addStateClass(CLASS_LOCATION, EXGridLocation.class); domain.addActionTypes( new UniversalActionType(ACTION_NORTH), new UniversalActionType(ACTION_SOUTH), new UniversalActionType(ACTION_EAST), new UniversalActionType(ACTION_WEST)); OODomain.Helper.addPfsToDomain(domain, this.generatePfs()); OOGridWorldStateModel smodel = new OOGridWorldStateModel(); RewardFunction rf = new SingleGoalPFRF(domain.propFunction(PF_AT), 100, -1); TerminalFunction tf = new SinglePFTF(domain.propFunction(PF_AT)); domain.setModel(new FactoredModel(smodel, rf, tf)); return domain; } protected class OOGridWorldStateModel implements FullStateModel { protected double [][] transitionProbs; public OOGridWorldStateModel() { this.transitionProbs = new double[4][4]; for(int i = 0; i < 4; i++){ for(int j = 0; j < 4; j++){ double p = i != j ? 0.2/3 : 0.8; transitionProbs[i][j] = p; } } } public List<StateTransitionProb> stateTransitions(State s, Action a) { //get agent current position GenericOOState gs = (GenericOOState)s; ExGridAgent agent = (ExGridAgent)gs.object(CLASS_AGENT); int curX = agent.x; int curY = agent.y; int adir = actionDir(a); List<StateTransitionProb> tps = new ArrayList<StateTransitionProb>(4); StateTransitionProb noChange = null; for(int i = 0; i < 4; i++){ int [] newPos = this.moveResult(curX, curY, i); if(newPos[0] != curX || newPos[1] != curY){ //new possible outcome GenericOOState ns = gs.copy(); ExGridAgent nagent = (ExGridAgent)ns.touch(CLASS_AGENT); nagent.x = newPos[0]; nagent.y = newPos[1]; //create transition probability object and add to our list of outcomes tps.add(new StateTransitionProb(ns, this.transitionProbs[adir][i])); } else{ //this direction didn't lead anywhere new //if there are existing possible directions //that wouldn't lead anywhere, aggregate with them if(noChange != null){ noChange.p += this.transitionProbs[adir][i]; } else{ //otherwise create this new state and transition noChange = new StateTransitionProb(s.copy(), this.transitionProbs[adir][i]); tps.add(noChange); } } } return tps; } public State sample(State s, Action a) { s = s.copy(); GenericOOState gs = (GenericOOState)s; ExGridAgent agent = (ExGridAgent)gs.touch(CLASS_AGENT); int curX = agent.x; int curY = agent.y; int adir = actionDir(a); //sample direction with random roll double r = Math.random(); double sumProb = 0.; int dir = 0; for(int i = 0; i < 4; i++){ sumProb += this.transitionProbs[adir][i]; if(r < sumProb){ dir = i; break; //found direction } } //get resulting position int [] newPos = this.moveResult(curX, curY, dir); //set the new position agent.x = newPos[0]; agent.y = newPos[1]; //return the state we just modified return gs; } protected int actionDir(Action a){ int adir = -1; if(a.actionName().equals(ACTION_NORTH)){ adir = 0; } else if(a.actionName().equals(ACTION_SOUTH)){ adir = 1; } else if(a.actionName().equals(ACTION_EAST)){ adir = 2; } else if(a.actionName().equals(ACTION_WEST)){ adir = 3; } return adir; } protected int [] moveResult(int curX, int curY, int direction){ //first get change in x and y from direction using 0: north; 1: south; 2:east; 3: west int xdelta = 0; int ydelta = 0; if(direction == 0){ ydelta = 1; } else if(direction == 1){ ydelta = -1; } else if(direction == 2){ xdelta = 1; } else{ xdelta = -1; } int nx = curX + xdelta; int ny = curY + ydelta; int width = ExampleOOGridWorld.this.map.length; int height = ExampleOOGridWorld.this.map[0].length; //make sure new position is valid (not a wall or off bounds) if(nx < 0 || nx >= width || ny < 0 || ny >= height || ExampleOOGridWorld.this.map[nx][ny] == 1){ nx = curX; ny = curY; } return new int[]{nx,ny}; } } public Visualizer getVisualizer(){ return new Visualizer(this.getStateRenderLayer()); } public StateRenderLayer getStateRenderLayer(){ StateRenderLayer rl = new StateRenderLayer(); rl.addStatePainter(new ExampleOOGridWorld.WallPainter()); OOStatePainter ooStatePainter = new OOStatePainter(); ooStatePainter.addObjectClassPainter(CLASS_LOCATION, new LocationPainter()); ooStatePainter.addObjectClassPainter(CLASS_AGENT, new AgentPainter()); rl.addStatePainter(ooStatePainter); return rl; } protected class AtLocation extends PropositionalFunction { public AtLocation(){ super(PF_AT, new String []{CLASS_AGENT, CLASS_LOCATION}); } @Override public boolean isTrue(OOState s, String... params) { ObjectInstance agent = s.object(params[0]); ObjectInstance location = s.object(params[1]); int ax = (Integer)agent.get(VAR_X); int ay = (Integer)agent.get(VAR_Y); int lx = (Integer)location.get(VAR_X); int ly = (Integer)location.get(VAR_Y); return ax == lx && ay == ly; } } public class WallPainter implements StatePainter { public void paint(Graphics2D g2, State s, float cWidth, float cHeight) { //walls will be filled in black g2.setColor(Color.BLACK); //set up floats for the width and height of our domain float fWidth = ExampleOOGridWorld.this.map.length; float fHeight = ExampleOOGridWorld.this.map[0].length; //determine the width of a single cell //on our canvas such that the whole map can be painted float width = cWidth / fWidth; float height = cHeight / fHeight; //pass through each cell of our map and if it's a wall, paint a black rectangle on our //cavas of dimension widthxheight for(int i = 0; i < ExampleOOGridWorld.this.map.length; i++){ for(int j = 0; j < ExampleOOGridWorld.this.map[0].length; j++){ //is there a wall here? if(ExampleOOGridWorld.this.map[i][j] == 1){ //left coordinate of cell on our canvas float rx = i*width; //top coordinate of cell on our canvas //coordinate system adjustment because the java canvas //origin is in the top left instead of the bottom right float ry = cHeight - height - j*height; //paint the rectangle g2.fill(new Rectangle2D.Float(rx, ry, width, height)); } } } } } public class AgentPainter implements ObjectPainter { @Override public void paintObject(Graphics2D g2, OOState s, ObjectInstance ob, float cWidth, float cHeight) { //agent will be filled in gray g2.setColor(Color.GRAY); //set up floats for the width and height of our domain float fWidth = ExampleOOGridWorld.this.map.length; float fHeight = ExampleOOGridWorld.this.map[0].length; //determine the width of a single cell on our canvas //such that the whole map can be painted float width = cWidth / fWidth; float height = cHeight / fHeight; int ax = (Integer)ob.get(VAR_X); int ay = (Integer)ob.get(VAR_Y); //left coordinate of cell on our canvas float rx = ax*width; //top coordinate of cell on our canvas //coordinate system adjustment because the java canvas //origin is in the top left instead of the bottom right float ry = cHeight - height - ay*height; //paint the rectangle g2.fill(new Ellipse2D.Float(rx, ry, width, height)); } } public class LocationPainter implements ObjectPainter { @Override public void paintObject(Graphics2D g2, OOState s, ObjectInstance ob, float cWidth, float cHeight) { //agent will be filled in blue g2.setColor(Color.BLUE); //set up floats for the width and height of our domain float fWidth = ExampleOOGridWorld.this.map.length; float fHeight = ExampleOOGridWorld.this.map[0].length; //determine the width of a single cell on our canvas //such that the whole map can be painted float width = cWidth / fWidth; float height = cHeight / fHeight; int ax = (Integer)ob.get(VAR_X); int ay = (Integer)ob.get(VAR_Y); //left coordinate of cell on our canvas float rx = ax*width; //top coordinate of cell on our canvas //coordinate system adjustment because the java canvas //origin is in the top left instead of the bottom right float ry = cHeight - height - ay*height; //paint the rectangle g2.fill(new Rectangle2D.Float(rx, ry, width, height)); } } public static void main(String [] args){ ExampleOOGridWorld gen = new ExampleOOGridWorld(); OOSADomain domain = gen.generateDomain(); State initialState = new GenericOOState(new ExGridAgent(0, 0), new EXGridLocation(10, 10, "loc0")); SimulatedEnvironment env = new SimulatedEnvironment(domain, initialState); Visualizer v = gen.getVisualizer(); VisualExplorer exp = new VisualExplorer(domain, env, v); exp.addKeyAction("w", ACTION_NORTH, ""); exp.addKeyAction("s", ACTION_SOUTH, ""); exp.addKeyAction("d", ACTION_EAST, ""); exp.addKeyAction("a", ACTION_WEST, ""); exp.initGUI(); } }