public class MLIRLRequest extends IRLRequest
MLIRL
).
This request adds a set of optionally specified weights on the expert trajectories, the DifferentiableRF
to use, and the Boltzmann beta parameter used for Differentiable planning. The larger the beta value, the more
deterministic the expert trajectories are assumed to be.
If no expert trajectory weights are provided,
then they will all be assumed to have a weight of 1. Calls to the getEpisodeWeights()
method when weights have
not been specified will result in a new double array being created and returned with the value 1.0 everywhere, so changes
to the returned array will not change the weights actually used. Instead, modify the weights using the setEpisodeWeights(double[])
method.Modifier and Type | Field and Description |
---|---|
protected double |
boltzmannBeta
The parameter used in the boltzmann policy that affects how noisy the expert is assumed to be.
|
protected double[] |
episodeWeights
The weight assigned to each episode.
|
protected DifferentiableRF |
rf
The differentiable reward function model that will be estimated by MLRIL.
|
domain, expertEpisodes, gamma, planner
Constructor and Description |
---|
MLIRLRequest(Domain domain,
java.util.List<EpisodeAnalysis> expertEpisodes,
DifferentiableRF rf,
StateHashFactory hashingFactory)
Initializes without any expert trajectory weights (which will be assumed to have a value 1) and requests
a default
QGradientPlanner instance to be created using
the StateHashFactory provided. |
MLIRLRequest(Domain domain,
OOMDPPlanner planner,
java.util.List<EpisodeAnalysis> expertEpisodes,
DifferentiableRF rf)
Initializes the request without any expert trajectory weights (which will be assumed to have a value 1).
|
Modifier and Type | Method and Description |
---|---|
double |
getBoltzmannBeta() |
double[] |
getEpisodeWeights()
Returns expert episodes weights.
|
DifferentiableRF |
getRf() |
boolean |
isValid()
Returns true if this request object has valid data members set; false otherwise.
|
void |
setBoltzmannBeta(double boltzmannBeta) |
void |
setEpisodeWeights(double[] episodeWeights) |
void |
setPlanner(OOMDPPlanner p) |
void |
setRf(DifferentiableRF rf) |
getDomain, getExpertEpisodes, getGamma, getPlanner, setDomain, setExpertEpisodes, setGamma
protected double[] episodeWeights
protected double boltzmannBeta
protected DifferentiableRF rf
public MLIRLRequest(Domain domain, OOMDPPlanner planner, java.util.List<EpisodeAnalysis> expertEpisodes, DifferentiableRF rf)
QGradientPlanner
interface, an exception will be thrown.domain
- the domain in which trajectories are provided.planner
- a planner that implements the QGradientPlanner
interface.expertEpisodes
- the expert episodes/trajectories to use for training.rf
- the DifferentiableRF
model to use.public MLIRLRequest(Domain domain, java.util.List<EpisodeAnalysis> expertEpisodes, DifferentiableRF rf, StateHashFactory hashingFactory)
QGradientPlanner
instance to be created using
the StateHashFactory
provided. The
QGradientPlanner
instance will be
a DifferentiableVI
that plans
either until the maximum change is the value function is no greater than 0.01 or until 500 iterations have been performed.
A default gamma (discount) value of 0.99 will be used for the planner and no terminal states will be used.domain
- the domain in which trajectories are provided.expertEpisodes
- the expert episodes/trajectories to use for training.rf
- the DifferentiableRF
model to use.hashingFactory
- the state hashing factory to use for the created planner.public boolean isValid()
IRLRequest
isValid
in class IRLRequest
public void setPlanner(OOMDPPlanner p)
setPlanner
in class IRLRequest
public double[] getEpisodeWeights()
public double getBoltzmannBeta()
public DifferentiableRF getRf()
public void setEpisodeWeights(double[] episodeWeights)
public void setBoltzmannBeta(double boltzmannBeta)
public void setRf(DifferentiableRF rf)