public class DifferentiableSparseSampling extends MDPSolver implements DifferentiableQFunction, QProvider, Planner
MLIRL
) [2]. Additionally, the value of the leaf
nodes of this valueFunction may also be parametrized using a DifferentiableVInit
object and learned with MLIRL
,
enabling a nice separation of shaping features/rewards and the learned (or known) reward function.
1. MacGlashan, J. Littman, M., "Between Imitation and Intention Learning," Proceedings of IJCAI 15, 2015. 2. Babes, M., Marivate, V., Subramanian, K., and Littman, "Apprenticeship learning about multiple intentions." Proceedings of the 28th International Conference on Machine Learning (ICML-11). 2011.
Modifier and Type | Class and Description |
---|---|
class |
DifferentiableSparseSampling.DiffStateNode
A class for value differentiable state nodes.
|
protected static class |
DifferentiableSparseSampling.QAndQGradient
A tuple for storing Q-values and their gradients.
|
protected static class |
DifferentiableSparseSampling.VAndVGradient
A tuple for storing a state value and its gradient.
|
QProvider.Helper
Modifier and Type | Field and Description |
---|---|
protected double |
boltzBeta
The Boltzmann beta parameter that defines the differentiable Bellman equation.
|
protected int |
c
The number of transition dynamics samples (for the root if depth-variable C is used)
|
protected boolean |
forgetPreviousPlanResults
Whether previous planning results should be forgotten or reused; default is reused (false).
|
protected int |
h
The height of the tree
|
protected java.util.Map<SparseSampling.HashedHeightState,DifferentiableSparseSampling.DiffStateNode> |
nodesByHeight
The tree nodes indexed by state and height.
|
protected int |
numUpdates
The total number of pseudo-Bellman updates
|
protected DifferentiableDPOperator |
operator |
protected DifferentiableRF |
rf |
protected int |
rfDim
The dimensionality of the differentiable reward function
|
protected java.util.Map<HashableState,DifferentiableSparseSampling.QAndQGradient> |
rootLevelQValues
The root state node Q-values that have been estimated by previous planning calls.
|
protected boolean |
useVariableC
Whether the number of transition dynamic samples should scale with the depth of the node.
|
protected DifferentiableValueFunction |
vinit
The state value used for leaf nodes; default is zero.
|
actionTypes, debugCode, domain, gamma, hashingFactory, model, usingOptionModel
Constructor and Description |
---|
DifferentiableSparseSampling(SADomain domain,
DifferentiableRF rf,
double gamma,
HashableStateFactory hashingFactory,
int h,
int c,
double boltzBeta)
Initializes.
|
Modifier and Type | Method and Description |
---|---|
protected java.util.Set<java.lang.Integer> |
combinedNonZeroPDParameters(FunctionGradient... gradients) |
int |
getC()
Returns the number of state transition samples
|
protected int |
getCAtHeight(int height)
Returns the value of C for a node at the given height (height from a leaf node).
|
int |
getDebugCode()
Returns the debug code used for logging plan results with
DPrint . |
int |
getH()
Returns the height of the tree
|
SampleModel |
getModel()
Returns the model being used by this solver
|
int |
getNumberOfValueEsitmates()
Returns the total number of state value estimates performed since the
resetSolver() call. |
DifferentiableDPOperator |
getOperator() |
protected DifferentiableSparseSampling.DiffStateNode |
getStateNode(State s,
int height)
Either returns, or creates, indexes, and returns, the state node for the given state at the given height in the tree
|
BoltzmannQPolicy |
planFromState(State initialState)
Plans from the input state and returns a
BoltzmannQPolicy following the
Boltzmann parameter used for value Botlzmann value backups in this planner. |
FunctionGradient |
qGradient(State s,
Action a)
Returns the Q-value gradient (
QGradientTuple ) for the given state and action. |
double |
qValue(State s,
Action a)
Returns the
QValue for the given state-action pair. |
java.util.List<QValue> |
qValues(State s)
Returns a
List of QValue objects for ever permissible action for the given input state. |
void |
resetSolver()
This method resets all solver results so that a solver can be restarted fresh
as if had never solved the MDP.
|
void |
setC(int c)
Sets the number of state transition samples used.
|
void |
setDebugCode(int debugCode)
Sets the debug code used for logging plan results with
DPrint . |
void |
setForgetPreviousPlanResults(boolean forgetPreviousPlanResults)
Sets whether previous planning results should be forgetten or resued in subsequent planning.
|
void |
setH(int h)
Sets the height of the tree.
|
void |
setOperator(DifferentiableDPOperator operator) |
void |
setUseVariableCSize(boolean useVariableC)
Sets whether the number of state transition samples (C) should be variable with respect to the depth of the node.
|
void |
setValueForLeafNodes(ValueFunction vinit)
Sets the
ValueFunction object to use for settting the value of leaf nodes. |
double |
value(State s)
Returns the value function evaluation of the given state.
|
addActionType, applicableActions, getActionTypes, getDomain, getGamma, getHashingFactory, setActionTypes, setDomain, setGamma, setHashingFactory, setModel, solverInit, stateHash, toggleDebugPrinting
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
addActionType, getActionTypes, getDomain, getGamma, getHashingFactory, setActionTypes, setDomain, setGamma, setHashingFactory, setModel, solverInit, toggleDebugPrinting
protected int h
protected int c
protected boolean useVariableC
protected boolean forgetPreviousPlanResults
protected DifferentiableValueFunction vinit
protected DifferentiableRF rf
protected java.util.Map<SparseSampling.HashedHeightState,DifferentiableSparseSampling.DiffStateNode> nodesByHeight
protected java.util.Map<HashableState,DifferentiableSparseSampling.QAndQGradient> rootLevelQValues
protected double boltzBeta
protected int rfDim
protected int numUpdates
protected DifferentiableDPOperator operator
public DifferentiableSparseSampling(SADomain domain, DifferentiableRF rf, double gamma, HashableStateFactory hashingFactory, int h, int c, double boltzBeta)
CustomRewardModel
using the provided reward function.domain
- the problem domainrf
- the differentiable reward functiongamma
- the discount factorhashingFactory
- the hashing factory used to compare state equalityh
- the planning horizonc
- how many samples from the transition dynamics to use. Set to -1 to use the full (unsampled) transition dynamics.boltzBeta
- the Boltzmann beta parameter for the differentiable Boltzmann (softmax) backup equation. The larger the value the more deterministic, the closer to 1 the softer.public void setUseVariableCSize(boolean useVariableC)
useVariableC
- if true, then depth-variable C will be used; if false, all state nodes use the same number of samples.public void setC(int c)
c
- the number of state transition samples used.public void setH(int h)
h
- the height of the tree.public int getC()
public int getH()
public SampleModel getModel()
MDPSolverInterface
getModel
in interface MDPSolverInterface
getModel
in class MDPSolver
SampleModel
public DifferentiableDPOperator getOperator()
public void setOperator(DifferentiableDPOperator operator)
public void setForgetPreviousPlanResults(boolean forgetPreviousPlanResults)
forgetPreviousPlanResults
- if true, then previous planning results will be forgotten; if true, they will be remembered and reused in susbequent planning.public void setValueForLeafNodes(ValueFunction vinit)
ValueFunction
object to use for settting the value of leaf nodes.vinit
- the ValueFunction
object to use for settting the value of leaf nodes.public int getDebugCode()
DPrint
.getDebugCode
in interface MDPSolverInterface
getDebugCode
in class MDPSolver
DPrint
.public void setDebugCode(int debugCode)
DPrint
.setDebugCode
in interface MDPSolverInterface
setDebugCode
in class MDPSolver
debugCode
- the debugCode to use.public int getNumberOfValueEsitmates()
resetSolver()
call.resetSolver()
call.public java.util.List<QValue> qValues(State s)
QProvider
List
of QValue
objects for ever permissible action for the given input state.public double qValue(State s, Action a)
QFunction
QValue
for the given state-action pair.public double value(State s)
ValueFunction
value
in interface ValueFunction
s
- the state to evaluate.public FunctionGradient qGradient(State s, Action a)
DifferentiableQFunction
QGradientTuple
) for the given state and action.qGradient
in interface DifferentiableQFunction
s
- the state for which the Q-value gradient is to be returneda
- the action for which the Q-value gradient is to be returned.public BoltzmannQPolicy planFromState(State initialState)
BoltzmannQPolicy
following the
Boltzmann parameter used for value Botlzmann value backups in this planner.planFromState
in interface Planner
initialState
- the initial state of the planning problemBoltzmannQPolicy
public void resetSolver()
MDPSolverInterface
resetSolver
in interface MDPSolverInterface
resetSolver
in class MDPSolver
protected int getCAtHeight(int height)
height
- the height from a leaf node.protected DifferentiableSparseSampling.DiffStateNode getStateNode(State s, int height)
s
- the stateheight
- the height (distance from leaf node) of the node.protected java.util.Set<java.lang.Integer> combinedNonZeroPDParameters(FunctionGradient... gradients)