@Override public List<StateTransitionProb> stateTransitions(State s, Action a) { double [] directionProbs = transitionDynamics[actionInd(a.actionName())]; List <StateTransitionProb> transitions = new ArrayList<StateTransitionProb>(); for(int i = 0; i < directionProbs.length; i++){ double p = directionProbs[i]; if(p == 0.){ continue; //cannot transition in this direction } State ns = s.copy(); int [] dcomps = movementDirectionFromIndex(i); ns = move(ns, dcomps[0], dcomps[1]); //make sure this direction doesn't actually stay in the same place and replicate another no-op boolean isNew = true; for(StateTransitionProb tp : transitions){ if(tp.s.equals(ns)){ isNew = false; tp.p += p; break; } } if(isNew){ StateTransitionProb tp = new StateTransitionProb(ns, p); transitions.add(tp); } } return transitions; }
@Override public OOSADomain generateDomain() { OOSADomain domain = new OOSADomain(); int [][] cmap = this.getMap(); domain.addStateClass(CLASS_AGENT, GridAgent.class).addStateClass(CLASS_LOCATION, GridLocation.class); GridWorldModel smodel = new GridWorldModel(cmap, getTransitionDynamics()); RewardFunction rf = this.rf; TerminalFunction tf = this.tf; if(rf == null){ rf = new UniformCostRF(); } if(tf == null){ tf = new NullTermination(); } FactoredModel model = new FactoredModel(smodel, rf, tf); domain.setModel(model); domain.addActionTypes( new UniversalActionType(ACTION_NORTH), new UniversalActionType(ACTION_SOUTH), new UniversalActionType(ACTION_EAST), new UniversalActionType(ACTION_WEST)); OODomain.Helper.addPfsToDomain(domain, this.generatePfs()); return domain; }
@Override public State sample(State s, Action a) { s = s.copy(); double [] directionProbs = transitionDynamics[actionInd(a.actionName())]; double roll = rand.nextDouble(); double curSum = 0.; int dir = 0; for(int i = 0; i < directionProbs.length; i++){ curSum += directionProbs[i]; if(roll < curSum){ dir = i; break; } } int [] dcomps = movementDirectionFromIndex(dir); return move(s, dcomps[0], dcomps[1]); }