public class LearningAlgorithmExperimenter
extends java.lang.Object
Environment
in which to perform the experiments,
a number of trials, the length of the trials, and an array of learning agent factories used to generated agent instances and compare their performance.
The Environment
may optionally implement the ExperimentalEnvironment
interface which will let this class to tell the Environment
whenever experiments with a new agent class (defined by
an LearningAgentFactory
is begun).
The length of the trials by default is assumed to be in episodes, but it may also be changed to indicate length in total number of steps using the
toggleTrialLengthInterpretation(boolean)
method.
Performance results are displayed in plots using the PerformancePlotter
class, but visualization may also be disabled with the toggleVisualPlots(boolean)
method. Results may be saved to csv files after the experiment is complete.
The purpose of the experimenter is to test an agent for a specified number of trials. At the beginning of each trial, a new agent is generated using the designated LearningAgentFactory and is used for the specified trial length. After all trials are complete for an agent, the next agent is tested. Note that immediately before an agent is generated from an agent factory, the performance plotter is temporarily frozen from collecting data until the new agent is returned. This allows agent factories to perform offline learning before returning a new agent in the same domain without affecting the experimenter results.
By default the cumulative reward per step will be plotted and if more than one trial is specified, the both the most recent trail and the trial average plot will be shown.
If only one trial is specified, then only the most recent trial plot will be shown. To control the kinds of plots displayed use the
setUpPlottingConfiguration(int, int, int, int, TrialMode, PerformanceMetric...)
method.
Modifier and Type | Field and Description |
---|---|
protected LearningAgentFactory[] |
agentFactories
The array of agent factories for the agents to be compared.
|
protected boolean |
completedExperiment
Whether the experimenter has completed.
|
int |
debugCode
The debug code used for debug printing.
|
protected boolean |
displayPlots
Whether the performance should be visually plotted (by default they will)
|
protected EnvironmentServer |
environmentSever
The
EnvironmentServer that wraps the test Environment
and tells a PerformancePlotter about the individual interactions. |
protected int |
nTrials
The number of trials that each agent is evaluated
|
protected double |
plotCISignificance
The signficance value for the confidence interval in the plots.
|
protected int |
plotRefresh
The delay in milliseconds between autmatic refreshes of the plots
|
protected PerformancePlotter |
plotter
The PerformancePlotter used to collect and plot results
|
protected Environment |
testEnvironment
The test
Environment in which experiments will be performed. |
protected int |
trialLength
The length of each trial
|
protected boolean |
trialLengthIsInEpisodes
Whether the trial length specifies a number of episodes (which is the default) or the total number of steps
|
Constructor and Description |
---|
LearningAlgorithmExperimenter(Environment testEnvironment,
int nTrials,
int trialLength,
LearningAgentFactory... agentFactories)
Initializes.
|
Modifier and Type | Method and Description |
---|---|
protected void |
runEpisodeBoundTrial(LearningAgentFactory agentFactory)
Runs a trial for an agent generated by the given factory when interpreting trial length as a number of episodes.
|
protected void |
runStepBoundTrial(LearningAgentFactory agentFactory)
Runs a trial for an agent generated by the given factor when interpreting trial length as a number of total steps.
|
void |
setPlotCISignificance(double significance)
Sets the significance used for confidence intervals.
|
void |
setPlotRefreshDelay(int delayInMS)
Sets the delay in milliseconds between automatic plot refreshes
|
void |
setUpPlottingConfiguration(int chartWidth,
int chartHeight,
int columns,
int maxWindowHeight,
TrialMode trialMode,
PerformanceMetric... metrics)
Setsup the plotting confiruation.
|
void |
startExperiment()
Starts the experiment and runs all trails for all agents.
|
void |
toggleTrialLengthInterpretation(boolean lengthRepresentsEpisodes)
Changes whether the trial length provided in the constructor is interpreted as the number of episodes or total number of steps.
|
void |
toggleVisualPlots(boolean shouldPlotResults)
Toggles whether plots should be displayed or not.
|
void |
writeEpisodeDataToCSV(java.lang.String filePath)
Writes an step-wise data to a csv file.
|
void |
writeStepAndEpisodeDataToCSV(java.lang.String pathAndBaseNameToUse)
Writes the step-wise and episode-wise data to CSV files.
|
void |
writeStepDataToCSV(java.lang.String filePath)
Writes an episode-wise data to a csv file.
|
protected Environment testEnvironment
Environment
in which experiments will be performed.protected EnvironmentServer environmentSever
EnvironmentServer
that wraps the test Environment
and tells a PerformancePlotter
about the individual interactions.protected LearningAgentFactory[] agentFactories
protected int nTrials
protected int trialLength
protected boolean trialLengthIsInEpisodes
protected PerformancePlotter plotter
protected boolean displayPlots
protected int plotRefresh
protected double plotCISignificance
protected boolean completedExperiment
public int debugCode
public LearningAlgorithmExperimenter(Environment testEnvironment, int nTrials, int trialLength, LearningAgentFactory... agentFactories)
toggleTrialLengthInterpretation(boolean)
.testEnvironment
- the test Environment
in which experiments will be performed.nTrials
- the number of trialstrialLength
- the length of the trials (by default in episodes, but can be intereted as maximum step length)agentFactories
- factories to generate the agents to be tested.public void setUpPlottingConfiguration(int chartWidth, int chartHeight, int columns, int maxWindowHeight, TrialMode trialMode, PerformanceMetric... metrics)
chartWidth
- the width of each chart/plotchartHeight
- the height of each chart//plotcolumns
- the number of columns of the plots displayed. Plots are filled in columns first, then move down the next row.maxWindowHeight
- the maximum window height allowed before a scroll view is used.trialMode
- which plots to use; most recent trial, average over all trials, or both. If both, the most recent plot will be inserted into the window first, then the average.metrics
- the metrics that should be plotted. The metrics will appear in the window in the order that they are specified (columns first)public void setPlotRefreshDelay(int delayInMS)
delayInMS
- the delay in millisecondspublic void setPlotCISignificance(double significance)
significance
- the significance for confidence intervals to usepublic void toggleVisualPlots(boolean shouldPlotResults)
shouldPlotResults
- if true, then plots will be displayed; if false plots will not be displayed.public void toggleTrialLengthInterpretation(boolean lengthRepresentsEpisodes)
lengthRepresentsEpisodes
- if true, interpret length as number of episodes; if false interprete as total number of steps.public void startExperiment()
public void writeStepAndEpisodeDataToCSV(java.lang.String pathAndBaseNameToUse)
pathAndBaseNameToUse
- the base path and file name for the episode-wise and step-wise csv files.public void writeStepDataToCSV(java.lang.String filePath)
filePath
- the path to the csv file to write to.public void writeEpisodeDataToCSV(java.lang.String filePath)
filePath
- the path to the csv file to write to.protected void runEpisodeBoundTrial(LearningAgentFactory agentFactory)
agentFactory
- the agent factory used to generate the agent to test.protected void runStepBoundTrial(LearningAgentFactory agentFactory)
agentFactory
- the agent factory used to generate the agent to test.