public class SparseSampling extends MDPSolver implements QProvider, Planner
setHAndCByMDPError(double, double, int)
method to ensure this, however, the required horizon and C will probably be intractably large). SS must replan for every new state it sees, so an agent following it in general must replan after every step it takes in the real world. Using a
Q-based Policy
object, will ensure this behavior because this algorithm will call the valueFunction whenever it's queried for the Q-value for a state it has not seen.
The algorithm operates by building a tree from the source initial state. The tree is built by sampling C outcome states for each possible state-action pair, thereby generating new state nodes in the tree. The tree is built out to a fixed height H and then in a tail recursive way, the Q-value and state value is estimated using a Bellman update as if the C samples perfectly defined the transition dynamics. Because the values are are based on fixed horizon and computed in a recursive way, only one bellman update is required per node.
Although the complexity of the algorithm is independent of the state space size, it is exponential in the height of the tree, so if a large tree height is required to make good value function estimates, this algorithm may not be appropriate. Therefore, when rewards are sparse or uniform except at a distant horizon, this may not be an appropriate algorithm choice.
By default, this class will remember the estimated Q-value for every state from which the planFromState(State)
method was called (which will be indirectly called
by the Q-value query methods if it does not have the Q-value for it) and it will also remember the value of state tree nodes it computed so that they may be reused in
subsequent tree creations, thereby limiting the amount of additional computation required. However, if memory is scarce, the class can be told to forget all prior planning
results, except the Q-value estimate for the most recently planned for state, by using the forgetPreviousPlanResults
method.
By default, the C parameter (number of state transition samples) is fixed for all nodes; however, it may also be set to use a variable C that reduces the number of sampled states the further down in the tree it is according to C_i = C_0 * gamma^(2i), where i is the depth of the node from the root and gamma is the discount factor.
By default, the state value of leafs will be set to 0, but this value can be changed by providing a ValueFunction
object via the
setValueForLeafNodes(ValueFunction)
method. Using a non-zero heuristic value may reduce the need for a large tree height.
This class will work with Option
s, but including options will necessarily *increase* the computational complexity, so they are not recommended.
This class requires a HashableStateFactory
.
This class can optionally be set to not use sampling and instead use the full Bellman update, which results in the exact finite horizon Q-value being computed.
However, this should only be done when the number of possible state transitions is small and when the full model for the domain is defined (that is,
all the model implements FullModel
). To set this class to compute the exact finite horizon value function, use the
setComputeExactValueFunction(boolean)
method. Note that you cannot use Option
s when using the full Bellman update.
1. Kearns, Michael, Yishay Mansour, and Andrew Y. Ng. "A sparse sampling algorithm for near-optimal planning in large Markov decision processes." Machine Learning 49.2-3 (2002): 193-208.
Modifier and Type | Class and Description |
---|---|
static class |
SparseSampling.HashedHeightState
Tuple for a state and its height in a tree that can be hashed for quick retrieval.
|
class |
SparseSampling.StateNode
A class for state nodes.
|
QProvider.Helper
Modifier and Type | Field and Description |
---|---|
protected int |
c
The number of transition dynamics samples (for the root if depth-variable C is used)
|
protected boolean |
computeExactValueFunction
This parameter indicates whether the exact finite horizon value function is computed or whether sparse sampling
to estimate should be used.
|
protected boolean |
forgetPreviousPlanResults
Whether previous planning results should be forgetten or reused; default is reused (false).
|
protected int |
h
The height of the tree
|
protected java.util.Map<SparseSampling.HashedHeightState,SparseSampling.StateNode> |
nodesByHeight
The tree nodes indexed by state and height.
|
protected int |
numUpdates
The total number of pseudo-Bellman updates
|
protected DPOperator |
operator
The operator used for back ups.
|
protected java.util.Map<HashableState,java.util.List<QValue>> |
rootLevelQValues
The root state node Q-values that have been estimated by previous planning calls.
|
protected boolean |
useVariableC
Whether the number of transition dyanmic samples should scale with the depth of the node.
|
protected ValueFunction |
vinit
The state value used for leaf nodes; default is zero.
|
actionTypes, debugCode, domain, gamma, hashingFactory, model, usingOptionModel
Constructor and Description |
---|
SparseSampling(SADomain domain,
double gamma,
HashableStateFactory hashingFactory,
int h,
int c)
Initializes.
|
Modifier and Type | Method and Description |
---|---|
boolean |
computesExactValueFunction()
Returns whether this valueFunction computes the exact finite horizon value function (by using the full transition dynamics) or whether
it estimates the value function with sampling.
|
int |
getC()
Returns the number of state transition samples
|
protected int |
getCAtHeight(int height)
Returns the value of C for a node at the given height (height from a leaf node).
|
int |
getDebugCode()
Returns the debug code used for logging plan results with
DPrint . |
int |
getH()
Returns the height of the tree
|
int |
getNumberOfStateNodesCreated()
Returns the total number of state nodes that have been created.
|
int |
getNumberOfValueEsitmates()
Returns the total number of state value estimates performed since the
resetSolver() call. |
DPOperator |
getOperator() |
protected SparseSampling.StateNode |
getStateNode(State s,
int height)
Either returns, or creates, indexes, and returns, the state node for the given state at the given height in the tree
|
protected static double |
logbase(double base,
double x)
Retuns the log value at the given bases.
|
GreedyQPolicy |
planFromState(State initialState)
Plans from the input state and then returns a
GreedyQPolicy that greedily
selects the action with the highest Q-value and breaks ties uniformly randomly. |
double |
qValue(State s,
Action a)
Returns the
QValue for the given state-action pair. |
java.util.List<QValue> |
qValues(State s)
Returns a
List of QValue objects for ever permissible action for the given input state. |
void |
resetSolver()
This method resets all solver results so that a solver can be restarted fresh
as if had never solved the MDP.
|
void |
setC(int c)
Sets the number of state transition samples used.
|
void |
setComputeExactValueFunction(boolean computeExactValueFunction)
Sets whether this valueFunction will compute the exact finite horizon value function (using the full transition dynamics) or if sampling
to estimate the value function will be used.
|
void |
setDebugCode(int debugCode)
Sets the debug code used for logging plan results with
DPrint . |
void |
setForgetPreviousPlanResults(boolean forgetPreviousPlanResults)
Sets whether previous planning results should be forgotten or reused in subsequent planning.
|
void |
setH(int h)
Sets the height of the tree.
|
void |
setHAndCByMDPError(double rmax,
double epsilon,
int numActions)
Sets the height and number of transition dynamics samples in a way that ensure epsilon optimality.
|
void |
setOperator(DPOperator operator) |
void |
setUseVariableCSize(boolean useVariableC)
Sets whether the number of state transition samples (C) should be variable with respect to the depth of the node.
|
void |
setValueForLeafNodes(ValueFunction vinit)
Sets the
ValueFunction object to use for settting the value of leaf nodes. |
double |
value(State s)
Returns the value function evaluation of the given state.
|
addActionType, applicableActions, getActionTypes, getDomain, getGamma, getHashingFactory, getModel, setActionTypes, setDomain, setGamma, setHashingFactory, setModel, solverInit, stateHash, toggleDebugPrinting
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
addActionType, getActionTypes, getDomain, getGamma, getHashingFactory, getModel, setActionTypes, setDomain, setGamma, setHashingFactory, setModel, solverInit, toggleDebugPrinting
protected int h
protected int c
protected boolean useVariableC
protected boolean forgetPreviousPlanResults
protected ValueFunction vinit
protected boolean computeExactValueFunction
protected java.util.Map<SparseSampling.HashedHeightState,SparseSampling.StateNode> nodesByHeight
protected java.util.Map<HashableState,java.util.List<QValue>> rootLevelQValues
protected int numUpdates
protected DPOperator operator
public SparseSampling(SADomain domain, double gamma, HashableStateFactory hashingFactory, int h, int c)
setHAndCByMDPError(double, double, int)
method, but in
general this will result in very large values that will be intractable. If you set c = -1, then the full transition dynamics will be used. You should
only use the full transition dynamics if the number of possible transitions from each state is small and if the model implements FullModel
domain
- the planning domaingamma
- the discount factorhashingFactory
- the state hashing factory for matching generated states with their state nodes.h
- the height of the treec
- the number of transition dynamics samples used. If set to -1, then the full transition dynamics are used.public void setHAndCByMDPError(double rmax, double epsilon, int numActions)
rmax
- the maximum reward value of the MDPepsilon
- the epsilon optimality (amount that the estimated value function may diverge from the true optimal)numActions
- the maximum number of actions that could be applied from a statepublic void setUseVariableCSize(boolean useVariableC)
useVariableC
- if true, then depth-variable C will be used; if false, all state nodes use the same number of samples.public void setC(int c)
c
- the number of state transition samples used. If -1, then the full transition dynamics are used.public void setH(int h)
h
- the height of the tree.public int getC()
public int getH()
public void setComputeExactValueFunction(boolean computeExactValueFunction)
computeExactValueFunction
- if true, the exact finite horizon value function is computed; if false, then sampling is used.public boolean computesExactValueFunction()
public void setForgetPreviousPlanResults(boolean forgetPreviousPlanResults)
forgetPreviousPlanResults
- if true, then previous planning results will be forgotten; if true, they will be remembered and reused in subsequent planning.public void setValueForLeafNodes(ValueFunction vinit)
ValueFunction
object to use for settting the value of leaf nodes.vinit
- the ValueFunction
object to use for settting the value of leaf nodes.public DPOperator getOperator()
public void setOperator(DPOperator operator)
public int getDebugCode()
DPrint
.getDebugCode
in interface MDPSolverInterface
getDebugCode
in class MDPSolver
DPrint
.public void setDebugCode(int debugCode)
DPrint
.setDebugCode
in interface MDPSolverInterface
setDebugCode
in class MDPSolver
debugCode
- the debugCode to use.public int getNumberOfValueEsitmates()
resetSolver()
call.resetSolver()
call.public int getNumberOfStateNodesCreated()
public GreedyQPolicy planFromState(State initialState)
GreedyQPolicy
that greedily
selects the action with the highest Q-value and breaks ties uniformly randomly.planFromState
in interface Planner
initialState
- the initial state of the planning problemGreedyQPolicy
.public void resetSolver()
MDPSolverInterface
resetSolver
in interface MDPSolverInterface
resetSolver
in class MDPSolver
public java.util.List<QValue> qValues(State s)
QProvider
List
of QValue
objects for ever permissible action for the given input state.public double qValue(State s, Action a)
QFunction
QValue
for the given state-action pair.public double value(State s)
ValueFunction
value
in interface ValueFunction
s
- the state to evaluate.protected int getCAtHeight(int height)
height
- the height from a leaf node.protected SparseSampling.StateNode getStateNode(State s, int height)
s
- the stateheight
- the height (distance from leaf node) of the node.protected static double logbase(double base, double x)
base
- the log basex
- the input of the log