public class DifferentiableSparseSampling extends MDPSolver implements QGradientPlanner, Planner
MLIRL) [2]. Additionally, the value of the leaf
nodes of this valueFunction may also be parametrized using a DifferentiableVInit
object and learned with MLIRL,
enabling a nice separation of shaping features/rewards and the learned (or known) reward function.
| Modifier and Type | Class and Description |
|---|---|
class |
DifferentiableSparseSampling.DiffStateNode
A class for value differentiable state nodes.
|
protected static class |
DifferentiableSparseSampling.QAndQGradient
A tuple for storing Q-values and their gradients.
|
protected static class |
DifferentiableSparseSampling.VAndVGradient
A tuple for storing a state value and its gradient.
|
QFunction.QFunctionHelper| Modifier and Type | Field and Description |
|---|---|
protected double |
boltzBeta
The Boltzmann beta parameter that defines the differentiable Bellman equation.
|
protected int |
c
The number of transition dynamics samples (for the root if depth-variable C is used)
|
protected boolean |
forgetPreviousPlanResults
Whether previous planning results should be forgotten or reused; default is reused (false).
|
protected int |
h
The height of the tree
|
protected java.util.Map<SparseSampling.HashedHeightState,DifferentiableSparseSampling.DiffStateNode> |
nodesByHeight
The tree nodes indexed by state and height.
|
protected int |
numUpdates
The total number of pseudo-Bellman updates
|
protected int |
rfDim
The dimensionality of the differentiable reward function
|
protected java.util.Map<HashableState,DifferentiableSparseSampling.QAndQGradient> |
rootLevelQValues
The root state node Q-values that have been estimated by previous planning calls.
|
protected boolean |
useVariableC
Whether the number of transition dynamic samples should scale with the depth of the node.
|
protected DifferentiableVInit |
vinit
The state value used for leaf nodes; default is zero.
|
actions, debugCode, domain, gamma, hashingFactory, mapToStateIndex, rf, tf| Constructor and Description |
|---|
DifferentiableSparseSampling(Domain domain,
DifferentiableRF rf,
TerminalFunction tf,
double gamma,
HashableStateFactory hashingFactory,
int h,
int c,
double boltzBeta)
Initializes.
|
| Modifier and Type | Method and Description |
|---|---|
protected java.util.Set<java.lang.Integer> |
combinedNonZeroPDParameters(FunctionGradient... gradients) |
java.util.List<QGradientTuple> |
getAllQGradients(State s)
Returns the list of Q-value gradients (returned as
objects) for each action permissible in the given state. |
int |
getC()
Returns the number of state transition samples
|
protected int |
getCAtHeight(int height)
Returns the value of C for a node at the given height (height from a leaf node).
|
int |
getDebugCode()
Returns the debug code used for logging plan results with
DPrint. |
int |
getH()
Returns the height of the tree
|
int |
getNumberOfValueEsitmates()
Returns the total number of state value estimates performed since the
resetSolver() call. |
QValue |
getQ(State s,
AbstractGroundedAction a)
Returns the
QValue for the given state-action pair. |
QGradientTuple |
getQGradient(State s,
GroundedAction a)
Returns the Q-value gradient (
QGradientTuple) for the given state and action. |
java.util.List<QValue> |
getQs(State s)
Returns a
List of QValue objects for ever permissible action for the given input state. |
protected DifferentiableSparseSampling.DiffStateNode |
getStateNode(State s,
int height)
Either returns, or creates, indexes, and returns, the state node for the given state at the given height in the tree
|
BoltzmannQPolicy |
planFromState(State initialState)
Plans from the input state and returns a
BoltzmannQPolicy following the
Boltzmann parameter used for value Botlzmann value backups in this planner. |
void |
resetSolver()
This method resets all solver results so that a solver can be restarted fresh
as if had never solved the MDP.
|
void |
setBoltzmannBetaParameter(double beta)
Sets this valueFunction's Boltzmann beta parameter used to compute gradients.
|
void |
setC(int c)
Sets the number of state transition samples used.
|
void |
setDebugCode(int debugCode)
Sets the debug code used for logging plan results with
DPrint. |
void |
setForgetPreviousPlanResults(boolean forgetPreviousPlanResults)
Sets whether previous planning results should be forgetten or resued in subsequent planning.
|
void |
setH(int h)
Sets the height of the tree.
|
void |
setUseVariableCSize(boolean useVariableC)
Sets whether the number of state transition samples (C) should be variable with respect to the depth of the node.
|
void |
setValueForLeafNodes(ValueFunctionInitialization vinit)
Sets the
ValueFunctionInitialization object to use for settting the value of leaf nodes. |
double |
value(State s)
Returns the value function evaluation of the given state.
|
addNonDomainReferencedAction, getActions, getAllGroundedActions, getDomain, getGamma, getHashingFactory, getRf, getRF, getTf, getTF, setActions, setDomain, setGamma, setHashingFactory, setRf, setTf, solverInit, stateHash, toggleDebugPrinting, translateActionclone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitaddNonDomainReferencedAction, getActions, getDomain, getGamma, getHashingFactory, getRf, getRF, getTf, getTF, setActions, setDomain, setGamma, setHashingFactory, setRf, setTf, solverInit, toggleDebugPrintingprotected int h
protected int c
protected boolean useVariableC
protected boolean forgetPreviousPlanResults
protected DifferentiableVInit vinit
protected java.util.Map<SparseSampling.HashedHeightState,DifferentiableSparseSampling.DiffStateNode> nodesByHeight
protected java.util.Map<HashableState,DifferentiableSparseSampling.QAndQGradient> rootLevelQValues
protected double boltzBeta
protected int rfDim
protected int numUpdates
public DifferentiableSparseSampling(Domain domain, DifferentiableRF rf, TerminalFunction tf, double gamma, HashableStateFactory hashingFactory, int h, int c, double boltzBeta)
domain - the problem domainrf - the differentiable reward functiontf - the terminal functiongamma - the discount factorhashingFactory - the hashing factory used to compare state equalityh - the planning horizonc - how many samples from the transition dynamics to use. Set to -1 to use the full (unsampled) transition dynamics.boltzBeta - the Boltzmann beta parameter for the differentiable Boltzmann (softmax) backup equation. The larger the value the more deterministic, the closer to 1 the softer.public void setUseVariableCSize(boolean useVariableC)
useVariableC - if true, then depth-variable C will be used; if false, all state nodes use the same number of samples.public void setC(int c)
c - the number of state transition samples used.public void setH(int h)
h - the height of the tree.public int getC()
public int getH()
public void setForgetPreviousPlanResults(boolean forgetPreviousPlanResults)
forgetPreviousPlanResults - if true, then previous planning results will be forgotten; if true, they will be remembered and reused in susbequent planning.public void setValueForLeafNodes(ValueFunctionInitialization vinit)
ValueFunctionInitialization object to use for settting the value of leaf nodes.vinit - the ValueFunctionInitialization object to use for settting the value of leaf nodes.public int getDebugCode()
DPrint.getDebugCode in interface MDPSolverInterfacegetDebugCode in class MDPSolverDPrint.public void setDebugCode(int debugCode)
DPrint.setDebugCode in interface MDPSolverInterfacesetDebugCode in class MDPSolverdebugCode - the debugCode to use.public int getNumberOfValueEsitmates()
resetSolver() call.resetSolver() call.public void setBoltzmannBetaParameter(double beta)
QGradientPlannersetBoltzmannBetaParameter in interface QGradientPlannerbeta - the value to which this valueFunction's Boltzmann beta parameter will be setpublic java.util.List<QValue> getQs(State s)
QFunctionList of QValue objects for ever permissible action for the given input state.public QValue getQ(State s, AbstractGroundedAction a)
QFunctionQValue for the given state-action pair.public double value(State s)
ValueFunctionvalue in interface ValueFunctions - the state to evaluate.public java.util.List<QGradientTuple> getAllQGradients(State s)
QGradientPlannerobjects) for each action permissible in the given state.getAllQGradients in interface QGradientPlanners - the state for which Q-value gradients are to be returned.public QGradientTuple getQGradient(State s, GroundedAction a)
QGradientPlannerQGradientTuple) for the given state and action.getQGradient in interface QGradientPlanners - the state for which the Q-value gradient is to be returneda - the action for which the Q-value gradient is to be returned.public BoltzmannQPolicy planFromState(State initialState)
BoltzmannQPolicy following the
Boltzmann parameter used for value Botlzmann value backups in this planner.planFromState in interface PlannerinitialState - the initial state of the planning problemBoltzmannQPolicypublic void resetSolver()
MDPSolverInterfaceresetSolver in interface MDPSolverInterfaceresetSolver in class MDPSolverprotected int getCAtHeight(int height)
height - the height from a leaf node.protected DifferentiableSparseSampling.DiffStateNode getStateNode(State s, int height)
s - the stateheight - the height (distance from leaf node) of the node.protected java.util.Set<java.lang.Integer> combinedNonZeroPDParameters(FunctionGradient... gradients)