public class DifferentiableSparseSampling extends MDPSolver implements QGradientPlanner, Planner
MLIRL
) [2]. Additionally, the value of the leaf
nodes of this valueFunction may also be parametrized using a DifferentiableVInit
object and learned with MLIRL
,
enabling a nice separation of shaping features/rewards and the learned (or known) reward function.
Modifier and Type | Class and Description |
---|---|
class |
DifferentiableSparseSampling.DiffStateNode
A class for value differentiable state nodes.
|
protected static class |
DifferentiableSparseSampling.QAndQGradient
A tuple for storing Q-values and their gradients.
|
protected static class |
DifferentiableSparseSampling.VAndVGradient
A tuple for storing a state value and its gradient.
|
QFunction.QFunctionHelper
Modifier and Type | Field and Description |
---|---|
protected double |
boltzBeta
The Boltzmann beta parameter that defines the differentiable Bellman equation.
|
protected int |
c
The number of transition dynamics samples (for the root if depth-variable C is used)
|
protected boolean |
forgetPreviousPlanResults
Whether previous planning results should be forgotten or reused; default is reused (false).
|
protected int |
h
The height of the tree
|
protected java.util.Map<SparseSampling.HashedHeightState,DifferentiableSparseSampling.DiffStateNode> |
nodesByHeight
The tree nodes indexed by state and height.
|
protected int |
numUpdates
The total number of pseudo-Bellman updates
|
protected int |
rfDim
The dimensionality of the differentiable reward function
|
protected java.util.Map<HashableState,DifferentiableSparseSampling.QAndQGradient> |
rootLevelQValues
The root state node Q-values that have been estimated by previous planning calls.
|
protected boolean |
useVariableC
Whether the number of transition dynamic samples should scale with the depth of the node.
|
protected DifferentiableVInit |
vinit
The state value used for leaf nodes; default is zero.
|
actions, debugCode, domain, gamma, hashingFactory, mapToStateIndex, rf, tf
Constructor and Description |
---|
DifferentiableSparseSampling(Domain domain,
DifferentiableRF rf,
TerminalFunction tf,
double gamma,
HashableStateFactory hashingFactory,
int h,
int c,
double boltzBeta)
Initializes.
|
Modifier and Type | Method and Description |
---|---|
protected java.util.Set<java.lang.Integer> |
combinedNonZeroPDParameters(FunctionGradient... gradients) |
java.util.List<QGradientTuple> |
getAllQGradients(State s)
Returns the list of Q-value gradients (returned as
objects ) for each action permissible in the given state. |
int |
getC()
Returns the number of state transition samples
|
protected int |
getCAtHeight(int height)
Returns the value of C for a node at the given height (height from a leaf node).
|
int |
getDebugCode()
Returns the debug code used for logging plan results with
DPrint . |
int |
getH()
Returns the height of the tree
|
int |
getNumberOfValueEsitmates()
Returns the total number of state value estimates performed since the
resetSolver() call. |
QValue |
getQ(State s,
AbstractGroundedAction a)
Returns the
QValue for the given state-action pair. |
QGradientTuple |
getQGradient(State s,
GroundedAction a)
Returns the Q-value gradient (
QGradientTuple ) for the given state and action. |
java.util.List<QValue> |
getQs(State s)
Returns a
List of QValue objects for ever permissible action for the given input state. |
protected DifferentiableSparseSampling.DiffStateNode |
getStateNode(State s,
int height)
Either returns, or creates, indexes, and returns, the state node for the given state at the given height in the tree
|
BoltzmannQPolicy |
planFromState(State initialState)
Plans from the input state and returns a
BoltzmannQPolicy following the
Boltzmann parameter used for value Botlzmann value backups in this planner. |
void |
resetSolver()
This method resets all solver results so that a solver can be restarted fresh
as if had never solved the MDP.
|
void |
setBoltzmannBetaParameter(double beta)
Sets this valueFunction's Boltzmann beta parameter used to compute gradients.
|
void |
setC(int c)
Sets the number of state transition samples used.
|
void |
setDebugCode(int debugCode)
Sets the debug code used for logging plan results with
DPrint . |
void |
setForgetPreviousPlanResults(boolean forgetPreviousPlanResults)
Sets whether previous planning results should be forgetten or resued in subsequent planning.
|
void |
setH(int h)
Sets the height of the tree.
|
void |
setUseVariableCSize(boolean useVariableC)
Sets whether the number of state transition samples (C) should be variable with respect to the depth of the node.
|
void |
setValueForLeafNodes(ValueFunctionInitialization vinit)
Sets the
ValueFunctionInitialization object to use for settting the value of leaf nodes. |
double |
value(State s)
Returns the value function evaluation of the given state.
|
addNonDomainReferencedAction, getActions, getAllGroundedActions, getDomain, getGamma, getHashingFactory, getRf, getRF, getTf, getTF, setActions, setDomain, setGamma, setHashingFactory, setRf, setTf, solverInit, stateHash, toggleDebugPrinting, translateAction
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
addNonDomainReferencedAction, getActions, getDomain, getGamma, getHashingFactory, getRf, getRF, getTf, getTF, setActions, setDomain, setGamma, setHashingFactory, setRf, setTf, solverInit, toggleDebugPrinting
protected int h
protected int c
protected boolean useVariableC
protected boolean forgetPreviousPlanResults
protected DifferentiableVInit vinit
protected java.util.Map<SparseSampling.HashedHeightState,DifferentiableSparseSampling.DiffStateNode> nodesByHeight
protected java.util.Map<HashableState,DifferentiableSparseSampling.QAndQGradient> rootLevelQValues
protected double boltzBeta
protected int rfDim
protected int numUpdates
public DifferentiableSparseSampling(Domain domain, DifferentiableRF rf, TerminalFunction tf, double gamma, HashableStateFactory hashingFactory, int h, int c, double boltzBeta)
domain
- the problem domainrf
- the differentiable reward functiontf
- the terminal functiongamma
- the discount factorhashingFactory
- the hashing factory used to compare state equalityh
- the planning horizonc
- how many samples from the transition dynamics to use. Set to -1 to use the full (unsampled) transition dynamics.boltzBeta
- the Boltzmann beta parameter for the differentiable Boltzmann (softmax) backup equation. The larger the value the more deterministic, the closer to 1 the softer.public void setUseVariableCSize(boolean useVariableC)
useVariableC
- if true, then depth-variable C will be used; if false, all state nodes use the same number of samples.public void setC(int c)
c
- the number of state transition samples used.public void setH(int h)
h
- the height of the tree.public int getC()
public int getH()
public void setForgetPreviousPlanResults(boolean forgetPreviousPlanResults)
forgetPreviousPlanResults
- if true, then previous planning results will be forgotten; if true, they will be remembered and reused in susbequent planning.public void setValueForLeafNodes(ValueFunctionInitialization vinit)
ValueFunctionInitialization
object to use for settting the value of leaf nodes.vinit
- the ValueFunctionInitialization
object to use for settting the value of leaf nodes.public int getDebugCode()
DPrint
.getDebugCode
in interface MDPSolverInterface
getDebugCode
in class MDPSolver
DPrint
.public void setDebugCode(int debugCode)
DPrint
.setDebugCode
in interface MDPSolverInterface
setDebugCode
in class MDPSolver
debugCode
- the debugCode to use.public int getNumberOfValueEsitmates()
resetSolver()
call.resetSolver()
call.public void setBoltzmannBetaParameter(double beta)
QGradientPlanner
setBoltzmannBetaParameter
in interface QGradientPlanner
beta
- the value to which this valueFunction's Boltzmann beta parameter will be setpublic java.util.List<QValue> getQs(State s)
QFunction
List
of QValue
objects for ever permissible action for the given input state.public QValue getQ(State s, AbstractGroundedAction a)
QFunction
QValue
for the given state-action pair.public double value(State s)
ValueFunction
value
in interface ValueFunction
s
- the state to evaluate.public java.util.List<QGradientTuple> getAllQGradients(State s)
QGradientPlanner
objects
) for each action permissible in the given state.getAllQGradients
in interface QGradientPlanner
s
- the state for which Q-value gradients are to be returned.public QGradientTuple getQGradient(State s, GroundedAction a)
QGradientPlanner
QGradientTuple
) for the given state and action.getQGradient
in interface QGradientPlanner
s
- the state for which the Q-value gradient is to be returneda
- the action for which the Q-value gradient is to be returned.public BoltzmannQPolicy planFromState(State initialState)
BoltzmannQPolicy
following the
Boltzmann parameter used for value Botlzmann value backups in this planner.planFromState
in interface Planner
initialState
- the initial state of the planning problemBoltzmannQPolicy
public void resetSolver()
MDPSolverInterface
resetSolver
in interface MDPSolverInterface
resetSolver
in class MDPSolver
protected int getCAtHeight(int height)
height
- the height from a leaf node.protected DifferentiableSparseSampling.DiffStateNode getStateNode(State s, int height)
s
- the stateheight
- the height (distance from leaf node) of the node.protected java.util.Set<java.lang.Integer> combinedNonZeroPDParameters(FunctionGradient... gradients)