public class UCT extends MDPSolver implements Planner, QProvider
StateConditionTest
)
that will cause the planning algorithm to terminate early once it has found a path to the goal. This may be useful if randomly finding the goal state is rare.
The class also implements the QProvider
interface. However, it will only return the Q-value
for a state if that state is the root node of the tree. If it is not the root node of the tree, then it will automatically reset the planning results
and replan from that state as the root node and then return the result. This allows the client to use a GreedyQPolicy
with this valueFunction in which it replans with each step in the world, thereby forcing the Q-values for every state to be for the same horizon.
Replanning fresh after each step in the world is the standard UCT approach. If you instead want a policy that walks
through the tree it generated from some source state,
(so that each step computes a Q-value for a shorter horizon than the step before), you can use the
UCTTreeWalkPolicy
. The TreeWalkPolicy
will be more computationally efficient than replanning at each step, but may have degrading performance after each step since
each step has a shorter horizon from which to plan and may not have as many samples from which it estimated its Q-value.
1. Kocsis, Levente, and Csaba Szepesvari. "Bandit based monte-carlo planning." ECML (2006). 282-293.
QProvider.Helper
Modifier and Type | Field and Description |
---|---|
protected UCTActionNode.UCTActionConstructor |
actionNodeConstructor |
protected double |
explorationBias |
protected boolean |
foundGoal |
protected boolean |
foundGoalOnRollout |
protected StateConditionTest |
goalCondition |
protected int |
maxHorizon |
protected int |
maxRollOutsFromRoot |
protected int |
numRollOutsFromRoot |
protected int |
numVisits |
protected java.util.Random |
rand |
protected UCTStateNode |
root |
protected java.util.List<java.util.Map<HashableState,UCTStateNode>> |
stateDepthIndex |
protected UCTStateNode.UCTStateConstructor |
stateNodeConstructor |
protected java.util.Map<HashableState,java.util.List<UCTStateNode>> |
statesToStateNodes |
protected int |
treeSize |
protected java.util.Set<HashableState> |
uniqueStatesInTree |
actionTypes, debugCode, domain, gamma, hashingFactory, model, usingOptionModel
Constructor and Description |
---|
UCT(SADomain domain,
double gamma,
HashableStateFactory hashingFactory,
int horizon,
int nRollouts,
int explorationBias)
Initializes UCT
|
Modifier and Type | Method and Description |
---|---|
protected void |
addNodeToIndexTree(UCTStateNode snode)
Adds a
UCTStateNode to the UCT tree |
protected UCTActionNode |
bestReturnAction(UCTStateNode snode)
Returns the
UCTActionNode with the highest average sample Q-value. |
protected double |
computeUCTQ(UCTStateNode snode,
UCTActionNode anode)
Returns the upper confidence Q-value for a given state node and action node.
|
protected boolean |
containsActionPreference(UCTStateNode snode)
Returns true if the sample returns for any actions are different
|
protected double |
explorationQBoost(int ns,
int na)
Returns the extra value added to the average sample Q-value that is sued to produce the upper confidence Q-value.
|
UCTStateNode |
getRoot()
Returns the root node of the UCT tree.
|
protected void |
initializeRollOut() |
GreedyQPolicy |
planFromState(State initialState)
Plans from the input state and then returns a
GreedyQPolicy that greedily
selects the action with the highest Q-value and breaks ties uniformly randomly. |
protected UCTStateNode |
queryTreeIndex(HashableState sh,
int d)
Returns the
UCTStateNode for the given (hashed) state at the given depth. |
double |
qValue(State s,
Action a)
Returns the
QValue for the given state-action pair. |
java.util.List<QValue> |
qValues(State s)
Returns a
List of QValue objects for ever permissible action for the given input state. |
void |
resetSolver()
This method resets all solver results so that a solver can be restarted fresh
as if had never solved the MDP.
|
protected UCTActionNode |
selectActionNode(UCTStateNode snode)
Selections which action to take.
|
boolean |
stopPlanning()
Returns true if rollouts and planning should cease.
|
double |
treeRollOut(UCTStateNode node,
int depth,
int childrenLeftToAdd)
Performs a rollout in the UCT tree from the given node, keeping track of how many new nodes can be added to the tree.
|
protected void |
UCTInit(SADomain domain,
double gamma,
HashableStateFactory hashingFactory,
int horizon,
int nRollouts,
int explorationBias) |
void |
useGoalConditionStopCriteria(StateConditionTest gc)
Tells the valueFunction to stop planning if a goal state is ever found.
|
double |
value(State s)
Returns the value function evaluation of the given state.
|
addActionType, applicableActions, getActionTypes, getDebugCode, getDomain, getGamma, getHashingFactory, getModel, setActionTypes, setDebugCode, setDomain, setGamma, setHashingFactory, setModel, solverInit, stateHash, toggleDebugPrinting
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
addActionType, getActionTypes, getDebugCode, getDomain, getGamma, getHashingFactory, getModel, setActionTypes, setDebugCode, setDomain, setGamma, setHashingFactory, setModel, solverInit, toggleDebugPrinting
protected java.util.List<java.util.Map<HashableState,UCTStateNode>> stateDepthIndex
protected java.util.Map<HashableState,java.util.List<UCTStateNode>> statesToStateNodes
protected UCTStateNode root
protected int maxHorizon
protected int maxRollOutsFromRoot
protected int numRollOutsFromRoot
protected double explorationBias
protected UCTStateNode.UCTStateConstructor stateNodeConstructor
protected UCTActionNode.UCTActionConstructor actionNodeConstructor
protected StateConditionTest goalCondition
protected boolean foundGoal
protected boolean foundGoalOnRollout
protected java.util.Set<HashableState> uniqueStatesInTree
protected int treeSize
protected int numVisits
protected java.util.Random rand
public UCT(SADomain domain, double gamma, HashableStateFactory hashingFactory, int horizon, int nRollouts, int explorationBias)
domain
- the domain in which to plangamma
- the discount factorhashingFactory
- the state hashing factoryhorizon
- the planning horizonnRollouts
- the number of rollouts to performexplorationBias
- the exploration bias constant (suggested >2)protected void UCTInit(SADomain domain, double gamma, HashableStateFactory hashingFactory, int horizon, int nRollouts, int explorationBias)
public UCTStateNode getRoot()
public void useGoalConditionStopCriteria(StateConditionTest gc)
gc
- a StateConditionTest
object used to specify goal states (whereever it evaluates as true).public GreedyQPolicy planFromState(State initialState)
GreedyQPolicy
that greedily
selects the action with the highest Q-value and breaks ties uniformly randomly.planFromState
in interface Planner
initialState
- the initial state of the planning problemGreedyQPolicy
.public java.util.List<QValue> qValues(State s)
QProvider
List
of QValue
objects for ever permissible action for the given input state.public double qValue(State s, Action a)
QFunction
QValue
for the given state-action pair.public double value(State s)
ValueFunction
value
in interface ValueFunction
s
- the state to evaluate.public void resetSolver()
MDPSolverInterface
resetSolver
in interface MDPSolverInterface
resetSolver
in class MDPSolver
protected void initializeRollOut()
public double treeRollOut(UCTStateNode node, int depth, int childrenLeftToAdd)
node
- the node from which to rolloutdepth
- the depth of the nodechildrenLeftToAdd
- the number of new subsequent nodes that can be connected to the treepublic boolean stopPlanning()
protected UCTActionNode selectActionNode(UCTStateNode snode)
snode
- the UCT node from which to select an action.UCTActionNode
to be taken.protected double computeUCTQ(UCTStateNode snode, UCTActionNode anode)
snode
- the state nodeanode
- the action nodeprotected double explorationQBoost(int ns, int na)
ns
- the number of times the state node has been visitedna
- the number of times the action node has been visitedprotected UCTStateNode queryTreeIndex(HashableState sh, int d)
UCTStateNode
for the given (hashed) state at the given depth.sh
- the state whose node should be returnedd
- the depth of the stateUCTStateNode
protected void addNodeToIndexTree(UCTStateNode snode)
UCTStateNode
to the UCT treesnode
- the UCTStateNode
to addprotected UCTActionNode bestReturnAction(UCTStateNode snode)
UCTActionNode
with the highest average sample Q-value. Ties are broken by returning the first UCTActionNode
with the highest value.snode
- the UCTStateNode
to queryUCTActionNode
with the highest average sample Q-valueprotected boolean containsActionPreference(UCTStateNode snode)
snode
- the node to check for an action preference