/*
 * Decompiled with CFR 0.152.
 */
package com.numericalmethod.suanshu.optimization.unconstrained.conjugatedirection;

import com.numericalmethod.suanshu.analysis.differentiation.multivariate.GradientFunction;
import com.numericalmethod.suanshu.analysis.differentiation.multivariate.HessianFunction;
import com.numericalmethod.suanshu.analysis.function.matrix.RntoMatrix;
import com.numericalmethod.suanshu.analysis.function.rn2r1.RealScalarFunction;
import com.numericalmethod.suanshu.analysis.function.rn2rm.RealVectorFunction;
import com.numericalmethod.suanshu.matrix.doubles.Matrix;
import com.numericalmethod.suanshu.matrix.doubles.matrixtype.dense.DenseMatrix;
import com.numericalmethod.suanshu.optimization.unconstrained.steepestdescent.SteepestDescent;
import com.numericalmethod.suanshu.vector.doubles.Vector;

public class ConjugateGradient
extends SteepestDescent {
    private RntoMatrix E;

    @Override
    public SteepestDescent.SteepestDescentImpl getImplementation() {
        return new HestenesStiefel();
    }

    public void solve(RealScalarFunction f2, RealVectorFunction g2, RntoMatrix H, double tol, int maxIterations) {
        this.E = H;
        super.solve(f2, g2, tol, maxIterations);
    }

    @Override
    public void solve(RealScalarFunction f2, RealVectorFunction g2, double tol, int maxIterations) {
        this.solve(f2, new GradientFunction(f2), new HessianFunction(f2), tol, maxIterations);
    }

    public class HestenesStiefel
    extends SteepestDescent.SteepestDescentImpl {
        public Vector dk = null;

        public HestenesStiefel() {
            HestenesStiefel a2;
        }

        @Override
        public double linesearch(Vector xk, Vector dk) {
            double a2 = this.gk.innerProduct(this.gk);
            Matrix a3 = ConjugateGradient.this.E.evaluate(xk.toArray());
            DenseMatrix a4 = new DenseMatrix(dk);
            double a5 = a4.t().multiply(a3).multiply(a4).get(1, 1);
            double a6 = a2 / a5;
            return a6;
        }

        @Override
        public Vector getDirection(Vector xk) {
            Vector a2 = ConjugateGradient.this.g.evaluate(xk.toArray());
            Vector a3 = a2.scaled(-1.0);
            if (this.dk != null) {
                double a4 = a2.innerProduct(a2);
                Vector a5 = this.dk.scaled(a4 /= this.gk.innerProduct(this.gk));
                a3 = a3.add(a5);
            }
            this.gk = a2;
            this.dk = a3;
            return this.dk;
        }
    }
}

