public class MLIRLRequest extends IRLRequest
MLIRL).
This request adds a set of optionally specified weights on the expert trajectories, the DifferentiableRF
to use, and the Boltzmann beta parameter used for Differentiable planning. The larger the beta value, the more
deterministic the expert trajectories are assumed to be.
If no expert trajectory weights are provided,
then they will all be assumed to have a weight of 1. Calls to the getEpisodeWeights() method when weights have
not been specified will result in a new double array being created and returned with the value 1.0 everywhere, so changes
to the returned array will not change the weights actually used. Instead, modify the weights using the setEpisodeWeights(double[])
method.| Modifier and Type | Field and Description |
|---|---|
protected double |
boltzmannBeta
The parameter used in the boltzmann policy that affects how noisy the expert is assumed to be.
|
protected double[] |
episodeWeights
The weight assigned to each episode.
|
protected DifferentiableRF |
rf
The differentiable reward function model that will be estimated by MLRIL.
|
domain, expertEpisodes, gamma, planner| Constructor and Description |
|---|
MLIRLRequest(Domain domain,
java.util.List<EpisodeAnalysis> expertEpisodes,
DifferentiableRF rf,
StateHashFactory hashingFactory)
Initializes without any expert trajectory weights (which will be assumed to have a value 1) and requests
a default
QGradientPlanner instance to be created using
the StateHashFactory provided. |
MLIRLRequest(Domain domain,
OOMDPPlanner planner,
java.util.List<EpisodeAnalysis> expertEpisodes,
DifferentiableRF rf)
Initializes the request without any expert trajectory weights (which will be assumed to have a value 1).
|
| Modifier and Type | Method and Description |
|---|---|
double |
getBoltzmannBeta() |
double[] |
getEpisodeWeights()
Returns expert episodes weights.
|
DifferentiableRF |
getRf() |
boolean |
isValid()
Returns true if this request object has valid data members set; false otherwise.
|
void |
setBoltzmannBeta(double boltzmannBeta) |
void |
setEpisodeWeights(double[] episodeWeights) |
void |
setPlanner(OOMDPPlanner p) |
void |
setRf(DifferentiableRF rf) |
getDomain, getExpertEpisodes, getGamma, getPlanner, setDomain, setExpertEpisodes, setGammaprotected double[] episodeWeights
protected double boltzmannBeta
protected DifferentiableRF rf
public MLIRLRequest(Domain domain, OOMDPPlanner planner, java.util.List<EpisodeAnalysis> expertEpisodes, DifferentiableRF rf)
QGradientPlanner
interface, an exception will be thrown.domain - the domain in which trajectories are provided.planner - a planner that implements the QGradientPlanner interface.expertEpisodes - the expert episodes/trajectories to use for training.rf - the DifferentiableRF model to use.public MLIRLRequest(Domain domain, java.util.List<EpisodeAnalysis> expertEpisodes, DifferentiableRF rf, StateHashFactory hashingFactory)
QGradientPlanner instance to be created using
the StateHashFactory provided. The
QGradientPlanner instance will be
a DifferentiableVI that plans
either until the maximum change is the value function is no greater than 0.01 or until 500 iterations have been performed.
A default gamma (discount) value of 0.99 will be used for the planner and no terminal states will be used.domain - the domain in which trajectories are provided.expertEpisodes - the expert episodes/trajectories to use for training.rf - the DifferentiableRF model to use.hashingFactory - the state hashing factory to use for the created planner.public boolean isValid()
IRLRequestisValid in class IRLRequestpublic void setPlanner(OOMDPPlanner p)
setPlanner in class IRLRequestpublic double[] getEpisodeWeights()
public double getBoltzmannBeta()
public DifferentiableRF getRf()
public void setEpisodeWeights(double[] episodeWeights)
public void setBoltzmannBeta(double boltzmannBeta)
public void setRf(DifferentiableRF rf)