public class FourierBasisLearningRateWrapper extends java.lang.Object implements LearningRate
LearningRate
implementation provides a wrapper around a source LearningRate
that should be used whenever using FourierBasis
features
with an algorithm like GradientDescentSarsaLam
. This implementation will query the source LearningRate
implementation for its vfa feature-wise
learning rate value and then scale it by the inverse of the L2 norm of the coefficient vector that is associated with the Fourier basis function for that feature id.
That is, if alpha(j) is the learning rate returned by the source LearningRate
implementation for basis function (feature id) j, then
this implementation will return alpha(j) / ||c_j||, where c_j is the coefficient vector associated with Fourier baiss function j.
Since this wrapper operates on state-action features, it will throw a runtime exception if it is queried for OO-MDP State
-wise learning rate peek and poll
methods (peekAtLearningRate(State, Action)
and pollLearningRate(int, State, Action)
, repsectively). Instead, clients
should only call the peekAtLearningRate(int)
and pollLearningRate(int, int)
methods.
Modifier and Type | Field and Description |
---|---|
protected FourierBasis |
fouierBasisFunctions
The Fourier basis functions that are used.
|
protected LearningRate |
sourceLearningRateFunction
The source
LearningRate function that is queried. |
Constructor and Description |
---|
FourierBasisLearningRateWrapper(LearningRate sourceLearningRateFunction,
FourierBasis fouierBasisFunctions)
Initializes.
|
Modifier and Type | Method and Description |
---|---|
double |
peekAtLearningRate(int featureId)
A method for looking at the current learning rate for a state (-action) feature without having it altered.
|
double |
peekAtLearningRate(State s,
Action ga)
A method for looking at the current learning rate for a state-action pair without having it altered.
|
double |
pollLearningRate(int agentTime,
int featureId)
A method for returning the learning rate for a given state (-action) feature and then decaying the learning rate as defined by this class.
|
double |
pollLearningRate(int agentTime,
State s,
Action ga)
A method for returning the learning rate for a given state action pair and then decaying the learning rate as defined by this class.
|
void |
resetDecay()
Causes any learnign rate decay to reset to where it started.
|
protected LearningRate sourceLearningRateFunction
LearningRate
function that is queried.protected FourierBasis fouierBasisFunctions
public FourierBasisLearningRateWrapper(LearningRate sourceLearningRateFunction, FourierBasis fouierBasisFunctions)
sourceLearningRateFunction
- the source LearningRate
function that will be scaled.fouierBasisFunctions
- the FourierBasis
SparseStateFeatures
that defines the Fourier basis functions and their coefficient vectors.public double peekAtLearningRate(State s, Action ga)
LearningRate
peekAtLearningRate
in interface LearningRate
s
- the state for which the learning rate should be returnedga
- the action from which the learning rate should be returnedpublic double pollLearningRate(int agentTime, State s, Action ga)
LearningRate
pollLearningRate
in interface LearningRate
agentTime
- the time index of the agent when polling.s
- the state for which the learning rate should be returnedga
- the action from which the learning rate should be returnedpublic double peekAtLearningRate(int featureId)
LearningRate
peekAtLearningRate
in interface LearningRate
featureId
- the state feature for which the learning rate should be returnedpublic double pollLearningRate(int agentTime, int featureId)
LearningRate
pollLearningRate
in interface LearningRate
agentTime
- the time index of the agent when polling.featureId
- the state feature for which the learning rate should be returnedpublic void resetDecay()
LearningRate
resetDecay
in interface LearningRate