package burlap.behavior.singleagent.interfaces.rlglue.common;

import burlap.behavior.learningrate.ConstantLR;
import burlap.behavior.learningrate.ExponentialDecayLR;
import burlap.behavior.learningrate.LearningRate;
import burlap.behavior.singleagent.Policy;
import burlap.behavior.singleagent.interfaces.rlglue.RLGlueAgentShell;
import burlap.behavior.singleagent.interfaces.rlglue.RLGlueLearningAgentFactory;
import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.singleagent.learning.tdmethods.vfa.GradientDescentSarsaLam;
import burlap.behavior.singleagent.planning.commonpolicies.BoltzmannQPolicy;
import burlap.behavior.singleagent.planning.commonpolicies.EpsilonGreedy;
import burlap.behavior.singleagent.vfa.cmac.CMACFeatureDatabase;
import burlap.datastructures.CommandLineOptions;
import burlap.oomdp.core.Attribute;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/interfaces/rlglue/common/RLGlueCMACSarsaLambdaFactory.class */
public class RLGlueCMACSarsaLambdaFactory implements RLGlueLearningAgentFactory {
    protected int nTiles = 5;
    protected double defaultTileWidth = 0.3d;
    protected Map<Integer, Double> tileWidths = new HashMap();
    protected LearningRate learningRate = new ConstantLR(Double.valueOf(0.02d));
    protected Policy learningPolicy = new EpsilonGreedy(0.1d);
    protected double initialFunctionWeight = 0.0d;
    protected double lambda = 0.5d;

    public int getnTiles() {
        return this.nTiles;
    }

    public void setnTiles(int i) {
        this.nTiles = i;
    }

    public double getDefaultTileWidth() {
        return this.defaultTileWidth;
    }

    public void setDefaultTileWidth(double d) {
        this.defaultTileWidth = d;
    }

    public LearningRate getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(LearningRate learningRate) {
        this.learningRate = learningRate;
    }

    public Policy getLearningPolicy() {
        return this.learningPolicy;
    }

    public void setLearningPolicy(Policy policy) {
        this.learningPolicy = policy;
    }

    public double getInitialFunctionWeight() {
        return this.initialFunctionWeight;
    }

