public class MLIRLRequest extends IRLRequest
MLIRL
).
This request adds a set of optionally specified weights on the expert trajectories, the DifferentiableRF
to use, and the Boltzmann beta parameter used for Differentiable planning. The larger the beta value, the more
deterministic the expert trajectories are assumed to be.
If no expert trajectory weights are provided,
then they will all be assumed to have a weight of 1. Calls to the getEpisodeWeights()
method when weights have
not been specified will result in a new double array being created and returned with the value 1.0 everywhere, so changes
to the returned array will not change the weights actually used. Instead, modify the weights using the setEpisodeWeights(double[])
method.
Modifier and Type | Field and Description |
---|---|
protected double |
boltzmannBeta
The parameter used in the boltzmann policy that affects how noisy the expert is assumed to be.
|
protected double[] |
episodeWeights
The weight assigned to each episode.
|
protected DifferentiableRF |
rf
The differentiable reward function model that will be estimated by MLRIL.
|
domain, expertEpisodes, gamma, planner
Constructor and Description |
---|
MLIRLRequest(SADomain domain,
java.util.List<Episode> expertEpisodes,
DifferentiableRF rf,
HashableStateFactory hashingFactory)
Initializes without any expert trajectory weights (which will be assumed to have a value 1) and requests
a default
DifferentiableQFunction instance to be created using
the HashableStateFactory provided. |
MLIRLRequest(SADomain domain,
Planner planner,
java.util.List<Episode> expertEpisodes,
DifferentiableRF rf)
Initializes the request without any expert trajectory weights (which will be assumed to have a value 1).
|
Modifier and Type | Method and Description |
---|---|
double |
getBoltzmannBeta() |
double[] |
getEpisodeWeights()
Returns expert episodes weights.
|
DifferentiableRF |
getRf() |
boolean |
isValid()
Returns true if this request object has valid data members set; false otherwise.
|
void |
setBoltzmannBeta(double boltzmannBeta) |
void |
setEpisodeWeights(double[] episodeWeights) |
void |
setPlanner(Planner p) |
void |
setRf(DifferentiableRF rf) |
getDomain, getExpertEpisodes, getGamma, getPlanner, setDomain, setExpertEpisodes, setGamma
protected double[] episodeWeights
protected double boltzmannBeta
protected DifferentiableRF rf
public MLIRLRequest(SADomain domain, Planner planner, java.util.List<Episode> expertEpisodes, DifferentiableRF rf)
DifferentiableQFunction
interface, an exception will be thrown.domain
- the domain in which trajectories are provided.planner
- a valueFunction that implements the DifferentiableQFunction
interface.expertEpisodes
- the expert episodes/trajectories to use for training.rf
- the DifferentiableRF
model to use.public MLIRLRequest(SADomain domain, java.util.List<Episode> expertEpisodes, DifferentiableRF rf, HashableStateFactory hashingFactory)
DifferentiableQFunction
instance to be created using
the HashableStateFactory
provided. The
DifferentiableQFunction
instance will be
a DifferentiableVI
that plans
either until the maximum change is the value function is no greater than 0.01 or until 500 iterations have been performed.
A default gamma (discount) value of 0.99 will be used for the valueFunction and no terminal states will be used.domain
- the domain in which trajectories are provided.expertEpisodes
- the expert episodes/trajectories to use for training.rf
- the DifferentiableRF
model to use.hashingFactory
- the state hashing factory to use for the created valueFunction.public boolean isValid()
IRLRequest
isValid
in class IRLRequest
public void setPlanner(Planner p)
setPlanner
in class IRLRequest
public double[] getEpisodeWeights()
public double getBoltzmannBeta()
public DifferentiableRF getRf()
public void setEpisodeWeights(double[] episodeWeights)
public void setBoltzmannBeta(double boltzmannBeta)
public void setRf(DifferentiableRF rf)