package burlap.behavior.singleagent.planning.vfa.fittedvi;

import burlap.behavior.singleagent.planning.ValueFunction;
import burlap.behavior.singleagent.planning.vfa.fittedvi.SupervisedVFA;
import burlap.behavior.singleagent.vfa.StateToFeatureVectorGenerator;
import burlap.datastructures.WekaInterfaces;
import burlap.oomdp.core.State;
import java.util.List;
import weka.classifiers.Classifier;
import weka.classifiers.lazy.IBk;
import weka.core.Instances;
import weka.core.SelectedTag;
import weka.core.neighboursearch.KDTree;

/* loaded from: input_file:burlap/behavior/singleagent/planning/vfa/fittedvi/WekaVFATrainer.class */
public class WekaVFATrainer implements SupervisedVFA {
    protected WekaClassifierGenerator baseClassifier;
    protected StateToFeatureVectorGenerator fvGen;

    /* loaded from: input_file:burlap/behavior/singleagent/planning/vfa/fittedvi/WekaVFATrainer$WekaClassifierGenerator.class */
    public interface WekaClassifierGenerator {
        Classifier generateClassifier();
    }

    /* loaded from: input_file:burlap/behavior/singleagent/planning/vfa/fittedvi/WekaVFATrainer$WekaVFA.class */
    public static class WekaVFA implements ValueFunction {
        protected StateToFeatureVectorGenerator fvGen;
        protected Classifier classifier;

        public WekaVFA(StateToFeatureVectorGenerator stateToFeatureVectorGenerator, Classifier classifier) {
            this.fvGen = stateToFeatureVectorGenerator;
            this.classifier = classifier;
        }

        @Override // burlap.behavior.singleagent.planning.ValueFunction
        public double value(State state) {
            double[] generateFeatureVectorFrom = this.fvGen.generateFeatureVectorFrom(state);
            try {
                return this.classifier.classifyInstance(WekaInterfaces.getInstance(generateFeatureVectorFrom, 0.0d, WekaInterfaces.getInstancesShell(generateFeatureVectorFrom, 1)));
            } catch (Exception e) {
                throw new RuntimeException("WekaVFA could not produce prediction for instance. Returned message:\n" + e.getMessage());
            }
        }
    }

    public WekaVFATrainer(WekaClassifierGenerator wekaClassifierGenerator, StateToFeatureVectorGenerator stateToFeatureVectorGenerator) {
        this.baseClassifier = wekaClassifierGenerator;
        this.fvGen = stateToFeatureVectorGenerator;
    }

    @Override // burlap.behavior.singleagent.planning.vfa.fittedvi.SupervisedVFA
    public ValueFunction train(List<SupervisedVFA.SupervisedVFAInstance> list) {
        Instances instancesShell = WekaInterfaces.getInstancesShell(list.get(0).s, this.fvGen, list.size());
        for (SupervisedVFA.SupervisedVFAInstance supervisedVFAInstance : list) {
            instancesShell.add(WekaInterfaces.getInstance(supervisedVFAInstance.s, this.fvGen, supervisedVFAInstance.v, instancesShell));
        }
        Classifier generateClassifier = this.baseClassifier.generateClassifier();
        try {
            generateClassifier.buildClassifier(instancesShell);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return new WekaVFA(this.fvGen, generateClassifier);
    }

    public static WekaVFATrainer getKNNTrainer(StateToFeatureVectorGenerator stateToFeatureVectorGenerator, final int i) {
        return new WekaVFATrainer(new WekaClassifierGenerator() { // from class: burlap.behavior.singleagent.planning.vfa.fittedvi.WekaVFATrainer.1
            @Override // burlap.behavior.singleagent.planning.vfa.fittedvi.WekaVFATrainer.WekaClassifierGenerator
            public Classifier generateClassifier() {
                IBk iBk = new IBk();
                iBk.setNearestNeighbourSearchAlgorithm(new KDTree());
                iBk.setKNN(i);
                iBk.setDistanceWeighting(new SelectedTag(4, IBk.TAGS_WEIGHTING));
                return iBk;
            }
        }, stateToFeatureVectorGenerator);
    }
}
