public class LSPI extends OOMDPPlanner implements QComputablePlanner, LearningAgent
planFromState(State)
or runLearningEpisodeFrom(State)
methods, you should instead use a SARSCollector
object to gather a bunch of example state-action-reward-state tuples that are then used for policy iteration. You can
set the dataset to use using the setDataset(SARSData)
method and then you can run LSPI on it using the runPolicyIteration(int, double)
method. LSPI requires
initializing a matrix to an identity matrix multiplied by some large positive constant (see the reference for more information).
By default this constant is 100, but you can change it with the setIdentityScalar(double)
method.
If you do use the planFromState(State)
method, it will work by creating a SARSCollector.UniformRandomSARSCollector
and collecting SARS data from the input state and then calling
the runPolicyIteration(int, double)
method. You can change the SARSCollector
this method uses, the number of samples it acquires, the maximum weight change for PI termination,
and the maximum number of policy iterations by using the setPlanningCollector(SARSCollector)
, setNumSamplesForPlanning(int)
, setMaxChange(double)
, and
setMaxNumPlanningIterations(int)
methods repsectively.
If you do use the runLearningEpisodeFrom(State)
method (or the runLearningEpisodeFrom(State, int)
method), it will work by following a learning policy for the episode and adding its observations to its dataset for its
policy iteration. After enough new data has been acquired, policy iteration will be rereun. You can adjust the learning policy, the maximum number of allowed learning steps in an
episode, and the minimum number of new observations until LSPI is rerun using the setLearningPolicy(Policy)
, setMaxLearningSteps(int)
, setMinNewStepsForLearningPI(int)
methods respectively. The LSPI termination parameters are set using the same methods that you use for adjusting the results from the planFromState(State)
method discussed above.
This data gathering and replanning behavior from learning episodes is not expected to be an especailly good choice. Therefore, if you want a better online data acquisition, you should consider subclassing this class
and overriding the methods updateDatasetWithLearningEpisode(EpisodeAnalysis)
and shouldRereunPolicyIteration(EpisodeAnalysis)
, or the runLearningEpisodeFrom(State, int)
method
itself.
1. Lagoudakis, Michail G., and Ronald Parr. "Least-squares policy iteration." The Journal of Machine Learning Research 4 (2003): 1107-1149.Modifier and Type | Class and Description |
---|---|
protected class |
LSPI.SSFeatures
Pair of the the state-action features and the next state-action features.
|
QComputablePlanner.QComputablePlannerHelper
LearningAgent.LearningAgentBookKeeping
Modifier and Type | Field and Description |
---|---|
protected SARSData |
dataset
The SARS dataset on which LSPI is performed
|
protected java.util.LinkedList<EpisodeAnalysis> |
episodeHistory
the saved previous learning episodes
|
protected FeatureDatabase |
featureDatabase
The state feature database on which the linear VFA is performed
|
protected double |
identityScalar
The initial LSPI identity matrix scalar; default is 100.
|
protected org.ejml.simple.SimpleMatrix |
lastWeights
The last weight values set from LSTDQ
|
protected Policy |
learningPolicy
The learning policy followed in
runLearningEpisodeFrom(State) method calls. |
protected double |
maxChange
The maximum change in weights permitted to terminate LSPI.
|
protected int |
maxLearningSteps
The maximum number of learning steps in an episode when the
runLearningEpisodeFrom(State) method is called. |
protected int |
maxNumPlanningIterations
The maximum number of policy iterations permitted when LSPI is run from the
planFromState(State) or runLearningEpisodeFrom(State) methods. |
protected int |
minNewStepsForLearningPI
The minimum number of new observations received from learning episodes before LSPI will be run again.
|
protected int |
numEpisodesToStore
The number of the most recent learning episodes to store.
|
protected int |
numSamplesForPlanning
the number of samples that are acquired for this object's dataset when the
planFromState(State) method is called. |
protected int |
numStepsSinceLastLearningPI
Number of new observations received from learning episodes since LSPI was run
|
protected SARSCollector |
planningCollector
The data collector used by the
planFromState(State) method. |
protected ValueFunctionApproximation |
vfa
The object that performs value function approximation given the weights that are estimated
|
actions, containsParameterizedActions, debugCode, domain, gamma, hashingFactory, mapToStateIndex, rf, tf
Constructor and Description |
---|
LSPI(Domain domain,
RewardFunction rf,
TerminalFunction tf,
double gamma,
FeatureDatabase fd)
Initializes for the given domain, reward function, terminal state function, discount factor and the feature database that provides the state features used by LSPI.
|
Modifier and Type | Method and Description |
---|---|
protected java.util.List<GroundedAction> |
gaListWrapper(AbstractGroundedAction ga)
Wraps a
GroundedAction in a list of size 1. |
java.util.List<EpisodeAnalysis> |
getAllStoredLearningEpisodes()
Returns all saved
EpisodeAnalysis objects of which the agent has kept track. |
SARSData |
getDataset()
Returns the dataset this object uses for LSPI
|
FeatureDatabase |
getFeatureDatabase()
Returns the feature database defining state features
|
double |
getIdentityScalar()
Returns the initial LSPI identity matrix scalar used
|
EpisodeAnalysis |
getLastLearningEpisode()
Returns the last learning episode of the agent.
|
Policy |
getLearningPolicy()
The learning policy followed by the
runLearningEpisodeFrom(State) and runLearningEpisodeFrom(State, int) methods. |
double |
getMaxChange()
The maximum change in weights required to terminate policy iteration when called from the
planFromState(State) , runLearningEpisodeFrom(State) or runLearningEpisodeFrom(State, int) methods. |
int |
getMaxLearningSteps()
The maximum number of learning steps permitted by the
runLearningEpisodeFrom(State) method. |
int |
getMaxNumPlanningIterations()
The maximum number of policy iterations that will be used by the
planFromState(State) method. |
int |
getMinNewStepsForLearningPI()
The minimum number of new learning observations before policy iteration is run again.
|
int |
getNumSamplesForPlanning()
Gets the number of SARS samples that will be gathered by the
planFromState(State) method. |
SARSCollector |
getPlanningCollector()
Gets the
SARSCollector used by the planFromState(State) method for collecting data. |
QValue |
getQ(State s,
AbstractGroundedAction a)
Returns the
QValue for the given state-action pair. |
protected QValue |
getQFromFeaturesFor(java.util.List<ActionApproximationResult> results,
State s,
GroundedAction ga)
Creates a Q-value object in which the Q-value is determined from VFA.
|
java.util.List<QValue> |
getQs(State s)
Returns a
List of QValue objects for ever permissible action for the given input state. |
org.ejml.simple.SimpleMatrix |
LSTDQ()
Runs LSTDQ on this object's current
SARSData dataset. |
protected org.ejml.simple.SimpleMatrix |
phiConstructor(java.util.List<ActionFeaturesQuery> features,
int nf)
Constructs the state-action feature vector as a
SimpleMatrix . |
void |
planFromState(State initialState)
This method will cause the planner to begin planning from the specified initial state
|
void |
resetPlannerResults()
Use this method to reset all planner results so that planning can be started fresh with a call to
OOMDPPlanner.planFromState(State)
as if no planning had ever been performed before. |
EpisodeAnalysis |
runLearningEpisodeFrom(State initialState)
Causes the agent to perform a learning episode starting in the given initial state.
|
EpisodeAnalysis |
runLearningEpisodeFrom(State initialState,
int maxSteps)
Causes the agent to perform a learning episode starting in the given initial state.
|
void |
runPolicyIteration(int numIterations,
double maxChange)
Runs LSPI for either numIterations or until the change in the weight matrix is no greater than maxChange.
|
void |
setDataset(SARSData dataset)
Sets the SARS dataset this object will use for LSPI
|
void |
setFeatureDatabase(FeatureDatabase featureDatabase)
Sets the feature datbase defining state features
|
void |
setIdentityScalar(double identityScalar)
Sets the initial LSPI identity matrix scalar used.
|
void |
setLearningPolicy(Policy learningPolicy)
Sets the learning policy followed by the
runLearningEpisodeFrom(State) and runLearningEpisodeFrom(State, int) methods. |
void |
setMaxChange(double maxChange)
Sets the maximum change in weights required to terminate policy iteration when called from the
planFromState(State) , runLearningEpisodeFrom(State) or runLearningEpisodeFrom(State, int) methods. |
void |
setMaxLearningSteps(int maxLearningSteps)
Sets the maximum number of learning steps permitted by the
runLearningEpisodeFrom(State) method. |
void |
setMaxNumPlanningIterations(int maxNumPlanningIterations)
Sets the maximum number of policy iterations that will be used by the
planFromState(State) method. |
void |
setMinNewStepsForLearningPI(int minNewStepsForLearningPI)
Sets the minimum number of new learning observations before policy iteration is run again.
|
void |
setNumEpisodesToStore(int numEps)
Tells the agent how many
EpisodeAnalysis objects representing learning episodes to internally store. |
void |
setNumSamplesForPlanning(int numSamplesForPlanning)
Sets the number of SARS samples that will be gathered by the
planFromState(State) method. |
void |
setPlanningCollector(SARSCollector planningCollector)
Sets the
SARSCollector used by the planFromState(State) method for collecting data. |
protected boolean |
shouldRereunPolicyIteration(EpisodeAnalysis ea)
Returns whether LSPI should be rereun given the latest learning episode results.
|
protected void |
updateDatasetWithLearningEpisode(EpisodeAnalysis ea)
Updates this object's
SARSData to include the results of a learning episode. |
addNonDomainReferencedAction, getActions, getAllGroundedActions, getDebugCode, getDomain, getGamma, getHashingFactory, getRf, getRF, getTf, getTF, plannerInit, setActions, setDebugCode, setDomain, setGamma, setRf, setTf, stateHash, toggleDebugPrinting, translateAction
protected ValueFunctionApproximation vfa
protected SARSData dataset
protected FeatureDatabase featureDatabase
protected double identityScalar
protected org.ejml.simple.SimpleMatrix lastWeights
protected int numSamplesForPlanning
planFromState(State)
method is called.protected double maxChange
protected SARSCollector planningCollector
planFromState(State)
method.protected int maxNumPlanningIterations
planFromState(State)
or runLearningEpisodeFrom(State)
methods.protected Policy learningPolicy
runLearningEpisodeFrom(State)
method calls. Default is 0.1 epsilon greedy.protected int maxLearningSteps
runLearningEpisodeFrom(State)
method is called. Default is INT_MAX.protected int numStepsSinceLastLearningPI
protected int minNewStepsForLearningPI
protected java.util.LinkedList<EpisodeAnalysis> episodeHistory
protected int numEpisodesToStore
public LSPI(Domain domain, RewardFunction rf, TerminalFunction tf, double gamma, FeatureDatabase fd)
domain
- the problem domainrf
- the reward functiontf
- the terminal state functiongamma
- the discount factorfd
- the feature database defining state features on which LSPI will run.public void setDataset(SARSData dataset)
dataset
- the SARSA datasetpublic SARSData getDataset()
public FeatureDatabase getFeatureDatabase()
public void setFeatureDatabase(FeatureDatabase featureDatabase)
featureDatabase
- the feature database defining state featurespublic double getIdentityScalar()
public void setIdentityScalar(double identityScalar)
identityScalar
- the initial LSPI identity matrix scalar used.public int getNumSamplesForPlanning()
planFromState(State)
method.planFromState(State)
method.public void setNumSamplesForPlanning(int numSamplesForPlanning)
planFromState(State)
method.numSamplesForPlanning
- the number of SARS samples that will be gathered by the planFromState(State)
method.public SARSCollector getPlanningCollector()
SARSCollector
used by the planFromState(State)
method for collecting data.SARSCollector
used by the planFromState(State)
method for collecting data.public void setPlanningCollector(SARSCollector planningCollector)
SARSCollector
used by the planFromState(State)
method for collecting data.planningCollector
- the SARSCollector
used by the planFromState(State)
method for collecting data.public int getMaxNumPlanningIterations()
planFromState(State)
method.planFromState(State)
method.public void setMaxNumPlanningIterations(int maxNumPlanningIterations)
planFromState(State)
method.maxNumPlanningIterations
- the maximum number of policy iterations that will be used by the planFromState(State)
method.public Policy getLearningPolicy()
runLearningEpisodeFrom(State)
and runLearningEpisodeFrom(State, int)
methods.runLearningEpisodeFrom(State)
and runLearningEpisodeFrom(State, int)
methods.public void setLearningPolicy(Policy learningPolicy)
runLearningEpisodeFrom(State)
and runLearningEpisodeFrom(State, int)
methods.learningPolicy
- the learning policy followed by the runLearningEpisodeFrom(State)
and runLearningEpisodeFrom(State, int)
methods.public int getMaxLearningSteps()
runLearningEpisodeFrom(State)
method.runLearningEpisodeFrom(State)
method.public void setMaxLearningSteps(int maxLearningSteps)
runLearningEpisodeFrom(State)
method.maxLearningSteps
- the maximum number of learning steps permitted by the runLearningEpisodeFrom(State)
method.public int getMinNewStepsForLearningPI()
public void setMinNewStepsForLearningPI(int minNewStepsForLearningPI)
minNewStepsForLearningPI
- the minimum number of new learning observations before policy iteration is run again.public double getMaxChange()
planFromState(State)
, runLearningEpisodeFrom(State)
or runLearningEpisodeFrom(State, int)
methods.planFromState(State)
, runLearningEpisodeFrom(State)
or runLearningEpisodeFrom(State, int)
methods.public void setMaxChange(double maxChange)
planFromState(State)
, runLearningEpisodeFrom(State)
or runLearningEpisodeFrom(State, int)
methods.maxChange
- the maximum change in weights required to terminate policy iteration when called from the planFromState(State)
, runLearningEpisodeFrom(State)
or runLearningEpisodeFrom(State, int)
methods.public org.ejml.simple.SimpleMatrix LSTDQ()
SARSData
dataset.SimpleMatrix
object.public void runPolicyIteration(int numIterations, double maxChange)
numIterations
- the maximum number of policy iterations.maxChange
- when the weight change is smaller than this value, LSPI terminates.protected org.ejml.simple.SimpleMatrix phiConstructor(java.util.List<ActionFeaturesQuery> features, int nf)
SimpleMatrix
.features
- the state-action features that have non-zero valuesnf
- the total number of state-action features.SimpleMatrix
.protected java.util.List<GroundedAction> gaListWrapper(AbstractGroundedAction ga)
GroundedAction
in a list of size 1.ga
- the GroundedAction
to wrap.List
consisting of just the input GroundedAction
object.public java.util.List<QValue> getQs(State s)
QComputablePlanner
List
of QValue
objects for ever permissible action for the given input state.getQs
in interface QComputablePlanner
s
- the state for which Q-values are to be returned.List
of QValue
objects for ever permissible action for the given input state.public QValue getQ(State s, AbstractGroundedAction a)
QComputablePlanner
QValue
for the given state-action pair.getQ
in interface QComputablePlanner
s
- the input statea
- the input actionQValue
for the given state-action pair.protected QValue getQFromFeaturesFor(java.util.List<ActionApproximationResult> results, State s, GroundedAction ga)
results
- the VFA prediction results for each action.s
- the state of the Q-valuega
- the action takenpublic void planFromState(State initialState)
OOMDPPlanner
planFromState
in class OOMDPPlanner
initialState
- the initial state of the planning problempublic void resetPlannerResults()
OOMDPPlanner
OOMDPPlanner.planFromState(State)
as if no planning had ever been performed before. Specifically, data produced from calls to the
OOMDPPlanner.planFromState(State)
will be cleared, but all other planner settings should remain the same.
This is useful if the reward function or transition dynamics have changed, thereby
requiring new results to be computed. If there were other objects this planner was provided that may have changed
and need to be reset, you will need to reset them yourself. For instance, if you told a planner to follow a policy
that had a temperature parameter decrease with time, you will need to reset the policy's temperature yourself.resetPlannerResults
in class OOMDPPlanner
public EpisodeAnalysis runLearningEpisodeFrom(State initialState)
LearningAgent
runLearningEpisodeFrom
in interface LearningAgent
initialState
- The initial state in which the agent will start the episode.EpisodeAnalysis
object.public EpisodeAnalysis runLearningEpisodeFrom(State initialState, int maxSteps)
LearningAgent
runLearningEpisodeFrom
in interface LearningAgent
initialState
- The initial state in which the agent will start the episode.maxSteps
- the maximum number of steps in the episodeEpisodeAnalysis
object.protected void updateDatasetWithLearningEpisode(EpisodeAnalysis ea)
SARSData
to include the results of a learning episode.ea
- the learning episode as an EpisodeAnalysis
object.protected boolean shouldRereunPolicyIteration(EpisodeAnalysis ea)
numStepsSinceLastLearningPI
threshold.ea
- the most recent learning episodepublic EpisodeAnalysis getLastLearningEpisode()
LearningAgent
getLastLearningEpisode
in interface LearningAgent
public void setNumEpisodesToStore(int numEps)
LearningAgent
EpisodeAnalysis
objects representing learning episodes to internally store.
For instance, if the number of set to 5, then the agent should remember the save the last 5 learning episodes. Note that this number
has nothing to do with how learning is performed; it is purely for performance gathering.setNumEpisodesToStore
in interface LearningAgent
numEps
- the number of learning episodes to remember.public java.util.List<EpisodeAnalysis> getAllStoredLearningEpisodes()
LearningAgent
EpisodeAnalysis
objects of which the agent has kept track.getAllStoredLearningEpisodes
in interface LearningAgent
EpisodeAnalysis
objects of which the agent has kept track.