protected void considerUpdate(double fx_candiate, double fx, double predictedReduction, double stepLength, boolean converged , boolean accepted) { ConfigTrustRegion config = new ConfigTrustRegion(); MockTrustRegionBase_F64 alg = new MockTrustRegionBase_F64(null); alg.configure(config); alg.regionRadius = 2; assertEquals(converged, alg.considerCandidate(fx_candiate,fx,predictedReduction,stepLength)); double ratio = (fx-fx_candiate)/predictedReduction; if( fx_candiate < fx ) { assertEquals(GaussNewtonBase_F64.Mode.COMPUTE_DERIVATIVES,alg.mode()); } else { assertEquals(GaussNewtonBase_F64.Mode.DETERMINE_STEP,alg.mode()); } if( !accepted ) { assertEquals(2*0.5,alg.regionRadius); } else { if (ratio <= 0.5) { assertEquals(2 * 0.5, alg.regionRadius); } else { assertEquals(Math.max(3 * stepLength, 2), alg.regionRadius); } } }
/** * sees if it's checking the region radius for problems */ @Test public void checkConvergenceFTest_radius() { MockTrustRegionBase_F64 alg = createFixedCost(1,0.5); alg.regionRadius = 0; try { alg.checkConvergenceFTest(-1,-1); fail("Should have thrown an exception"); } catch( OptimizationException ignore){} alg.regionRadius = Double.NaN; try { alg.checkConvergenceFTest(-1,-1); fail("Should have thrown an exception"); } catch( OptimizationException ignore){} }
private MockTrustRegionBase_F64 createFixedCost( double cost , double predictedReduction ) { MockParameterUpdate update = new MockParameterUpdate() { @Override public double getPredictedReduction() { return predictedReduction; } }; return new MockTrustRegionBase_F64(update) { @Override protected double cost(DMatrixRMaj x) { return cost; } }; }
@Test public void initialize() { MockTrustRegionBase_F64 alg = createFixedCost(1,0.5); double x[] = new double[]{1,2}; alg.initialize(x,2,0); assertEquals(1,alg.fx, UtilEjml.TEST_F64); assertEquals(alg.regionRadius,alg.config.regionInitial, UtilEjml.TEST_F64); }
@Test public void checkConvergenceFTest() { MockTrustRegionBase_F64 alg = createFixedCost(1,0.5); alg.regionRadius = 1; alg.config.ftol = 1e-4; assertTrue(alg.checkConvergenceFTest(2,2)); assertTrue(alg.checkConvergenceFTest(2,2*(1+1e-5))); assertFalse(alg.checkConvergenceFTest(2,2*(1+9e-3))); }