package burlap.behavior.singleagent.auxiliary.performance;

import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.singleagent.learning.LearningAgentFactory;
import burlap.debugtools.DPrint;
import burlap.oomdp.auxiliary.StateGenerator;
import burlap.oomdp.singleagent.RewardFunction;
import burlap.oomdp.singleagent.SADomain;
import org.apache.commons.math3.random.EmpiricalDistribution;

/* loaded from: input_file:burlap/behavior/singleagent/auxiliary/performance/LearningAlgorithmExperimenter.class */
public class LearningAlgorithmExperimenter {
    protected SADomain domain;
    protected RewardFunction rf;
    protected StateGenerator sg;
    protected LearningAgentFactory[] agentFactories;
    protected int nTrials;
    protected int trialLength;
    protected boolean trialLengthIsInEpisodes = true;
    protected PerformancePlotter plotter = null;
    protected boolean displayPlots = true;
    protected int plotRefresh = EmpiricalDistribution.DEFAULT_BIN_COUNT;
    protected double plotCISignificance = 0.05d;
    protected boolean completedExperiment = false;
    public int debugCode = 63634013;

    public LearningAlgorithmExperimenter(SADomain sADomain, RewardFunction rewardFunction, StateGenerator stateGenerator, int i, int i2, LearningAgentFactory... learningAgentFactoryArr) {
        if (learningAgentFactoryArr.length == 0) {
            throw new RuntimeException("Zero agent factories provided. At least one must be given for an experiment");
        }
        this.domain = sADomain;
        this.rf = rewardFunction;
        this.sg = stateGenerator;
        this.nTrials = i;
        this.trialLength = i2;
        this.agentFactories = learningAgentFactoryArr;
    }

    public void setUpPlottingConfiguration(int i, int i2, int i3, int i4, TrialMode trialMode, PerformanceMetric... performanceMetricArr) {
        if (trialMode.averagesEnabled() && this.nTrials == 1) {
            trialMode = TrialMode.MOSTRECENTTTRIALONLY;
        }
        this.displayPlots = true;
        this.plotter = new PerformancePlotter(this.agentFactories[0].getAgentName(), this.rf, i, i2, i3, i4, trialMode, performanceMetricArr);
        this.plotter.setRefreshDelay(this.plotRefresh);
        this.plotter.setSignificanceForCI(this.plotCISignificance);
    }

    public void setPlotRefreshDelay(int i) {
        this.plotRefresh = i;
        if (this.plotter != null) {
            this.plotter.setRefreshDelay(i);
        }
    }

    public void setPlotCISignificance(double d) {
        this.plotCISignificance = d;
        if (this.plotter != null) {
            this.plotter.setSignificanceForCI(d);
        }
    }

    public void toggleVisualPlots(boolean z) {
        this.displayPlots = z;
    }

    public void toggleTrialLengthInterpretation(boolean z) {
        this.trialLengthIsInEpisodes = z;
    }

    public void startExperiment() {
        if (this.completedExperiment) {
            System.out.println("Experiment was already run and has completed. If you want to run a new experiment create a new Experiment object.");
            return;
        }
        if (this.plotter == null) {
            TrialMode trialMode = TrialMode.MOSTRECENTANDAVERAGE;
            if (this.nTrials == 1) {
                trialMode = TrialMode.MOSTRECENTTTRIALONLY;
            }
            this.plotter = new PerformancePlotter(this.agentFactories[0].getAgentName(), this.rf, 500, 250, 2, 500, trialMode, new PerformanceMetric[0]);
        }
        this.domain.addActionObserverForAllAction(this.plotter);
        if (this.displayPlots) {
            this.plotter.startGUI();
        }
        for (int i = 0; i < this.agentFactories.length; i++) {
            if (i > 0) {
                this.plotter.startNewAgent(this.agentFactories[i].getAgentName());
            }
            for (int i2 = 0; i2 < this.nTrials; i2++) {
                DPrint.cl(this.debugCode, "Beginning " + this.agentFactories[i].getAgentName() + " trial " + (i2 + 1) + "/" + this.nTrials);
                if (this.trialLengthIsInEpisodes) {
                    runEpisodeBoundTrial(this.agentFactories[i]);
                } else {
                    runStepBoundTrial(this.agentFactories[i]);
                }
            }
        }
        this.plotter.endAllAgents();
        this.completedExperiment = true;
    }

    public void writeStepAndEpisodeDataToCSV(String str) {
        if (this.completedExperiment) {
            this.plotter.writeStepAndEpisodeDataToCSV(str);
        } else {
            System.out.println("Cannot write data until the experiment has been started with the startExperiment() method.");
        }
    }

    public void writeStepDataToCSV(String str) {
        if (this.completedExperiment) {
            this.plotter.writeStepDataToCSV(str);
        } else {
            System.out.println("Cannot write data until the experiment has been started with the startExperiment() method.");
        }
    }

    public void writeEpisodeDataToCSV(String str) {
        if (this.completedExperiment) {
            this.plotter.writeEpisodeDataToCSV(str);
        } else {
            System.out.println("Cannot write data until the experiment has been started with the startExperiment() method.");
        }
    }

    protected void runEpisodeBoundTrial(LearningAgentFactory learningAgentFactory) {
        this.plotter.toggleDataCollection(false);
        LearningAgent generateAgent = learningAgentFactory.generateAgent();
        this.plotter.toggleDataCollection(true);
        this.plotter.startNewTrial();
        for (int i = 0; i < this.trialLength; i++) {
            generateAgent.runLearningEpisodeFrom(this.sg.generateState());
            this.plotter.endEpisode();
        }
        this.plotter.endTrial();
    }

    protected void runStepBoundTrial(LearningAgentFactory learningAgentFactory) {
        this.plotter.toggleDataCollection(false);
        LearningAgent generateAgent = learningAgentFactory.generateAgent();
        this.plotter.toggleDataCollection(true);
        this.plotter.startNewTrial();
        int i = this.trialLength;
        while (i > 0) {
            i -= generateAgent.runLearningEpisodeFrom(this.sg.generateState(), i).numTimeSteps() - 1;
            this.plotter.endEpisode();
        }
        this.plotter.endTrial();
    }
}