    public void setInitialFunctionWeight(double d) {
        this.initialFunctionWeight = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public void addTileWidth(int i, double d) {
        this.tileWidths.put(Integer.valueOf(i), Double.valueOf(d));
    }

    @Override // burlap.behavior.singleagent.interfaces.rlglue.RLGlueLearningAgentFactory
    public LearningAgent generateAgentForRLDomain(Domain domain, double d, RewardFunction rewardFunction, TerminalFunction terminalFunction) {
        CMACFeatureDatabase cMACFeatureDatabase = new CMACFeatureDatabase(this.nTiles, CMACFeatureDatabase.TilingArrangement.RANDOMJITTER);
        int i = 0;
        for (Attribute attribute : domain.getAttributes()) {
            Double d2 = this.tileWidths.get(Integer.valueOf(i));
            cMACFeatureDatabase.addSpecificationForAllTilings("real", attribute, Double.valueOf(d2 != null ? d2.doubleValue() : this.defaultTileWidth).doubleValue());
            i++;
        }
        GradientDescentSarsaLam gradientDescentSarsaLam = new GradientDescentSarsaLam(domain, rewardFunction, terminalFunction, d, cMACFeatureDatabase.generateVFA(this.initialFunctionWeight), 0.1d, this.lambda);
        gradientDescentSarsaLam.setLearningRate(this.learningRate);
        gradientDescentSarsaLam.setLearningPolicy(getLearningPolicy());
        return gradientDescentSarsaLam;
    }

    public static void main(String[] strArr) {
        CommandLineOptions commandLineOptions = new CommandLineOptions(strArr);
        if (commandLineOptions.containsOption("help")) {
            System.out.println("--help: print this message\n--lambda=v: sets the lambda value\n--qinit=v: sets the initial q-value to v everywhere\n--constant_lr=v: sets a constant learnign rate to v\n--exp_lr_base=v: sets the learning rate to an exponential decay with expoential base v\n--exp_lr_init=v: sets the learning rate to an exponential decay with initial learning rate v\n--exp_lr_min=v: sets the learning rate to an exponential decay with minimum learning rate v\n--egreedy=v: sets the learning policy to epsilon greedy with epsilon = v\n--boltzmann=v: sets the learning policy to boltzmann with temperature = v\n--defaultTileWidth=v: sets the tile width to v for attributes that do not have a specific width set\n--nTilings=v: sets the number of overlapping tilings to use\n--tileWidth_i=v: sets the tile width for continue attribute i to v\n\nBy default lambda = 0.5, learning rate is constant 0.1, q initialization is zero, and epsilon greedy policy with epsilon = 0.1");
            System.exit(0);
        }
        System.out.println("Use --help to see varaible settings.");
        RLGlueCMACSarsaLambdaFactory rLGlueCMACSarsaLambdaFactory = new RLGlueCMACSarsaLambdaFactory();
        if (commandLineOptions.containsOption("lambda")) {
            rLGlueCMACSarsaLambdaFactory.setLambda(Double.parseDouble(commandLineOptions.optionValue("lamba")));
        }
        if (commandLineOptions.containsOption("qinit")) {
            rLGlueCMACSarsaLambdaFactory.setInitialFunctionWeight(Double.parseDouble(commandLineOptions.optionValue("qinit")));
        }
        if (commandLineOptions.containsOption("egreedy")) {
            rLGlueCMACSarsaLambdaFactory.setLearningPolicy(new EpsilonGreedy(Double.parseDouble(commandLineOptions.optionValue("egreedy"))));
        }
        if (commandLineOptions.containsOption("boltzmann")) {
            rLGlueCMACSarsaLambdaFactory.setLearningPolicy(new BoltzmannQPolicy(Double.parseDouble(commandLineOptions.optionValue("boltzmann"))));
        }
        if (commandLineOptions.containsOption("constant_lr")) {
            rLGlueCMACSarsaLambdaFactory.setLearningRate(new ConstantLR(Double.valueOf(Double.parseDouble(commandLineOptions.optionValue("constant_lr")))));
        } else {
            boolean z = false;
            double d = 0.1d;
            double d2 = 0.99d;
            double d3 = Double.MIN_VALUE;
            if (commandLineOptions.containsOption("exp_lr_base")) {
                z = true;
                d2 = Double.parseDouble(commandLineOptions.optionValue("exp_lr_base"));
            }
            if (commandLineOptions.containsOption("exp_lr_init")) {
                z = true;
                d = Double.parseDouble(commandLineOptions.optionValue("exp_lr_init"));
            }
            if (commandLineOptions.containsOption("exp_lr_min")) {
                z = true;
                d3 = Double.parseDouble(commandLineOptions.optionValue("exp_lr_min"));
            }
            if (z) {
                rLGlueCMACSarsaLambdaFactory.setLearningRate(new ExponentialDecayLR(d, d2, d3));
            }
        }
        if (commandLineOptions.containsOption("defaultTileWidth")) {
            rLGlueCMACSarsaLambdaFactory.setDefaultTileWidth(Double.parseDouble(commandLineOptions.optionValue("defaultTileWidth")));
        }
        if (commandLineOptions.containsOption("nTilings")) {
            rLGlueCMACSarsaLambdaFactory.setnTiles(Integer.parseInt(commandLineOptions.optionValue("nTilings")));
        }
        for (String str : commandLineOptions.getOptionsStartingWithName("tileWidth_")) {
            rLGlueCMACSarsaLambdaFactory.addTileWidth(Integer.parseInt(str.split("_")[1]), Double.parseDouble(commandLineOptions.optionValue(str)));
        }
        RLGlueAgentShell rLGlueAgentShell = new RLGlueAgentShell(rLGlueCMACSarsaLambdaFactory);
        System.out.println("Loading agent into RLGlue...");
        rLGlueAgentShell.loadAgent();
    }
}
