/** * Assert that the specified workspace is open and active * * @param ws Name of the workspace to assert open and active * @param errorMsg Message to include in the exception, if required */ public static void assertOpenAndActive(@NonNull String ws, @NonNull String errorMsg) throws ND4JWorkspaceException { if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(ws)) { throw new ND4JWorkspaceException(errorMsg); } }
private void enforceExistsAndActive(@NonNull T arrayType){ validateConfig(arrayType); if(scopeOutOfWs.contains(arrayType)){ return; } if(!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(workspaceNames.get(arrayType))){ throw new ND4JWorkspaceException("Workspace \"" + workspaceNames.get(arrayType) + "\" for array type " + arrayType + " is not open"); } } }
private static List<String> allOpenWorkspaces(){ List<MemoryWorkspace> l = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); List<String> workspaces = new ArrayList<>(l.size()); for( MemoryWorkspace ws : l){ if(ws.isScopeActive()) { workspaces.add(ws.getId()); } } return workspaces; }
private static List<String> allOpenWorkspaces(){ List<MemoryWorkspace> l = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); List<String> workspaces = new ArrayList<>(l.size()); for( MemoryWorkspace ws : l){ if(ws.isScopeActive()) { workspaces.add(ws.getId()); } } return workspaces; } }
@Override public boolean isWorkspaceOpen(@NonNull T arrayType) { validateConfig(arrayType); if(!scopeOutOfWs.contains(arrayType)) { return Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(getWorkspaceName(arrayType)); } return true; }
/** * Assert that no workspaces are currently open * * @param msg Message to include in the exception, if required */ public static void assertNoWorkspacesOpen(String msg) throws ND4JWorkspaceException { if (Nd4j.getWorkspaceManager().anyWorkspaceActiveForCurrentThread()) { List<MemoryWorkspace> l = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread(); List<String> workspaces = new ArrayList<>(l.size()); for (MemoryWorkspace ws : l) { if(ws.isScopeActive()) { workspaces.add(ws.getId()); } } throw new ND4JWorkspaceException(msg + " - Open/active workspaces: " + workspaces); } }
@Override public MemoryWorkspace notifyScopeBorrowed(@NonNull T arrayType) { validateConfig(arrayType); enforceExistsAndActive(arrayType); if(scopeOutOfWs.contains(arrayType)){ return Nd4j.getWorkspaceManager().scopeOutOfWorkspaces(); } else { MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( getConfiguration(arrayType), getWorkspaceName(arrayType)); return ws.notifyScopeBorrowed(); } }
@Override public MemoryWorkspace notifyScopeEntered(@NonNull T arrayType) { validateConfig(arrayType); if(isScopedOut(arrayType)){ return Nd4j.getWorkspaceManager().scopeOutOfWorkspaces(); } else { MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( getConfiguration(arrayType), getWorkspaceName(arrayType)); return ws.notifyScopeEntered(); } }
/** * Assert that the specified workspace is open, active, and is the current workspace * * @param ws Name of the workspace to assert open/active/current * @param errorMsg Message to include in the exception, if required */ public static void assertOpenActiveAndCurrent(@NonNull String ws, @NonNull String errorMsg) throws ND4JWorkspaceException { if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(ws)) { throw new ND4JWorkspaceException(errorMsg + " - workspace is not open and active"); } MemoryWorkspace currWs = Nd4j.getMemoryManager().getCurrentWorkspace(); if (currWs == null || !ws.equals(currWs.getId())) { throw new ND4JWorkspaceException(errorMsg + " - not the current workspace (current workspace: " + (currWs == null ? null : currWs.getId())); } }
private void initWorkspace() { workspace = Nd4j.getWorkspaceManager().createNewWorkspace( WorkspaceConfiguration.builder() .initialSize(memoryForGraph()) .policyAllocation(AllocationPolicy.OVERALLOCATE) .policyLearning(LearningPolicy.FIRST_LOOP) .build()); Nd4j.getWorkspaceManager().setWorkspaceForCurrentThread(workspace); }
/** * This method detaches INDArray from current Workspace, and attaches it to Workspace with a given Id, if a workspace * with the given ID is open and active. * * If the workspace does not exist, or is not active, the array is detached from any workspaces. * * @param id ID of the workspace to leverage to * @return The INDArray, leveraged to the specified workspace (if it exists and is active) otherwise the detached array * @see #leverageTo(String) */ public INDArray leverageOrDetach(String id){ if(!isAttached()){ return this; } if(!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(id)){ return detach(); } return leverageTo(id); }
this.id = workspaceId; this.threadId = Thread.currentThread().getId(); this.guid = Nd4j.getWorkspaceManager().getUUID(); this.memoryManager = Nd4j.getMemoryManager(); this.deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID")) { try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(learningConfig, "OTHER_ID")) { INDArray array = Nd4j.create(100); try(MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID")) { INDArray array = Nd4j.create(10, 10).assign(1.0f); INDArray sumRes; try(MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "THIRD_ID")) { try(MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID")) { INDArray array1 = Nd4j.create(10, 10).assign(1.0f); INDArray array2; try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(circularConfig, "CIRCULAR_ID")) { INDArray array = Nd4j.create(100);
@Override public INDArray getActivation(INDArray in, boolean training) { if (training) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { this.alpha = Nd4j.rand(in.shape(), l, u, Nd4j.getRandom()); } INDArray inTimesAlpha = in.mul(alpha); BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0)); } else { this.alpha = null; double a = 0.5 * (l + u); return Nd4j.getExecutioner().execAndReturn(new RectifedLinear(in, a)); } return in; }
return this; if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(id)) { if(enforceExistence){ throw new Nd4jNoSuchWorkspaceException(id); MemoryWorkspace target = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(id);
@Override public MemoryWorkspace getWorkspaceForCurrentThread(@NonNull WorkspaceConfiguration configuration, @NonNull String id) { ensureThreadExistense(); MemoryWorkspace workspace = backingMap.get().get(id); if (workspace == null) { workspace = newWorkspace(configuration, id); backingMap.get().put(id, workspace); if (Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.BYPASS_EVERYTHING) pickReference(workspace); } return workspace; }
@Override public MemoryWorkspace createNewWorkspace(WorkspaceConfiguration configuration, String id) { ensureThreadExistense(); MemoryWorkspace workspace = newWorkspace(configuration, id); backingMap.get().put(id, workspace); if (Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.BYPASS_EVERYTHING) pickReference(workspace); return workspace; }
@Override public MemoryWorkspace createNewWorkspace(WorkspaceConfiguration configuration, String id, Integer deviceId) { ensureThreadExistense(); MemoryWorkspace workspace = newWorkspace(configuration, id, deviceId); backingMap.get().put(id, workspace); if (Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.BYPASS_EVERYTHING) pickReference(workspace); return workspace; }
@Override public MemoryWorkspace createNewWorkspace(@NonNull WorkspaceConfiguration configuration) { ensureThreadExistense(); MemoryWorkspace workspace = newWorkspace(configuration); backingMap.get().put(workspace.getId(), workspace); if (Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.BYPASS_EVERYTHING) pickReference(workspace); return workspace; }
@Override public MemoryWorkspace createNewWorkspace() { ensureThreadExistense(); MemoryWorkspace workspace = newWorkspace(defaultConfiguration); backingMap.get().put(workspace.getId(), workspace); if (Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.BYPASS_EVERYTHING) pickReference(workspace); return workspace; }