@Override public void agent_end(double v) { DPrint.cl(this.debugCode, "Got agent end message"); synchronized (nextStateReference) { this.lastReward = v; this.curStateIsTerminal = true; nextStateReference.val = curState; nextStateReference.notifyAll(); } }
/** * Blocks the calling thread until a state is provided by the RLGlue server or the RLGlue experiment has ended. */ public void blockUntilStateReceived(){ synchronized(nextStateReference){ while(this.nextStateReference.val == null && !this.rlGlueExperimentFinished){ try{ DPrint.cl(debugCode, "Waiting for state from RLGlue Server..."); nextStateReference.wait(); } catch(InterruptedException ex){ ex.printStackTrace(); } } } }
/** * Sets the height and number of transition dynamics samples in a way that ensure epsilon optimality. * @param rmax the maximum reward value of the MDP * @param epsilon the epsilon optimality (amount that the estimated value function may diverge from the true optimal) * @param numActions the maximum number of actions that could be applied from a state */ public void setHAndCByMDPError(double rmax, double epsilon, int numActions){ double lambda = epsilon * (1. - this.gamma) * (1. - this.gamma) / 4.; double vmax = rmax / (1. - this.gamma); this.h = (int)logbase(this.gamma, lambda / vmax) + 1; this.c = (int)( (vmax*vmax / (lambda*lambda)) * (2 * this.h * Math.log(numActions*this.h * vmax * vmax / (lambda * lambda) + Math.log(rmax/lambda))) ); DPrint.cl(this.debugCode, "H = " + this.h); DPrint.cl(this.debugCode, "C = " + this.c); }
/** * Runs value iteration. Note that if the state samples have not been set, it will throw a runtime exception. */ public void runVI(){ for(int i = 0; i < this.maxIterations || this.maxIterations == -1; i++){ double change = this.runIteration(); DPrint.cl(this.debugCode, "Finished iteration " + i + "; max change: " + change); if(change < this.maxDelta){ break; } } }
/** * Plans from the input state and then returns a {@link burlap.behavior.policy.GreedyQPolicy} that greedily * selects the action with the highest Q-value and breaks ties uniformly randomly. * @param initialState the initial state of the planning problem * @return a {@link burlap.behavior.policy.GreedyQPolicy}. */ @Override public GreedyQPolicy planFromState(State initialState) { DPrint.cl(this.debugCode, "Beginning Planning."); int nr = 0; while(this.runRollout(initialState) > this.maxDiff && (nr < this.maxRollouts || this.maxRollouts == -1)){ nr++; } DPrint.cl(this.debugCode, "Finished planning with a total of " + this.numBellmanUpdates + " backups."); return new GreedyQPolicy(this); }
@Override public Action agent_start(Observation observation) { DPrint.cl(debugCode, "got agent start message, launching agent."); synchronized (nextStateReference) { this.curStateIsTerminal = false; this.lastReward = 0.; final State s = RLGlueDomain.stateFromObservation(observation); this.curState = s; this.nextStateReference.val = s; nextStateReference.notifyAll(); } Action toRet; synchronized (nextAction) { while(nextAction.val == null){ try{ DPrint.cl(debugCode, "Waiting for action..."); nextAction.wait(); } catch(InterruptedException ex){ ex.printStackTrace(); } } toRet = getRLGlueAction(nextAction.val); nextAction.val = null; } DPrint.cl(debugCode, "Returning first action."); return toRet; }
DPrint.cl(this.debugCode, "Finished writing step csv file to: " + filePath);
@Override public Action agent_step(double v, Observation observation) { DPrint.cl(this.debugCode, "Got agent step message"); synchronized (nextStateReference) { nextStateReference.val = RLGlueDomain.stateFromObservation(observation); this.lastReward = v; this.curState = nextStateReference.val; nextStateReference.notifyAll(); } Action toRet; synchronized (nextAction) { while(nextAction.val == null){ try{ DPrint.cl(debugCode, "Waiting for action..."); nextAction.wait(); } catch(InterruptedException ex){ ex.printStackTrace(); } } toRet = getRLGlueAction(nextAction.val); nextAction.val = null; } return toRet; }
DPrint.cl(this.debugCode, "Passes: " + i);
@Override public void agent_init(String arg0) { DPrint.cl(debugCode, "Started init"); DPrint.cl(debugCode, arg0); TaskSpec theTaskSpec = new TaskSpec(arg0); RLGlueDomain domainGenerator = new RLGlueDomain(theTaskSpec); this.discount = theTaskSpec.getDiscountFactor(); this.domain = domainGenerator.generateDomain(); synchronized(this.domainSet){ this.domainSet.val = 1; this.domainSet.notifyAll(); } }
DPrint.cl(this.debugCode, "Iterations in inner VI for policy eval: " + i); this.totalValueIterations += i;
/** * Performs multiple intention inverse reinforcement learning. */ public void performIRL(){ int k = this.clusterPriors.length; for(int i = 0; i < this.numEMIterations; i++){ DPrint.cl(this.debugCode, "Starting EM iteration " + (i+1) + "/" + this.numEMIterations); double [][] trajectoryPerClusterWeights = this.computePerClusterMLIRLWeights(); for(int j = 0; j < k; j++){ MLIRLRequest clusterRequest = this.clusterRequests.get(j); clusterRequest.setEpisodeWeights(trajectoryPerClusterWeights[j].clone()); this.mlirlInstance.setRequest(clusterRequest); this.mlirlInstance.performIRL(); } } DPrint.cl(this.debugCode, "Finished EM"); }
DPrint.cl(debugCode, "Pass: " + i + "; Num states: " + orderedStates.size() + " (total: " + totalStates + ")");
/** * Runs VI until the specified termination conditions are met. In general, this method should only be called indirectly through the {@link #planFromState(State)} method. * The {@link #performReachabilityFrom(State)} must have been performed at least once * in the past or a runtime exception will be thrown. The {@link #planFromState(State)} method will automatically call the {@link #performReachabilityFrom(State)} * method first and then this if it hasn't been run. */ public void runVI(){ if(!this.foundReachableStates){ throw new RuntimeException("Cannot run VI until the reachable states have been found. Use the planFromState, performReachabilityFrom, addStateToStateSpace or addStatesToStateSpace methods at least once before calling runVI."); } Set<HashableState> states = valueFunction.keySet(); int i; for(i = 0; i < this.maxIterations; i++){ double delta = 0.; for(HashableState sh : states){ double v = this.value(sh); double newV = this.performBellmanUpdateOn(sh); this.performDPValueGradientUpdateOn(sh); delta = Math.max(Math.abs(newV - v), delta); } if(delta < this.maxDelta){ break; //approximated well enough; stop iterating } } DPrint.cl(this.debugCode, "Passes: " + i); this.hasRunVI = true; }
/** * Plans from the input state and returns a {@link burlap.behavior.policy.BoltzmannQPolicy} following the * Boltzmann parameter used for value Botlzmann value backups in this planner. * @param initialState the initial state of the planning problem * @return a {@link burlap.behavior.policy.BoltzmannQPolicy} */ @Override public BoltzmannQPolicy planFromState(State initialState) { if(this.forgetPreviousPlanResults){ this.rootLevelQValues.clear(); } HashableState sh = this.hashingFactory.hashState(initialState); if(this.rootLevelQValues.containsKey(sh)){ return new BoltzmannQPolicy(this, 1./this.boltzBeta); //already planned for this state } DPrint.cl(this.debugCode, "Beginning Planning."); int oldUpdates = this.numUpdates; DiffStateNode sn = this.getStateNode(initialState, this.h); rootLevelQValues.put(sh, sn.estimateQs()); DPrint.cl(this.debugCode, "Finished Planning with " + (this.numUpdates - oldUpdates) + " value esitmates; for a cumulative total of: " + this.numUpdates); if(this.forgetPreviousPlanResults){ this.nodesByHeight.clear(); } return new BoltzmannQPolicy(this, 1./this.boltzBeta); }
/** * Runs LSPI for either numIterations or until the change in the weight matrix is no greater than maxChange. * @param numIterations the maximum number of policy iterations. * @param maxChange when the weight change is smaller than this value, LSPI terminates. * @return a {@link burlap.behavior.policy.GreedyQPolicy} using this object as the {@link QProvider} source. */ public GreedyQPolicy runPolicyIteration(int numIterations, double maxChange){ boolean converged = false; for(int i = 0; i < numIterations && !converged; i++){ SimpleMatrix nw = this.LSTDQ(); double change = Double.POSITIVE_INFINITY; if(this.lastWeights != null){ change = this.lastWeights.minus(nw).normF(); if(change <= maxChange){ converged = true; } } this.lastWeights = nw; DPrint.cl(0, "Finished iteration: " + i + ". Weight change: " + change); } DPrint.cl(0, "Finished Policy Iteration."); return new GreedyQPolicy(this); }
/** * Plans from the input state and then returns a {@link burlap.behavior.policy.GreedyQPolicy} that greedily * selects the action with the highest Q-value and breaks ties uniformly randomly. * @param initialState the initial state of the planning problem * @return a {@link burlap.behavior.policy.GreedyQPolicy}. */ @Override public GreedyQPolicy planFromState(State initialState) { int iterations = 0; if(this.performReachabilityFrom(initialState) || !this.hasRunPlanning){ double delta; do{ delta = this.evaluatePolicy(); iterations++; this.evaluativePolicy = new GreedyQPolicy(this.getCopyOfValueFunction()); }while(delta > this.maxPIDelta && iterations < maxPolicyIterations); this.hasRunPlanning = true; } DPrint.cl(this.debugCode, "Total policy iterations: " + iterations); this.totalPolicyIterations += iterations; return (GreedyQPolicy)this.evaluativePolicy; }
DPrint.cl(debugCode, "Num visted: " + numVisted);