public class LinearStateActionDifferentiableRF extends java.lang.Object implements DifferentiableRF
DifferentiableRF
.
The class takes as input a DenseStateFeatures
and the set of possible
grounded actions that can be applied in the world. The dimensionality of this reward function is equal to |A|*|f|,
where A is the set of possible grounded actions, and |f| is the state feature vector dimensionality.
The reward function is defined as R(s, a, s') = w(a) * f(s), where w(a) is the set of weights (the parameters) of this reward functions associated with action a, * is the dot product operator, and f(s) is the feature vector for state s.
Note that when the gradient is a vector of size |A||f|, since the feature vector is replicated for each action, and the gradient for all entries associated with an action other than the one taken in the (s, a, s') query will have a gradient value of zero.
The set of possible grounded actions must be defined either in the LinearStateActionDifferentiableRF(DenseStateFeatures, int, Action...)
constructor, or added iteratively with the addAction(burlap.mdp.core.action.Action)
method.
ParametricFunction.ParametricStateActionFunction, ParametricFunction.ParametricStateFunction
Modifier and Type | Field and Description |
---|---|
protected java.util.Map<Action,java.lang.Integer> |
actionMap
An ordering of grounded actions
|
protected int |
dim
The dimension of this reward function
|
protected DenseStateFeatures |
fvGen
The state feature vector generator to use
|
protected int |
numStateFeatures
The number of state features
|
protected double[] |
parameters
The parameters of this reward function
|
Constructor and Description |
---|
LinearStateActionDifferentiableRF(DenseStateFeatures stateFeatures,
int numStateFeatures,
Action... allPossibleActions)
Initializes.
|
Modifier and Type | Method and Description |
---|---|
void |
addAction(Action ga)
Adds a possible grounded action.
|
ParametricFunction |
copy()
Returns a copy of this
ParametricFunction . |
double |
getParameter(int i)
Returns the value of the ith parameter value
|
FunctionGradient |
gradient(State s,
Action a,
State sprime) |
int |
numParameters()
Returns the number of parameters defining this function.
|
void |
resetParameters()
Resets the parameters of this function to default values.
|
double |
reward(State s,
Action a,
State sprime)
Returns the reward received when action a is executed in state s and the agent transitions to state sprime.
|
void |
setParameter(int i,
double p)
Sets the value of the ith parameter to given value
|
java.lang.String |
toString() |
protected java.util.Map<Action,java.lang.Integer> actionMap
protected double[] parameters
protected int dim
protected DenseStateFeatures fvGen
protected int numStateFeatures
public LinearStateActionDifferentiableRF(DenseStateFeatures stateFeatures, int numStateFeatures, Action... allPossibleActions)
addAction(Action)
method.stateFeatures
- the state feature vector generatornumStateFeatures
- the dimensionality of the state feature vectorallPossibleActions
- the set of possible grounded actions.public void addAction(Action ga)
ga
- the possible grounded action to add to this reward function's definition.public double reward(State s, Action a, State sprime)
RewardFunction
reward
in interface RewardFunction
s
- the state in which the action was executeda
- the action executedsprime
- the state to which the agent transitionedpublic FunctionGradient gradient(State s, Action a, State sprime)
gradient
in interface DifferentiableRF
public int numParameters()
ParametricFunction
numParameters
in interface ParametricFunction
public double getParameter(int i)
ParametricFunction
getParameter
in interface ParametricFunction
i
- the parameter indexpublic void setParameter(int i, double p)
ParametricFunction
setParameter
in interface ParametricFunction
i
- the index of the parameter to setp
- the parameter value to which it should be setpublic void resetParameters()
ParametricFunction
resetParameters
in interface ParametricFunction
public ParametricFunction copy()
ParametricFunction
ParametricFunction
.copy
in interface ParametricFunction
ParametricFunction
.public java.lang.String toString()
toString
in class java.lang.Object