@Override public GraphDef parseGraphFrom(InputStream inputStream) throws IOException { return GraphDef.parseFrom(inputStream); }
@Override public GraphDef parseGraphFrom(byte[] inputStream) throws IOException { return GraphDef.parseFrom(inputStream); }
/** * {@inheritDoc} */ @Override public void dumpBinaryProtoAsText(File inputFile, File outputFile) { try { GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new FileInputStream(inputFile))); BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); for(NodeDef node : graphDef.getNodeList()) { bufferedWriter.write(node.toString()); } bufferedWriter.flush(); bufferedWriter.close(); } catch (IOException e) { e.printStackTrace(); } }
@Override public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) { try { GraphDef graphDef = GraphDef.parseFrom(inputFile); BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); for(NodeDef node : graphDef.getNodeList()) { bufferedWriter.write(node.toString()); } bufferedWriter.flush(); bufferedWriter.close(); } catch (IOException e) { e.printStackTrace(); } }
private void initSessionAndStatusIfNeeded(byte[] graphToUse) { try { //use the protobuf api to load the graph definition and load the node metadata org.tensorflow.framework.GraphDef graphDef1 = org.tensorflow.framework.GraphDef.parseFrom(graphToUse); initSessionAndStatusIfNeeded(graphDef1); } catch (InvalidProtocolBufferException e) { e.printStackTrace(); } }
@Procedure(value = "load.tensorflow", mode = Mode.WRITE) public Stream<LoadResult> loadTensorFlow(@Name("file") String url) throws IOException { GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new URL(url).openStream())); Map<String, Node> nodes = new HashMap<>(); // tod model node, layer nodes for (NodeDef nodeDef : graphDef.getNodeList()) { Node node = db.createNode(Types.Neuron); node.setProperty("name", nodeDef.getName()); if (nodeDef.getDevice() != null) node.setProperty("device", nodeDef.getDevice()); node.setProperty("op", nodeDef.getOp()); nodeDef.getAttrMap().forEach((k, v) -> { Object value = getValue(v); if (value != null) { node.setProperty(k, value); } }); nodes.put(nodeDef.getName(), node); } long rels = 0; for (NodeDef nodeDef : graphDef.getNodeList()) { Node target = nodes.get(nodeDef.getName()); nodeDef.getInputList().forEach(name -> nodes.get(name).createRelationshipTo(target, RelTypes.INPUT)); // todo weights rels += nodeDef.getInputCount(); } return Stream.of(new LoadResult(url,"tensorflow",nodes.size(), rels)); }