public class DifferentiableSparseSampling extends OOMDPPlanner implements QGradientPlanner
MLIRL
) [2]. Additionally, the value of the leaf
nodes of this planner may also be parametrized using a DifferentiableVInit
object and learned with MLIRL
,
enabling a nice separation of shaping features/rewards and the learned (or known) reward function.
Modifier and Type | Class and Description |
---|---|
class |
DifferentiableSparseSampling.DiffStateNode
A class for value differentiable state nodes.
|
protected static class |
DifferentiableSparseSampling.QAndQGradient
A tuple for storing Q-values and their gradients.
|
protected static class |
DifferentiableSparseSampling.VAndVGradient
A tuple for storing a state value and its gradient.
|
QComputablePlanner.QComputablePlannerHelper
Modifier and Type | Field and Description |
---|---|
protected double |
boltzBeta
The Boltzmann beta parameter that defines the differentiable Bellman equation.
|
protected int |
c
The number of transition dynamics samples (for the root if depth-variable C is used)
|
protected boolean |
forgetPreviousPlanResults
Whether previous planning results should be forgotten or reused; default is reused (false).
|
protected int |
h
The height of the tree
|
protected java.util.Map<SparseSampling.HashedHeightState,DifferentiableSparseSampling.DiffStateNode> |
nodesByHeight
The tree nodes indexed by state and height.
|
protected int |
numUpdates
The total number of pseudo-Bellman updates
|
protected int |
rfDim
The dimensionality of the differentiable reward function
|
protected java.util.Map<StateHashTuple,DifferentiableSparseSampling.QAndQGradient> |
rootLevelQValues
The root state node Q-values that have been estimated by previous planning calls.
|
protected boolean |
useVariableC
Whether the number of transition dynamic samples should scale with the depth of the node.
|
protected DifferentiableVInit |
vinit
The state value used for leaf nodes; default is zero.
|
actions, containsParameterizedActions, debugCode, domain, gamma, hashingFactory, mapToStateIndex, rf, tf
Constructor and Description |
---|
DifferentiableSparseSampling(Domain domain,
DifferentiableRF rf,
TerminalFunction tf,
double gamma,
StateHashFactory hashingFactory,
int h,
int c,
double boltzBeta)
Initializes.
|
Modifier and Type | Method and Description |
---|---|
java.util.List<QGradientTuple> |
getAllQGradients(State s)
Returns the list of Q-value gradients (returned as
objects ) for each action permissible in the given state. |
int |
getC()
Returns the number of state transition samples
|
protected int |
getCAtHeight(int height)
Returns the value of C for a node at the given height (height from a leaf node).
|
int |
getDebugCode()
Returns the debug code used for logging plan results with
DPrint . |
int |
getH()
Returns the height of the tree
|
int |
getNumberOfValueEsitmates()
Returns the total number of state value estimates performed since the
resetPlannerResults() call. |
QValue |
getQ(State s,
AbstractGroundedAction a)
Returns the
QValue for the given state-action pair. |
QGradientTuple |
getQGradient(State s,
GroundedAction a)
Returns the Q-value gradient (
QGradientTuple ) for the given state and action. |
java.util.List<QValue> |
getQs(State s)
Returns a
List of QValue objects for ever permissible action for the given input state. |
protected DifferentiableSparseSampling.DiffStateNode |
getStateNode(State s,
int height)
Either returns, or creates, indexes, and returns, the state node for the given state at the given height in the tree
|
void |
planFromState(State initialState)
This method will cause the planner to begin planning from the specified initial state
|
void |
resetPlannerResults()
Use this method to reset all planner results so that planning can be started fresh with a call to
OOMDPPlanner.planFromState(State)
as if no planning had ever been performed before. |
void |
setBoltzmannBetaParameter(double beta)
Sets this planner's Boltzmann beta parameter used to compute gradients.
|
void |
setC(int c)
Sets the number of state transition samples used.
|
void |
setDebugCode(int debugCode)
Sets the debug code used for logging plan results with
DPrint . |
void |
setForgetPreviousPlanResults(boolean forgetPreviousPlanResults)
Sets whether previous planning results should be forgetten or resued in subsequent planning.
|
void |
setH(int h)
Sets the height of the tree.
|
void |
setUseVariableCSize(boolean useVariableC)
Sets whether the number of state transition samples (C) should be variable with respect to the depth of the node.
|
void |
setValueForLeafNodes(ValueFunctionInitialization vinit)
Sets the
ValueFunctionInitialization object to use for settting the value of leaf nodes. |
addNonDomainReferencedAction, getActions, getAllGroundedActions, getDomain, getGamma, getHashingFactory, getRf, getRF, getTf, getTF, plannerInit, setActions, setDomain, setGamma, setRf, setTf, stateHash, toggleDebugPrinting, translateAction
protected int h
protected int c
protected boolean useVariableC
protected boolean forgetPreviousPlanResults
protected DifferentiableVInit vinit
protected java.util.Map<SparseSampling.HashedHeightState,DifferentiableSparseSampling.DiffStateNode> nodesByHeight
protected java.util.Map<StateHashTuple,DifferentiableSparseSampling.QAndQGradient> rootLevelQValues
protected double boltzBeta
protected int rfDim
protected int numUpdates
public DifferentiableSparseSampling(Domain domain, DifferentiableRF rf, TerminalFunction tf, double gamma, StateHashFactory hashingFactory, int h, int c, double boltzBeta)
domain
- the problem domainrf
- the differentiable reward functiontf
- the terminal functiongamma
- the discount factorhashingFactory
- the hashing factory used to compare state equalityh
- the planning horizonc
- how many samples from the transition dynamics to use. Set to -1 to use the full (unsampled) transition dynamics.boltzBeta
- the Boltzmann beta parameter for the differentiable Boltzmann (softmax) backup equation. The larger the value the more deterministic, the closer to 1 the softer.public void setUseVariableCSize(boolean useVariableC)
useVariableC
- if true, then depth-variable C will be used; if false, all state nodes use the same number of samples.public void setC(int c)
c
- the number of state transition samples used.public void setH(int h)
h
- the height of the tree.public int getC()
public int getH()
public void setForgetPreviousPlanResults(boolean forgetPreviousPlanResults)
forgetPreviousPlanResults
- if true, then previous planning results will be forgotten; if true, they will be remembered and reused in susbequent planning.public void setValueForLeafNodes(ValueFunctionInitialization vinit)
ValueFunctionInitialization
object to use for settting the value of leaf nodes.vinit
- the ValueFunctionInitialization
object to use for settting the value of leaf nodes.public int getDebugCode()
DPrint
.getDebugCode
in class OOMDPPlanner
DPrint
.public void setDebugCode(int debugCode)
DPrint
.setDebugCode
in class OOMDPPlanner
debugCode
- the debugCode to use.public int getNumberOfValueEsitmates()
resetPlannerResults()
call.resetPlannerResults()
call.public void setBoltzmannBetaParameter(double beta)
QGradientPlanner
setBoltzmannBetaParameter
in interface QGradientPlanner
beta
- the value to which this planner's Boltzmann beta parameter will be setpublic java.util.List<QValue> getQs(State s)
QComputablePlanner
List
of QValue
objects for ever permissible action for the given input state.getQs
in interface QComputablePlanner
s
- the state for which Q-values are to be returned.List
of QValue
objects for ever permissible action for the given input state.public QValue getQ(State s, AbstractGroundedAction a)
QComputablePlanner
QValue
for the given state-action pair.getQ
in interface QComputablePlanner
s
- the input statea
- the input actionQValue
for the given state-action pair.public java.util.List<QGradientTuple> getAllQGradients(State s)
QGradientPlanner
objects
) for each action permissible in the given state.getAllQGradients
in interface QGradientPlanner
s
- the state for which Q-value gradients are to be returned.public QGradientTuple getQGradient(State s, GroundedAction a)
QGradientPlanner
QGradientTuple
) for the given state and action.getQGradient
in interface QGradientPlanner
s
- the state for which the Q-value gradient is to be returneda
- the action for which the Q-value gradient is to be returned.public void planFromState(State initialState)
OOMDPPlanner
planFromState
in class OOMDPPlanner
initialState
- the initial state of the planning problempublic void resetPlannerResults()
OOMDPPlanner
OOMDPPlanner.planFromState(State)
as if no planning had ever been performed before. Specifically, data produced from calls to the
OOMDPPlanner.planFromState(State)
will be cleared, but all other planner settings should remain the same.
This is useful if the reward function or transition dynamics have changed, thereby
requiring new results to be computed. If there were other objects this planner was provided that may have changed
and need to be reset, you will need to reset them yourself. For instance, if you told a planner to follow a policy
that had a temperature parameter decrease with time, you will need to reset the policy's temperature yourself.resetPlannerResults
in class OOMDPPlanner
protected int getCAtHeight(int height)
height
- the height from a leaf node.protected DifferentiableSparseSampling.DiffStateNode getStateNode(State s, int height)
s
- the stateheight
- the height (distance from leaf node) of the node.