public class UCT extends MDPSolver implements Planner, QFunction
StateConditionTest)
that will cause the planning algorithm to terminate early once it has found a path to the goal. This may be useful if randomly finding the goal state is rare.
QFunction interface. However, it will only return the Q-value
for a state if that state is the root node of the tree. If it is not the root node of the tree, then it will automatically reset the planning results
and replan from that state as the root node and then return the result. This allows the client to use a GreedyQPolicy
with this valueFunction in which it replans with each step in the world, thereby forcing the Q-values for every state to be for the same horizon.
Replanning fresh after each step in the world is the standard UCT approach. If you instead want a policy that walks
through the tree it generated from some source state,
(so that each step computes a Q-value for a shorter horizon than the step before), you can use the
UCTTreeWalkPolicy. The TreeWalkPolicy
will be more computationally efficient than replanning at each step, but may have degrading performance after each step since
each step has a shorter horizon from which to plan and may not have as many samples from which it estimated its Q-value.
QFunction.QFunctionHelper| Modifier and Type | Field and Description |
|---|---|
protected UCTActionNode.UCTActionConstructor |
actionNodeConstructor |
protected double |
explorationBias |
protected boolean |
foundGoal |
protected boolean |
foundGoalOnRollout |
protected StateConditionTest |
goalCondition |
protected int |
maxHorizon |
protected int |
maxRollOutsFromRoot |
protected int |
numRollOutsFromRoot |
protected int |
numVisits |
protected java.util.Random |
rand |
protected UCTStateNode |
root |
protected java.util.List<java.util.Map<HashableState,UCTStateNode>> |
stateDepthIndex |
protected UCTStateNode.UCTStateConstructor |
stateNodeConstructor |
protected java.util.Map<HashableState,java.util.List<UCTStateNode>> |
statesToStateNodes |
protected int |
treeSize |
protected java.util.Set<HashableState> |
uniqueStatesInTree |
actions, debugCode, domain, gamma, hashingFactory, mapToStateIndex, rf, tf| Constructor and Description |
|---|
UCT(Domain domain,
RewardFunction rf,
TerminalFunction tf,
double gamma,
HashableStateFactory hashingFactory,
int horizon,
int nRollouts,
int explorationBias)
Initializes UCT
|
| Modifier and Type | Method and Description |
|---|---|
protected void |
addNodeToIndexTree(UCTStateNode snode)
Adds a
UCTStateNode to the UCT tree |
protected UCTActionNode |
bestReturnAction(UCTStateNode snode)
Returns the
UCTActionNode with the highest average sample Q-value. |
protected double |
computeUCTQ(UCTStateNode snode,
UCTActionNode anode)
Returns the upper confidence Q-value for a given state node and action node.
|
protected boolean |
containsActionPreference(UCTStateNode snode)
Returns true if the sample returns for any actions are different
|
protected double |
explorationQBoost(int ns,
int na)
Returns the extra value added to the average sample Q-value that is sued to produce the upper confidence Q-value.
|
QValue |
getQ(State s,
AbstractGroundedAction a)
Returns the
QValue for the given state-action pair. |
java.util.List<QValue> |
getQs(State s)
Returns a
List of QValue objects for ever permissible action for the given input state. |
UCTStateNode |
getRoot()
Returns the root node of the UCT tree.
|
protected void |
initializeRollOut() |
GreedyQPolicy |
planFromState(State initialState)
Plans from the input state and then returns a
GreedyQPolicy that greedily
selects the action with the highest Q-value and breaks ties uniformly randomly. |
protected UCTStateNode |
queryTreeIndex(HashableState sh,
int d)
Returns the
UCTStateNode for the given (hashed) state at the given depth. |
void |
resetSolver()
This method resets all solver results so that a solver can be restarted fresh
as if had never solved the MDP.
|
protected UCTActionNode |
selectActionNode(UCTStateNode snode)
Selections which action to take.
|
boolean |
stopPlanning()
Returns true if rollouts and planning should cease.
|
double |
treeRollOut(UCTStateNode node,
int depth,
int childrenLeftToAdd)
Performs a rollout in the UCT tree from the given node, keeping track of how many new nodes can be added to the tree.
|
protected void |
UCTInit(Domain domain,
RewardFunction rf,
TerminalFunction tf,
double gamma,
HashableStateFactory hashingFactory,
int horizon,
int nRollouts,
int explorationBias) |
void |
useGoalConditionStopCriteria(StateConditionTest gc)
Tells the valueFunction to stop planning if a goal state is ever found.
|
double |
value(State s)
Returns the value function evaluation of the given state.
|
addNonDomainReferencedAction, getActions, getAllGroundedActions, getDebugCode, getDomain, getGamma, getHashingFactory, getRf, getRF, getTf, getTF, setActions, setDebugCode, setDomain, setGamma, setHashingFactory, setRf, setTf, solverInit, stateHash, toggleDebugPrinting, translateActionclone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitaddNonDomainReferencedAction, getActions, getDebugCode, getDomain, getGamma, getHashingFactory, getRf, getRF, getTf, getTF, setActions, setDebugCode, setDomain, setGamma, setHashingFactory, setRf, setTf, solverInit, toggleDebugPrintingprotected java.util.List<java.util.Map<HashableState,UCTStateNode>> stateDepthIndex
protected java.util.Map<HashableState,java.util.List<UCTStateNode>> statesToStateNodes
protected UCTStateNode root
protected int maxHorizon
protected int maxRollOutsFromRoot
protected int numRollOutsFromRoot
protected double explorationBias
protected UCTStateNode.UCTStateConstructor stateNodeConstructor
protected UCTActionNode.UCTActionConstructor actionNodeConstructor
protected StateConditionTest goalCondition
protected boolean foundGoal
protected boolean foundGoalOnRollout
protected java.util.Set<HashableState> uniqueStatesInTree
protected int treeSize
protected int numVisits
protected java.util.Random rand
public UCT(Domain domain, RewardFunction rf, TerminalFunction tf, double gamma, HashableStateFactory hashingFactory, int horizon, int nRollouts, int explorationBias)
domain - the domain in which to planrf - the reward function to usetf - the terminal function to usegamma - the discount factorhashingFactory - the state hashing factoryhorizon - the planning horizonnRollouts - the number of rollouts to performexplorationBias - the exploration bias constant (suggested >2)protected void UCTInit(Domain domain, RewardFunction rf, TerminalFunction tf, double gamma, HashableStateFactory hashingFactory, int horizon, int nRollouts, int explorationBias)
public UCTStateNode getRoot()
public void useGoalConditionStopCriteria(StateConditionTest gc)
gc - a StateConditionTest object used to specify goal states (whereever it evaluates as true).public GreedyQPolicy planFromState(State initialState)
GreedyQPolicy that greedily
selects the action with the highest Q-value and breaks ties uniformly randomly.planFromState in interface PlannerinitialState - the initial state of the planning problemGreedyQPolicy.public java.util.List<QValue> getQs(State s)
QFunctionList of QValue objects for ever permissible action for the given input state.public QValue getQ(State s, AbstractGroundedAction a)
QFunctionQValue for the given state-action pair.public double value(State s)
ValueFunctionvalue in interface ValueFunctions - the state to evaluate.public void resetSolver()
MDPSolverInterfaceresetSolver in interface MDPSolverInterfaceresetSolver in class MDPSolverprotected void initializeRollOut()
public double treeRollOut(UCTStateNode node, int depth, int childrenLeftToAdd)
node - the node from which to rolloutdepth - the depth of the nodechildrenLeftToAdd - the number of new subsequent nodes that can be connected to the treepublic boolean stopPlanning()
protected UCTActionNode selectActionNode(UCTStateNode snode)
snode - the UCT node from which to select an action.UCTActionNode to be taken.protected double computeUCTQ(UCTStateNode snode, UCTActionNode anode)
snode - the state nodeanode - the action nodeprotected double explorationQBoost(int ns,
int na)
ns - the number of times the state node has been visitedna - the number of times the action node has been visitedprotected UCTStateNode queryTreeIndex(HashableState sh, int d)
UCTStateNode for the given (hashed) state at the given depth.sh - the state whose node should be returnedd - the depth of the stateUCTStateNodeprotected void addNodeToIndexTree(UCTStateNode snode)
UCTStateNode to the UCT treesnode - the UCTStateNode to addprotected UCTActionNode bestReturnAction(UCTStateNode snode)
UCTActionNode with the highest average sample Q-value. Ties are broken by returning the first UCTActionNode with the highest value.snode - the UCTStateNode to queryUCTActionNode with the highest average sample Q-valueprotected boolean containsActionPreference(UCTStateNode snode)
snode - the node to check for an action preference