/** * * @param graphDef * @return */ @Override public List<NodeDef> getNodeList(GraphDef graphDef) { return graphDef.getNodeList(); }
@Override public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) { try { GraphDef graphDef = GraphDef.parseFrom(inputFile); BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); for(NodeDef node : graphDef.getNodeList()) { bufferedWriter.write(node.toString()); } bufferedWriter.flush(); bufferedWriter.close(); } catch (IOException e) { e.printStackTrace(); } }
/** * {@inheritDoc} */ @Override public void dumpBinaryProtoAsText(File inputFile, File outputFile) { try { GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new FileInputStream(inputFile))); BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); for(NodeDef node : graphDef.getNodeList()) { bufferedWriter.write(node.toString()); } bufferedWriter.flush(); bufferedWriter.close(); } catch (IOException e) { e.printStackTrace(); } }
@Override public Map<String, NodeDef> variablesForGraph(GraphDef graphDef) { Map<String,NodeDef> ret = new LinkedHashMap<>(); for(NodeDef nodeDef : graphDef.getNodeList()) { if(nodeDef.getName().endsWith("/read")) { continue; } val name = translateToSameDiffName(nodeDef.getName(), nodeDef); ret.put(name,nodeDef); } return ret; }
@java.lang.Override public int hashCode() { if (memoizedHashCode != 0) { return memoizedHashCode; } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); if (getNodeCount() > 0) { hash = (37 * hash) + NODE_FIELD_NUMBER; hash = (53 * hash) + getNodeList().hashCode(); } if (hasVersions()) { hash = (37 * hash) + VERSIONS_FIELD_NUMBER; hash = (53 * hash) + getVersions().hashCode(); } hash = (37 * hash) + VERSION_FIELD_NUMBER; hash = (53 * hash) + getVersion(); if (hasLibrary()) { hash = (37 * hash) + LIBRARY_FIELD_NUMBER; hash = (53 * hash) + getLibrary().hashCode(); } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; }
int currNodeIndex = graph.getNodeList().indexOf(from); val trueDefName = from.getInput(1); val falseDefName = from.getInput(0);
@java.lang.Override public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } if (!(obj instanceof org.tensorflow.framework.GraphDef)) { return super.equals(obj); } org.tensorflow.framework.GraphDef other = (org.tensorflow.framework.GraphDef) obj; boolean result = true; result = result && getNodeList() .equals(other.getNodeList()); result = result && (hasVersions() == other.hasVersions()); if (hasVersions()) { result = result && getVersions() .equals(other.getVersions()); } result = result && (getVersion() == other.getVersion()); result = result && (hasLibrary() == other.hasLibrary()); if (hasLibrary()) { result = result && getLibrary() .equals(other.getLibrary()); } result = result && unknownFields.equals(other.unknownFields); return result; }
val nodes = graph.getNodeList();
for(val node : graph.getNodeList()) { if(node.getName().equals(nodeDef.getInput(0))) { startNode = node;
@Procedure(value = "load.tensorflow", mode = Mode.WRITE) public Stream<LoadResult> loadTensorFlow(@Name("file") String url) throws IOException { GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new URL(url).openStream())); Map<String, Node> nodes = new HashMap<>(); // tod model node, layer nodes for (NodeDef nodeDef : graphDef.getNodeList()) { Node node = db.createNode(Types.Neuron); node.setProperty("name", nodeDef.getName()); if (nodeDef.getDevice() != null) node.setProperty("device", nodeDef.getDevice()); node.setProperty("op", nodeDef.getOp()); nodeDef.getAttrMap().forEach((k, v) -> { Object value = getValue(v); if (value != null) { node.setProperty(k, value); } }); nodes.put(nodeDef.getName(), node); } long rels = 0; for (NodeDef nodeDef : graphDef.getNodeList()) { Node target = nodes.get(nodeDef.getName()); nodeDef.getInputList().forEach(name -> nodes.get(name).createRelationshipTo(target, RelTypes.INPUT)); // todo weights rels += nodeDef.getInputCount(); } return Stream.of(new LoadResult(url,"tensorflow",nodes.size(), rels)); }
@java.lang.Override public int hashCode() { if (memoizedHashCode != 0) { return memoizedHashCode; } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); if (getNodeCount() > 0) { hash = (37 * hash) + NODE_FIELD_NUMBER; hash = (53 * hash) + getNodeList().hashCode(); } if (hasVersions()) { hash = (37 * hash) + VERSIONS_FIELD_NUMBER; hash = (53 * hash) + getVersions().hashCode(); } hash = (37 * hash) + VERSION_FIELD_NUMBER; hash = (53 * hash) + getVersion(); if (hasLibrary()) { hash = (37 * hash) + LIBRARY_FIELD_NUMBER; hash = (53 * hash) + getLibrary().hashCode(); } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; }
@java.lang.Override public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } if (!(obj instanceof org.tensorflow.framework.GraphDef)) { return super.equals(obj); } org.tensorflow.framework.GraphDef other = (org.tensorflow.framework.GraphDef) obj; boolean result = true; result = result && getNodeList() .equals(other.getNodeList()); result = result && (hasVersions() == other.hasVersions()); if (hasVersions()) { result = result && getVersions() .equals(other.getVersions()); } result = result && (getVersion() == other.getVersion()); result = result && (hasLibrary() == other.hasLibrary()); if (hasLibrary()) { result = result && getLibrary() .equals(other.getLibrary()); } result = result && unknownFields.equals(other.unknownFields); return result; }