public class LSPI extends OOMDPPlanner implements QComputablePlanner, LearningAgent
planFromState(State) or runLearningEpisodeFrom(State)
methods, you should instead use a SARSCollector object to gather a bunch of example state-action-reward-state tuples that are then used for policy iteration. You can
set the dataset to use using the setDataset(SARSData) method and then you can run LSPI on it using the runPolicyIteration(int, double) method. LSPI requires
initializing a matrix to an identity matrix multiplied by some large positive constant (see the reference for more information).
By default this constant is 100, but you can change it with the setIdentityScalar(double)
method.
If you do use the planFromState(State) method, it will work by creating a SARSCollector.UniformRandomSARSCollector and collecting SARS data from the input state and then calling
the runPolicyIteration(int, double) method. You can change the SARSCollector this method uses, the number of samples it acquires, the maximum weight change for PI termination,
and the maximum number of policy iterations by using the setPlanningCollector(SARSCollector), setNumSamplesForPlanning(int), setMaxChange(double), and
setMaxNumPlanningIterations(int) methods repsectively.
If you do use the runLearningEpisodeFrom(State) method (or the runLearningEpisodeFrom(State, int) method), it will work by following a learning policy for the episode and adding its observations to its dataset for its
policy iteration. After enough new data has been acquired, policy iteration will be rereun. You can adjust the learning policy, the maximum number of allowed learning steps in an
episode, and the minimum number of new observations until LSPI is rerun using the setLearningPolicy(Policy), setMaxLearningSteps(int), setMinNewStepsForLearningPI(int)
methods respectively. The LSPI termination parameters are set using the same methods that you use for adjusting the results from the planFromState(State) method discussed above.
This data gathering and replanning behavior from learning episodes is not expected to be an especailly good choice. Therefore, if you want a better online data acquisition, you should consider subclassing this class
and overriding the methods updateDatasetWithLearningEpisode(EpisodeAnalysis) and shouldRereunPolicyIteration(EpisodeAnalysis), or the runLearningEpisodeFrom(State, int) method
itself.
1. Lagoudakis, Michail G., and Ronald Parr. "Least-squares policy iteration." The Journal of Machine Learning Research 4 (2003): 1107-1149.| Modifier and Type | Class and Description |
|---|---|
protected class |
LSPI.SSFeatures
Pair of the the state-action features and the next state-action features.
|
QComputablePlanner.QComputablePlannerHelperLearningAgent.LearningAgentBookKeeping| Modifier and Type | Field and Description |
|---|---|
protected SARSData |
dataset
The SARS dataset on which LSPI is performed
|
protected java.util.LinkedList<EpisodeAnalysis> |
episodeHistory
the saved previous learning episodes
|
protected FeatureDatabase |
featureDatabase
The state feature database on which the linear VFA is performed
|
protected double |
identityScalar
The initial LSPI identity matrix scalar; default is 100.
|
protected org.ejml.simple.SimpleMatrix |
lastWeights
The last weight values set from LSTDQ
|
protected Policy |
learningPolicy
The learning policy followed in
runLearningEpisodeFrom(State) method calls. |
protected double |
maxChange
The maximum change in weights permitted to terminate LSPI.
|
protected int |
maxLearningSteps
The maximum number of learning steps in an episode when the
runLearningEpisodeFrom(State) method is called. |
protected int |
maxNumPlanningIterations
The maximum number of policy iterations permitted when LSPI is run from the
planFromState(State) or runLearningEpisodeFrom(State) methods. |
protected int |
minNewStepsForLearningPI
The minimum number of new observations received from learning episodes before LSPI will be run again.
|
protected int |
numEpisodesToStore
The number of the most recent learning episodes to store.
|
protected int |
numSamplesForPlanning
the number of samples that are acquired for this object's dataset when the
planFromState(State) method is called. |
protected int |
numStepsSinceLastLearningPI
Number of new observations received from learning episodes since LSPI was run
|
protected SARSCollector |
planningCollector
The data collector used by the
planFromState(State) method. |
protected ValueFunctionApproximation |
vfa
The object that performs value function approximation given the weights that are estimated
|
actions, containsParameterizedActions, debugCode, domain, gamma, hashingFactory, mapToStateIndex, rf, tf| Constructor and Description |
|---|
LSPI(Domain domain,
RewardFunction rf,
TerminalFunction tf,
double gamma,
FeatureDatabase fd)
Initializes for the given domain, reward function, terminal state function, discount factor and the feature database that provides the state features used by LSPI.
|
| Modifier and Type | Method and Description |
|---|---|
protected java.util.List<GroundedAction> |
gaListWrapper(AbstractGroundedAction ga)
Wraps a
GroundedAction in a list of size 1. |
java.util.List<EpisodeAnalysis> |
getAllStoredLearningEpisodes()
Returns all saved
EpisodeAnalysis objects of which the agent has kept track. |
SARSData |
getDataset()
Returns the dataset this object uses for LSPI
|
FeatureDatabase |
getFeatureDatabase()
Returns the feature database defining state features
|
double |
getIdentityScalar()
Returns the initial LSPI identity matrix scalar used
|
EpisodeAnalysis |
getLastLearningEpisode()
Returns the last learning episode of the agent.
|
Policy |
getLearningPolicy()
The learning policy followed by the
runLearningEpisodeFrom(State) and runLearningEpisodeFrom(State, int) methods. |
double |
getMaxChange()
The maximum change in weights required to terminate policy iteration when called from the
planFromState(State), runLearningEpisodeFrom(State) or runLearningEpisodeFrom(State, int) methods. |
int |
getMaxLearningSteps()
The maximum number of learning steps permitted by the
runLearningEpisodeFrom(State) method. |
int |
getMaxNumPlanningIterations()
The maximum number of policy iterations that will be used by the
planFromState(State) method. |
int |
getMinNewStepsForLearningPI()
The minimum number of new learning observations before policy iteration is run again.
|
int |
getNumSamplesForPlanning()
Gets the number of SARS samples that will be gathered by the
planFromState(State) method. |
SARSCollector |
getPlanningCollector()
Gets the
SARSCollector used by the planFromState(State) method for collecting data. |
QValue |
getQ(State s,
AbstractGroundedAction a)
Returns the
QValue for the given state-action pair. |
protected QValue |
getQFromFeaturesFor(java.util.List<ActionApproximationResult> results,
State s,
GroundedAction ga)
Creates a Q-value object in which the Q-value is determined from VFA.
|
java.util.List<QValue> |
getQs(State s)
Returns a
List of QValue objects for ever permissible action for the given input state. |
org.ejml.simple.SimpleMatrix |
LSTDQ()
Runs LSTDQ on this object's current
SARSData dataset. |
protected org.ejml.simple.SimpleMatrix |
phiConstructor(java.util.List<ActionFeaturesQuery> features,
int nf)
Constructs the state-action feature vector as a
SimpleMatrix. |
void |
planFromState(State initialState)
This method will cause the planner to begin planning from the specified initial state
|
void |
resetPlannerResults()
Use this method to reset all planner results so that planning can be started fresh with a call to
OOMDPPlanner.planFromState(State)
as if no planning had ever been performed before. |
EpisodeAnalysis |
runLearningEpisodeFrom(State initialState)
Causes the agent to perform a learning episode starting in the given initial state.
|
EpisodeAnalysis |
runLearningEpisodeFrom(State initialState,
int maxSteps)
Causes the agent to perform a learning episode starting in the given initial state.
|
void |
runPolicyIteration(int numIterations,
double maxChange)
Runs LSPI for either numIterations or until the change in the weight matrix is no greater than maxChange.
|
void |
setDataset(SARSData dataset)
Sets the SARS dataset this object will use for LSPI
|
void |
setFeatureDatabase(FeatureDatabase featureDatabase)
Sets the feature datbase defining state features
|
void |
setIdentityScalar(double identityScalar)
Sets the initial LSPI identity matrix scalar used.
|
void |
setLearningPolicy(Policy learningPolicy)
Sets the learning policy followed by the
runLearningEpisodeFrom(State) and runLearningEpisodeFrom(State, int) methods. |
void |
setMaxChange(double maxChange)
Sets the maximum change in weights required to terminate policy iteration when called from the
planFromState(State), runLearningEpisodeFrom(State) or runLearningEpisodeFrom(State, int) methods. |
void |
setMaxLearningSteps(int maxLearningSteps)
Sets the maximum number of learning steps permitted by the
runLearningEpisodeFrom(State) method. |
void |
setMaxNumPlanningIterations(int maxNumPlanningIterations)
Sets the maximum number of policy iterations that will be used by the
planFromState(State) method. |
void |
setMinNewStepsForLearningPI(int minNewStepsForLearningPI)
Sets the minimum number of new learning observations before policy iteration is run again.
|
void |
setNumEpisodesToStore(int numEps)
Tells the agent how many
EpisodeAnalysis objects representing learning episodes to internally store. |
void |
setNumSamplesForPlanning(int numSamplesForPlanning)
Sets the number of SARS samples that will be gathered by the
planFromState(State) method. |
void |
setPlanningCollector(SARSCollector planningCollector)
Sets the
SARSCollector used by the planFromState(State) method for collecting data. |
protected boolean |
shouldRereunPolicyIteration(EpisodeAnalysis ea)
Returns whether LSPI should be rereun given the latest learning episode results.
|
protected void |
updateDatasetWithLearningEpisode(EpisodeAnalysis ea)
Updates this object's
SARSData to include the results of a learning episode. |
addNonDomainReferencedAction, getActions, getAllGroundedActions, getDebugCode, getDomain, getGamma, getHashingFactory, getRf, getRF, getTf, getTF, plannerInit, setActions, setDebugCode, setDomain, setGamma, setRf, setTf, stateHash, toggleDebugPrinting, translateActionprotected ValueFunctionApproximation vfa
protected SARSData dataset
protected FeatureDatabase featureDatabase
protected double identityScalar
protected org.ejml.simple.SimpleMatrix lastWeights
protected int numSamplesForPlanning
planFromState(State) method is called.protected double maxChange
protected SARSCollector planningCollector
planFromState(State) method.protected int maxNumPlanningIterations
planFromState(State) or runLearningEpisodeFrom(State) methods.protected Policy learningPolicy
runLearningEpisodeFrom(State) method calls. Default is 0.1 epsilon greedy.protected int maxLearningSteps
runLearningEpisodeFrom(State) method is called. Default is INT_MAX.protected int numStepsSinceLastLearningPI
protected int minNewStepsForLearningPI
protected java.util.LinkedList<EpisodeAnalysis> episodeHistory
protected int numEpisodesToStore
public LSPI(Domain domain, RewardFunction rf, TerminalFunction tf, double gamma, FeatureDatabase fd)
domain - the problem domainrf - the reward functiontf - the terminal state functiongamma - the discount factorfd - the feature database defining state features on which LSPI will run.public void setDataset(SARSData dataset)
dataset - the SARSA datasetpublic SARSData getDataset()
public FeatureDatabase getFeatureDatabase()
public void setFeatureDatabase(FeatureDatabase featureDatabase)
featureDatabase - the feature database defining state featurespublic double getIdentityScalar()
public void setIdentityScalar(double identityScalar)
identityScalar - the initial LSPI identity matrix scalar used.public int getNumSamplesForPlanning()
planFromState(State) method.planFromState(State) method.public void setNumSamplesForPlanning(int numSamplesForPlanning)
planFromState(State) method.numSamplesForPlanning - the number of SARS samples that will be gathered by the planFromState(State) method.public SARSCollector getPlanningCollector()
SARSCollector used by the planFromState(State) method for collecting data.SARSCollector used by the planFromState(State) method for collecting data.public void setPlanningCollector(SARSCollector planningCollector)
SARSCollector used by the planFromState(State) method for collecting data.planningCollector - the SARSCollector used by the planFromState(State) method for collecting data.public int getMaxNumPlanningIterations()
planFromState(State) method.planFromState(State) method.public void setMaxNumPlanningIterations(int maxNumPlanningIterations)
planFromState(State) method.maxNumPlanningIterations - the maximum number of policy iterations that will be used by the planFromState(State) method.public Policy getLearningPolicy()
runLearningEpisodeFrom(State) and runLearningEpisodeFrom(State, int) methods.runLearningEpisodeFrom(State) and runLearningEpisodeFrom(State, int) methods.public void setLearningPolicy(Policy learningPolicy)
runLearningEpisodeFrom(State) and runLearningEpisodeFrom(State, int) methods.learningPolicy - the learning policy followed by the runLearningEpisodeFrom(State) and runLearningEpisodeFrom(State, int) methods.public int getMaxLearningSteps()
runLearningEpisodeFrom(State) method.runLearningEpisodeFrom(State) method.public void setMaxLearningSteps(int maxLearningSteps)
runLearningEpisodeFrom(State) method.maxLearningSteps - the maximum number of learning steps permitted by the runLearningEpisodeFrom(State) method.public int getMinNewStepsForLearningPI()
public void setMinNewStepsForLearningPI(int minNewStepsForLearningPI)
minNewStepsForLearningPI - the minimum number of new learning observations before policy iteration is run again.public double getMaxChange()
planFromState(State), runLearningEpisodeFrom(State) or runLearningEpisodeFrom(State, int) methods.planFromState(State), runLearningEpisodeFrom(State) or runLearningEpisodeFrom(State, int) methods.public void setMaxChange(double maxChange)
planFromState(State), runLearningEpisodeFrom(State) or runLearningEpisodeFrom(State, int) methods.maxChange - the maximum change in weights required to terminate policy iteration when called from the planFromState(State), runLearningEpisodeFrom(State) or runLearningEpisodeFrom(State, int) methods.public org.ejml.simple.SimpleMatrix LSTDQ()
SARSData dataset.SimpleMatrix object.public void runPolicyIteration(int numIterations,
double maxChange)
numIterations - the maximum number of policy iterations.maxChange - when the weight change is smaller than this value, LSPI terminates.protected org.ejml.simple.SimpleMatrix phiConstructor(java.util.List<ActionFeaturesQuery> features, int nf)
SimpleMatrix.features - the state-action features that have non-zero valuesnf - the total number of state-action features.SimpleMatrix.protected java.util.List<GroundedAction> gaListWrapper(AbstractGroundedAction ga)
GroundedAction in a list of size 1.ga - the GroundedAction to wrap.List consisting of just the input GroundedAction object.public java.util.List<QValue> getQs(State s)
QComputablePlannerList of QValue objects for ever permissible action for the given input state.getQs in interface QComputablePlanners - the state for which Q-values are to be returned.List of QValue objects for ever permissible action for the given input state.public QValue getQ(State s, AbstractGroundedAction a)
QComputablePlannerQValue for the given state-action pair.getQ in interface QComputablePlanners - the input statea - the input actionQValue for the given state-action pair.protected QValue getQFromFeaturesFor(java.util.List<ActionApproximationResult> results, State s, GroundedAction ga)
results - the VFA prediction results for each action.s - the state of the Q-valuega - the action takenpublic void planFromState(State initialState)
OOMDPPlannerplanFromState in class OOMDPPlannerinitialState - the initial state of the planning problempublic void resetPlannerResults()
OOMDPPlannerOOMDPPlanner.planFromState(State)
as if no planning had ever been performed before. Specifically, data produced from calls to the
OOMDPPlanner.planFromState(State) will be cleared, but all other planner settings should remain the same.
This is useful if the reward function or transition dynamics have changed, thereby
requiring new results to be computed. If there were other objects this planner was provided that may have changed
and need to be reset, you will need to reset them yourself. For instance, if you told a planner to follow a policy
that had a temperature parameter decrease with time, you will need to reset the policy's temperature yourself.resetPlannerResults in class OOMDPPlannerpublic EpisodeAnalysis runLearningEpisodeFrom(State initialState)
LearningAgentrunLearningEpisodeFrom in interface LearningAgentinitialState - The initial state in which the agent will start the episode.EpisodeAnalysis object.public EpisodeAnalysis runLearningEpisodeFrom(State initialState, int maxSteps)
LearningAgentrunLearningEpisodeFrom in interface LearningAgentinitialState - The initial state in which the agent will start the episode.maxSteps - the maximum number of steps in the episodeEpisodeAnalysis object.protected void updateDatasetWithLearningEpisode(EpisodeAnalysis ea)
SARSData to include the results of a learning episode.ea - the learning episode as an EpisodeAnalysis object.protected boolean shouldRereunPolicyIteration(EpisodeAnalysis ea)
numStepsSinceLastLearningPI threshold.ea - the most recent learning episodepublic EpisodeAnalysis getLastLearningEpisode()
LearningAgentgetLastLearningEpisode in interface LearningAgentpublic void setNumEpisodesToStore(int numEps)
LearningAgentEpisodeAnalysis objects representing learning episodes to internally store.
For instance, if the number of set to 5, then the agent should remember the save the last 5 learning episodes. Note that this number
has nothing to do with how learning is performed; it is purely for performance gathering.setNumEpisodesToStore in interface LearningAgentnumEps - the number of learning episodes to remember.public java.util.List<EpisodeAnalysis> getAllStoredLearningEpisodes()
LearningAgentEpisodeAnalysis objects of which the agent has kept track.getAllStoredLearningEpisodes in interface LearningAgentEpisodeAnalysis objects of which the agent has kept track.