@java.lang.Override public int hashCode() { if (memoizedHashCode != 0) { return memoizedHashCode; } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); if (getNodeCount() > 0) { hash = (37 * hash) + NODE_FIELD_NUMBER; hash = (53 * hash) + getNodeList().hashCode(); } if (hasVersions()) { hash = (37 * hash) + VERSIONS_FIELD_NUMBER; hash = (53 * hash) + getVersions().hashCode(); } hash = (37 * hash) + VERSION_FIELD_NUMBER; hash = (53 * hash) + getVersion(); if (hasLibrary()) { hash = (37 * hash) + LIBRARY_FIELD_NUMBER; hash = (53 * hash) + getLibrary().hashCode(); } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; }
@Override public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) { try { GraphDef graphDef = GraphDef.parseFrom(inputFile); BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); for(NodeDef node : graphDef.getNodeList()) { bufferedWriter.write(node.toString()); } bufferedWriter.flush(); bufferedWriter.close(); } catch (IOException e) { e.printStackTrace(); } }
public Builder mergeFrom(org.tensorflow.framework.GraphDef other) { if (other == org.tensorflow.framework.GraphDef.getDefaultInstance()) return this; if (nodeBuilder_ == null) { if (!other.node_.isEmpty()) { if (other.hasVersions()) { mergeVersions(other.getVersions()); if (other.getVersion() != 0) { setVersion(other.getVersion()); if (other.hasLibrary()) { mergeLibrary(other.getLibrary());
@Override public NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) { for(int i = 0; i < graph.getNodeCount(); i++) { val node = graph.getNode(i); if(node.getName().equals(name)) return node; } return null; }
@java.lang.Override public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } if (!(obj instanceof org.tensorflow.framework.GraphDef)) { return super.equals(obj); } org.tensorflow.framework.GraphDef other = (org.tensorflow.framework.GraphDef) obj; boolean result = true; result = result && getNodeList() .equals(other.getNodeList()); result = result && (hasVersions() == other.hasVersions()); if (hasVersions()) { result = result && getVersions() .equals(other.getVersions()); } result = result && (getVersion() == other.getVersion()); result = result && (hasLibrary() == other.hasLibrary()); if (hasLibrary()) { result = result && getLibrary() .equals(other.getLibrary()); } result = result && unknownFields.equals(other.unknownFields); return result; }
public void writeTo(com.github.os72.protobuf351.CodedOutputStream output) throws java.io.IOException { for (int i = 0; i < node_.size(); i++) { output.writeMessage(1, node_.get(i)); } if (library_ != null) { output.writeMessage(2, getLibrary()); } if (version_ != 0) { output.writeInt32(3, version_); } if (versions_ != null) { output.writeMessage(4, getVersions()); } unknownFields.writeTo(output); }
/** * <pre> * Definition of remote graph * </pre> * * <code>.tensorflow.GraphDef remote_graph = 1;</code> */ public org.tensorflow.framework.GraphDef getRemoteGraph() { return remoteGraph_ == null ? org.tensorflow.framework.GraphDef.getDefaultInstance() : remoteGraph_; } /**
int currNodeIndex = graph.getNodeList().indexOf(from); val trueDefName = from.getInput(1); val falseDefName = from.getInput(0); if(graph.getNode(i).getName().equals(trueDefName)) { onFalseDefinition = false; onTrueDefinition = true; if(graph.getNode(i).getName().contains("pred_id")) { onTrueDefinition = false; if(onTrueDefinition && !graph.getNode(i).equals(from)) { trueBodyNodes.add(graph.getNode(i)); else if(onFalseDefinition && !graph.getNode(i).equals(from)) { falseBodyNodes.add(graph.getNode(i)); val currNode = graph.getNode(i); if(currNode.equals(from)) continue; if(!seenNames.contains(graph.getNode(i).getName()) && !graph.getNode(i).getName().contains("pred_id")) { break; seenNames.add(graph.getNode(i).getName()); conditionNodes.add(graph.getNode(i));
/** * * @param graphDef * @return */ @Override public List<NodeDef> getNodeList(GraphDef graphDef) { return graphDef.getNodeList(); }
@Override public GraphDef parseGraphFrom(InputStream inputStream) throws IOException { return GraphDef.parseFrom(inputStream); }
@java.lang.Override public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } if (!(obj instanceof org.tensorflow.framework.RemoteFusedGraphExecuteInfo)) { return super.equals(obj); } org.tensorflow.framework.RemoteFusedGraphExecuteInfo other = (org.tensorflow.framework.RemoteFusedGraphExecuteInfo) obj; boolean result = true; result = result && (hasRemoteGraph() == other.hasRemoteGraph()); if (hasRemoteGraph()) { result = result && getRemoteGraph() .equals(other.getRemoteGraph()); } result = result && getGraphInputNodeNameList() .equals(other.getGraphInputNodeNameList()); result = result && getGraphOutputNodeNameList() .equals(other.getGraphOutputNodeNameList()); result = result && getExecutorName() .equals(other.getExecutorName()); result = result && getSerializedExecutorParameters() .equals(other.getSerializedExecutorParameters()); result = result && getDefaultGraphInputTensorShapeList() .equals(other.getDefaultGraphInputTensorShapeList()); result = result && getDefaultGraphOutputTensorShapeList() .equals(other.getDefaultGraphOutputTensorShapeList()); result = result && unknownFields.equals(other.unknownFields); return result; }
public org.tensorflow.framework.GraphDef buildPartial() { org.tensorflow.framework.GraphDef result = new org.tensorflow.framework.GraphDef(this); int from_bitField0_ = bitField0_; int to_bitField0_ = 0; if (nodeBuilder_ == null) { if (((bitField0_ & 0x00000001) == 0x00000001)) { node_ = java.util.Collections.unmodifiableList(node_); bitField0_ = (bitField0_ & ~0x00000001); } result.node_ = node_; } else { result.node_ = nodeBuilder_.build(); } if (versionsBuilder_ == null) { result.versions_ = versions_; } else { result.versions_ = versionsBuilder_.build(); } result.version_ = version_; if (libraryBuilder_ == null) { result.library_ = library_; } else { result.library_ = libraryBuilder_.build(); } result.bitField0_ = to_bitField0_; onBuilt(); return result; }
return getLibrary();
@java.lang.Override public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } if (!(obj instanceof org.tensorflow.framework.GraphDef)) { return super.equals(obj); } org.tensorflow.framework.GraphDef other = (org.tensorflow.framework.GraphDef) obj; boolean result = true; result = result && getNodeList() .equals(other.getNodeList()); result = result && (hasVersions() == other.hasVersions()); if (hasVersions()) { result = result && getVersions() .equals(other.getVersions()); } result = result && (getVersion() == other.getVersion()); result = result && (hasLibrary() == other.hasLibrary()); if (hasLibrary()) { result = result && getLibrary() .equals(other.getLibrary()); } result = result && unknownFields.equals(other.unknownFields); return result; }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val idd = nodeDef.getInput(nodeDef.getInputCount() - 1); NodeDef iddNode = null; for(int i = 0; i < graph.getNodeCount(); i++) { if(graph.getNode(i).getName().equals(idd)) { iddNode = graph.getNode(i); } } val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph); if (arr != null) { int idx = arr.getInt(0); addIArgument(idx); } }
public int getSerializedSize() { int size = memoizedSize; if (size != -1) return size; size = 0; for (int i = 0; i < node_.size(); i++) { size += com.github.os72.protobuf351.CodedOutputStream .computeMessageSize(1, node_.get(i)); } if (library_ != null) { size += com.github.os72.protobuf351.CodedOutputStream .computeMessageSize(2, getLibrary()); } if (version_ != 0) { size += com.github.os72.protobuf351.CodedOutputStream .computeInt32Size(3, version_); } if (versions_ != null) { size += com.github.os72.protobuf351.CodedOutputStream .computeMessageSize(4, getVersions()); } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; }
public org.tensorflow.framework.GraphDef getDefaultInstanceForType() { return org.tensorflow.framework.GraphDef.getDefaultInstance(); }
@Override public Map<String, NodeDef> variablesForGraph(GraphDef graphDef) { Map<String,NodeDef> ret = new LinkedHashMap<>(); for(NodeDef nodeDef : graphDef.getNodeList()) { if(nodeDef.getName().endsWith("/read")) { continue; } val name = translateToSameDiffName(nodeDef.getName(), nodeDef); ret.put(name,nodeDef); } return ret; }
@Override public GraphDef parseGraphFrom(byte[] inputStream) throws IOException { return GraphDef.parseFrom(inputStream); }
@java.lang.Override public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } if (!(obj instanceof org.tensorflow.framework.RemoteFusedGraphExecuteInfo)) { return super.equals(obj); } org.tensorflow.framework.RemoteFusedGraphExecuteInfo other = (org.tensorflow.framework.RemoteFusedGraphExecuteInfo) obj; boolean result = true; result = result && (hasRemoteGraph() == other.hasRemoteGraph()); if (hasRemoteGraph()) { result = result && getRemoteGraph() .equals(other.getRemoteGraph()); } result = result && getGraphInputNodeNameList() .equals(other.getGraphInputNodeNameList()); result = result && getGraphOutputNodeNameList() .equals(other.getGraphOutputNodeNameList()); result = result && getExecutorName() .equals(other.getExecutorName()); result = result && getSerializedExecutorParameters() .equals(other.getSerializedExecutorParameters()); result = result && getDefaultGraphInputTensorShapeList() .equals(other.getDefaultGraphInputTensorShapeList()); result = result && getDefaultGraphOutputTensorShapeList() .equals(other.getDefaultGraphOutputTensorShapeList()); result = result && unknownFields.equals(other.unknownFields); return result; }