/*
 * Decompiled with CFR 0.152.
 */
package com.anji.polebalance;

import com.anji.integration.Activator;
import com.anji.integration.ActivatorTranscriber;
import com.anji.polebalance.PoleBalanceDisplay;
import com.anji.util.Arrays;
import com.anji.util.Configurable;
import com.anji.util.Properties;
import com.anji.util.Randomizer;
import java.nio.DoubleBuffer;
import java.util.List;
import java.util.Random;
import org.apache.log4j.Logger;
import org.jgap.BulkFitnessFunction;
import org.jgap.Chromosome;

public class DoublePoleBalanceFitnessFunction
implements BulkFitnessFunction,
Configurable {
    private static final String TRACK_LENGTH_KEY = "polebalance.track.length";
    private static final String TIMESTEPS_KEY = "polebalance.timesteps";
    private static final String NUM_TRIALS_KEY = "polebalance.trials";
    private static final String ANGLE_THRESHOLD_KEY = "polebalance.angle.threshold";
    private static final String INPUT_VELOCITY_KEY = "polebalance.input.velocities";
    private static final String POLE_1_LENGTH_KEY = "pole.1.length";
    private static final String POLE_2_LENGTH_KEY = "pole.2.length";
    private static final String START_POLE_ANGLE_1_KEY = "polebalance.pole.angle.start.1";
    private static final String START_POLE_ANGLE_2_KEY = "polebalance.pole.angle.start.2";
    private static final String START_POLE_ANGLE_RANDOM_KEY = "polebalance.pole.angle.start.random";
    private static final String PENALIZE_FOR_ENERGY_USE_KEY = "penalize.for.energy.use";
    private static final String PENALIZE_OSCILLATIONS_KEY = "penalize.oscillations";
    private static final double GRAVITY = -9.8;
    private static final double MASSCART = 1.0;
    private static final double FORCE_MAG = 10.0;
    private static final double TIME_DELTA = 0.01;
    private static final double FOURTHIRDS = 1.3333333333333333;
    private static final double MUP = 2.0E-6;
    private static final double ONE_DEGREE = Math.PI / 180;
    private static final double SIX_DEGREES = 0.10471975511965977;
    private static final double TWELVE_DEGREES = 0.20943951023931953;
    private static final double EIGHTEEN_DEGREES = 0.3141592653589793;
    private static final double TWENTYFOUR_DEGREES = 0.41887902047863906;
    private static final double THIRTYSIX_DEGREES = 0.6283185307179586;
    private static final double FIFTY_DEGREES = 0.8726646259971648;
    private static final double SEVENTYTWO_DEGREES = 1.2566370614359172;
    private PoleBalanceDisplay display = null;
    private static final double DEFAULT_TRACK_LENGTH = 4.8;
    private double trackLength = 4.8;
    private double trackLengthHalfed;
    private static final int DEFAULT_TIMESTEPS = 10000;
    private int maxTimesteps = 10000;
    private static final int DEFAULT_NUM_TRIALS = 10;
    private int numTrials = 10;
    private double poleAngleThreshold = 0.6283185307179586;
    private static final Logger logger = Logger.getLogger(DoublePoleBalanceFitnessFunction.class);
    private ActivatorTranscriber factory;
    private boolean doInputVelocities = true;
    private double poleLength1 = 0.5;
    private double poleMass1 = 0.1;
    private double poleLength2 = 0.05;
    private double poleMass2 = 0.01;
    private double startPoleAngle1 = Math.PI / 180;
    private double startPoleAngle2 = 0.0;
    private boolean startPoleAngleRandom = false;
    private boolean penalizeEnergyUse = false;
    private boolean penalizeOscillations = false;
    private Random rand;

    private void setTrackLength(double aTrackLength) {
        this.trackLength = aTrackLength;
        this.trackLengthHalfed = this.trackLength / 2.0;
    }

    @Override
    public void init(Properties props) throws Exception {
        try {
            this.factory = (ActivatorTranscriber)props.singletonObjectProperty(ActivatorTranscriber.class);
            this.setTrackLength(props.getDoubleProperty(TRACK_LENGTH_KEY, 4.8));
            this.maxTimesteps = props.getIntProperty(TIMESTEPS_KEY, 10000);
            this.numTrials = props.getIntProperty(NUM_TRIALS_KEY, 10);
            this.poleAngleThreshold = props.getDoubleProperty(ANGLE_THRESHOLD_KEY, 0.6283185307179586);
            this.doInputVelocities = props.getBooleanProperty(INPUT_VELOCITY_KEY, true);
            this.poleLength1 = props.getDoubleProperty(POLE_1_LENGTH_KEY, 0.5) / 2.0;
            this.poleMass1 = this.poleLength1 / 5.0;
            this.poleLength2 = props.getDoubleProperty(POLE_2_LENGTH_KEY, 0.05) / 2.0;
            this.poleMass2 = this.poleLength2 / 5.0;
            this.startPoleAngle1 = props.getDoubleProperty(START_POLE_ANGLE_1_KEY, Math.PI / 180);
            this.startPoleAngle2 = props.getDoubleProperty(START_POLE_ANGLE_2_KEY, 0.0);
            this.startPoleAngleRandom = props.getBooleanProperty(START_POLE_ANGLE_RANDOM_KEY, false);
            this.penalizeEnergyUse = props.getBooleanProperty(PENALIZE_FOR_ENERGY_USE_KEY, false);
            this.penalizeOscillations = props.getBooleanProperty(PENALIZE_OSCILLATIONS_KEY, false);
            Randomizer randomizer = (Randomizer)props.singletonObjectProperty(Randomizer.class);
            this.rand = randomizer.getRand();
        }
        catch (Exception e) {
            throw new IllegalArgumentException("invalid properties: " + e.getClass().toString() + ": " + e.getMessage());
        }
    }

    @Override
    public void evaluate(List genotypes) {
        for (Chromosome c2 : genotypes) {
            this.evaluate(c2);
        }
    }

    public void evaluate(Chromosome c2) {
        try {
            Activator activator = this.factory.newActivator(c2);
            int fitness = 0;
            int i = 0;
            while (i < this.numTrials) {
                fitness += this.singleTrial(activator);
                ++i;
            }
            c2.setFitnessValue(fitness);
        }
        catch (Throwable e) {
            logger.warn("error evaluating chromosome " + c2.toString(), e);
            c2.setFitnessValue(0);
        }
    }

    private double[] newState() {
        double[] state = new double[6];
        state[5] = 0.0;
        state[3] = 0.0;
        state[1] = 0.0;
        state[0] = 0.0;
        if (this.startPoleAngleRandom) {
            state[2] = this.rand.nextGaussian() * this.startPoleAngle1;
            state[4] = this.rand.nextGaussian() * this.startPoleAngle2;
        } else {
            state[2] = this.startPoleAngle1;
            state[4] = this.startPoleAngle2;
        }
        return state;
    }

    private int singleTrial(Activator activator) {
        double[] state = this.newState();
        double energyUsed = 0.0;
        double f2 = 0.0;
        int fitness = 0;
        DoubleBuffer oscillBuffer = DoubleBuffer.allocate(10000);
        logger.debug("state = " + Arrays.toString(state));
        int currentTimestep = 0;
        currentTimestep = 0;
        while (currentTimestep < this.maxTimesteps) {
            double[] networkInput = this.doInputVelocities ? new double[]{state[0] / this.trackLengthHalfed, state[1] / 0.75, state[2] / this.poleAngleThreshold, state[3], state[4] / this.poleAngleThreshold, state[5], 1.0} : new double[]{state[0] / this.trackLengthHalfed, state[2] / this.poleAngleThreshold, state[4] / this.poleAngleThreshold, 1.0};
            oscillBuffer.put(Math.abs(state[0]) + Math.abs(state[1]) + Math.abs(state[2]) + Math.abs(state[3]));
            double networkOutput = activator.next(networkInput)[0];
            energyUsed += networkOutput;
            this.performAction(networkOutput, state);
            if (this.display != null) {
                this.display.step(currentTimestep, state[0], new double[]{state[2], state[4]});
            }
            if (state[0] < -this.trackLengthHalfed || state[0] > this.trackLengthHalfed || state[2] > this.poleAngleThreshold || state[2] < -this.poleAngleThreshold || state[4] > this.poleAngleThreshold || state[4] < -this.poleAngleThreshold) break;
            if (currentTimestep % 1000 == 0) {
                if (currentTimestep > 99 && f2 > 0.0) {
                    f2 = 0.75 / f2;
                }
                fitness = (int)((double)fitness + (0.1 + 0.9 * f2));
            }
            ++currentTimestep;
        }
        if (this.penalizeEnergyUse) {
            currentTimestep -= (int)(energyUsed / 10.0);
        } else {
            fitness = currentTimestep;
        }
        if (this.penalizeOscillations) {
            int remainder = currentTimestep % 1000;
            int f2_steps = Math.min(100, remainder);
            f2 = 0.0;
            int i = 0;
            while (i < f2_steps) {
                f2 += oscillBuffer.get();
                ++i;
            }
            fitness = (int)((double)fitness + (0.1 * (double)remainder + 0.9 * f2));
        } else {
            fitness = currentTimestep;
        }
        logger.debug("trial took " + currentTimestep + " steps");
        return fitness;
    }

    private void performAction(double output, double[] state) {
        double[] dydx = new double[6];
        boolean RK4 = true;
        double EULER_TAU = 0.0025;
        if (RK4) {
            int i = 0;
            while (i < 2) {
                dydx[0] = state[1];
                dydx[2] = state[3];
                dydx[4] = state[5];
                this.step(output, state, dydx);
                this.rk4(output, state, dydx, state);
                ++i;
            }
        } else {
            int i = 0;
            while (i < 8) {
                this.step(output, state, dydx);
                state[0] = state[0] + EULER_TAU * dydx[0];
                state[1] = state[1] + EULER_TAU * dydx[1];
                state[2] = state[2] + EULER_TAU * dydx[2];
                state[3] = state[3] + EULER_TAU * dydx[3];
                state[4] = state[4] + EULER_TAU * dydx[4];
                state[5] = state[5] + EULER_TAU * dydx[5];
                ++i;
            }
        }
    }

    private void step(double action, double[] st, double[] derivs) {
        double force = (action - 0.5) * 10.0 * 2.0;
        double costheta_1 = Math.cos(st[2]);
        double sintheta_1 = Math.sin(st[2]);
        double gsintheta_1 = -9.8 * sintheta_1;
        double costheta_2 = Math.cos(st[4]);
        double sintheta_2 = Math.sin(st[4]);
        double gsintheta_2 = -9.8 * sintheta_2;
        double ml_1 = this.poleLength1 * this.poleMass1;
        double ml_2 = this.poleLength2 * this.poleMass2;
        double temp_1 = 2.0E-6 * st[3] / ml_1;
        double temp_2 = 2.0E-6 * st[5] / ml_2;
        double fi_1 = ml_1 * st[3] * st[3] * sintheta_1 + 0.75 * this.poleMass1 * costheta_1 * (temp_1 + gsintheta_1);
        double fi_2 = ml_2 * st[5] * st[5] * sintheta_2 + 0.75 * this.poleMass2 * costheta_2 * (temp_2 + gsintheta_2);
        double mi_1 = this.poleMass1 * (1.0 - 0.75 * costheta_1 * costheta_1);
        double mi_2 = this.poleMass2 * (1.0 - 0.75 * costheta_2 * costheta_2);
        derivs[1] = (force + fi_1 + fi_2) / (mi_1 + mi_2 + 1.0);
        derivs[3] = -0.75 * (derivs[1] * costheta_1 + gsintheta_1 + temp_1) / this.poleLength1;
        derivs[5] = -0.75 * (derivs[1] * costheta_2 + gsintheta_2 + temp_2) / this.poleLength2;
    }

    private void rk4(double f, double[] y, double[] dydx, double[] yout) {
        double[] dym = new double[6];
        double[] dyt = new double[6];
        double[] yt = new double[6];
        double hh = 0.005;
        double h6 = 0.0016666666666666668;
        int i = 0;
        while (i <= 5) {
            yt[i] = y[i] + hh * dydx[i];
            ++i;
        }
        this.step(f, yt, dyt);
        dyt[0] = yt[1];
        dyt[2] = yt[3];
        dyt[4] = yt[5];
        i = 0;
        while (i <= 5) {
            yt[i] = y[i] + hh * dyt[i];
            ++i;
        }
        this.step(f, yt, dym);
        dym[0] = yt[1];
        dym[2] = yt[3];
        dym[4] = yt[5];
        i = 0;
        while (i <= 5) {
            yt[i] = y[i] + 0.01 * dym[i];
            int n = i;
            dym[n] = dym[n] + dyt[i];
            ++i;
        }
        this.step(f, yt, dyt);
        dyt[0] = yt[1];
        dyt[2] = yt[3];
        dyt[4] = yt[5];
        i = 0;
        while (i <= 5) {
            yout[i] = y[i] + h6 * (dydx[i] + dyt[i] + 2.0 * dym[i]);
            ++i;
        }
    }

    @Override
    public int getMaxFitnessValue() {
        return this.numTrials * this.maxTimesteps;
    }

    public void enableDisplay() {
        this.display = new PoleBalanceDisplay(this.trackLength, new double[]{this.poleLength1, this.poleLength2}, this.maxTimesteps);
        this.display.setVisible(true);
    }
}

