@java.lang.Override public int hashCode() { if (memoizedHashCode != 0) { return memoizedHashCode; } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); hash = (37 * hash) + NAME_FIELD_NUMBER; hash = (53 * hash) + getName().hashCode(); hash = (37 * hash) + OP_FIELD_NUMBER; hash = (53 * hash) + getOp().hashCode(); if (getInputCount() > 0) { hash = (37 * hash) + INPUT_FIELD_NUMBER; hash = (53 * hash) + getInputList().hashCode(); } hash = (37 * hash) + DEVICE_FIELD_NUMBER; hash = (53 * hash) + getDevice().hashCode(); if (!internalGetAttr().getMap().isEmpty()) { hash = (37 * hash) + ATTR_FIELD_NUMBER; hash = (53 * hash) + internalGetAttr().hashCode(); } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; }
@Override public Map<String, AttrValue> getAttrMap(NodeDef nodeDef) { return nodeDef.getAttrMap(); }
/** * <pre> * The operation name. There may be custom parameters in attrs. * Op names starting with an underscore are reserved for internal use. * </pre> * * <code>string op = 2;</code> */ public Builder clearOp() { op_ = getDefaultInstance().getOp(); onChanged(); return this; } /**
protected boolean hasReductionIndices(NodeDef nodeDef) { for(int i = 0; i < nodeDef.getInputCount(); i++) { if(nodeDef.getInput(i).contains("reduction_indices")) { return true; } } return false; }
private void doImport(NodeDef nodeDef,SameDiff initWith,Map<String,AttrValue> attributesForNode,GraphDef graph,Set<String> skipSet,AtomicInteger currIndex) { val uniqueId = java.util.UUID.randomUUID().toString(); skipSet.add(nodeDef.getName()); val scopeCondition = SameDiff.create(); val scopeLoop = SameDiff.create(); if (!tfNode.getOp().equalsIgnoreCase("enter")) { skipSet.add(tfNode.getName()); val vars = new SDVariable[tfNode.getInputCount()]; for (int e = 0; e < tfNode.getInputCount(); e++) { val input = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(e)); vars[e] = initWith.getVariable(input) == null ? initWith.var(input,null,new ZeroInitScheme()) : initWith.getVariable(input); scopeCondition.var(vars[e]); if (!tfNode.getOp().equalsIgnoreCase("merge")) { scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()),null,new ZeroInitScheme()); break; skipSet.add(tfNode.getName()); val var = scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()),null,new ZeroInitScheme()); scopeCondition.var(var); initWith.var(var); if (tfNode.getOp().equalsIgnoreCase("LoopCond")) { skipSet.add(tfNode.getName()); currIndex.incrementAndGet(); break;
@java.lang.Override public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } if (!(obj instanceof org.tensorflow.framework.NodeDef)) { return super.equals(obj); } org.tensorflow.framework.NodeDef other = (org.tensorflow.framework.NodeDef) obj; boolean result = true; result = result && getName() .equals(other.getName()); result = result && getOp() .equals(other.getOp()); result = result && getInputList() .equals(other.getInputList()); result = result && getDevice() .equals(other.getDevice()); result = result && internalGetAttr().equals( other.internalGetAttr()); result = result && unknownFields.equals(other.unknownFields); return result; }
@Override public String getOpType(NodeDef nodeDef) { return nodeDef.getOp(); }
@Procedure(value = "load.tensorflow", mode = Mode.WRITE) public Stream<LoadResult> loadTensorFlow(@Name("file") String url) throws IOException { GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new URL(url).openStream())); Map<String, Node> nodes = new HashMap<>(); // tod model node, layer nodes for (NodeDef nodeDef : graphDef.getNodeList()) { Node node = db.createNode(Types.Neuron); node.setProperty("name", nodeDef.getName()); if (nodeDef.getDevice() != null) node.setProperty("device", nodeDef.getDevice()); node.setProperty("op", nodeDef.getOp()); nodeDef.getAttrMap().forEach((k, v) -> { Object value = getValue(v); if (value != null) { node.setProperty(k, value); } }); nodes.put(nodeDef.getName(), node); } long rels = 0; for (NodeDef nodeDef : graphDef.getNodeList()) { Node target = nodes.get(nodeDef.getName()); nodeDef.getInputList().forEach(name -> nodes.get(name).createRelationshipTo(target, RelTypes.INPUT)); // todo weights rels += nodeDef.getInputCount(); } return Stream.of(new LoadResult(url,"tensorflow",nodes.size(), rels)); }
public Builder mergeFrom(org.tensorflow.framework.NodeDef other) { if (other == org.tensorflow.framework.NodeDef.getDefaultInstance()) return this; if (!other.getName().isEmpty()) { name_ = other.name_; onChanged(); } if (!other.getOp().isEmpty()) { op_ = other.op_; onChanged(); } if (!other.input_.isEmpty()) { if (input_.isEmpty()) { input_ = other.input_; bitField0_ = (bitField0_ & ~0x00000004); } else { ensureInputIsMutable(); input_.addAll(other.input_); } onChanged(); } if (!other.getDevice().isEmpty()) { device_ = other.device_; onChanged(); } internalGetMutableAttr().mergeFrom( other.internalGetAttr()); this.mergeUnknownFields(other.unknownFields); onChanged(); return this; }
if (!getNameBytes().isEmpty()) { size += com.github.os72.protobuf351.GeneratedMessageV3.computeStringSize(1, name_); if (!getOpBytes().isEmpty()) { size += com.github.os72.protobuf351.GeneratedMessageV3.computeStringSize(2, op_); dataSize += computeStringSizeNoTag(input_.getRaw(i)); size += 1 * getInputList().size(); if (!getDeviceBytes().isEmpty()) { size += com.github.os72.protobuf351.GeneratedMessageV3.computeStringSize(4, device_); : internalGetAttr().getMap().entrySet()) { com.github.os72.protobuf351.MapEntry<java.lang.String, org.tensorflow.framework.AttrValue> attr__ = AttrDefaultEntryHolder.defaultEntry.newBuilderForType()
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { if (!nodeDef.containsAttr("TShape") && nodeDef.getInputCount() == 1) { this.shape = new long[]{}; return; } else if (nodeDef.getInputCount() > 1) { val shapeNode = nodeDef.getInput(1); NodeDef shapeNodeInGraph = null; for (int i = 0; i < graph.getNodeCount(); i++) { if (graph.getNode(i).getName().equals(shapeNode)) { shapeNodeInGraph = graph.getNode(i); addIArgument(this.shape); else { arrName = nodeDef.getName(); val shape = nodeDef.getAttrOrThrow("Tshape"); if (!shape.hasShape()) { val shapeRet = new long[2];
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val idd = nodeDef.getInput(nodeDef.getInputCount() - 1); NodeDef iddNode = null; for(int i = 0; i < graph.getNodeCount(); i++) { if(graph.getNode(i).getName().equals(idd)) { iddNode = graph.getNode(i); } } val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value",iddNode,graph); if (arr != null) { int idx = arr.getInt(0); addIArgument(idx); } }
val trueDefName = from.getInput(1); val falseDefName = from.getInput(0); val scopeId = UUID.randomUUID().toString(); val scopeName = scopeId + "-" + trueDefName.substring(0,trueDefName.indexOf("/")); if(graph.getNode(i).getName().equals(trueDefName)) { onFalseDefinition = false; onTrueDefinition = true; if(graph.getNode(i).getName().contains("pred_id")) { onTrueDefinition = false; if(onTrueDefinition && !graph.getNode(i).equals(from)) { trueBodyNodes.add(graph.getNode(i)); else if(onFalseDefinition && !graph.getNode(i).equals(from)) { falseBodyNodes.add(graph.getNode(i)); if(currNode.equals(from)) continue; if(!seenNames.contains(graph.getNode(i).getName()) && !graph.getNode(i).getName().contains("pred_id")) { break; for(int inputIdx = 0; inputIdx < currNode.getInputCount(); inputIdx++) { seenNames.add(currNode.getInput(inputIdx)); seenNames.add(graph.getNode(i).getName()); conditionNodes.add(graph.getNode(i));
@Override public String getName(NodeDef nodeDef) { return nodeDef.getName(); }
/** * <pre> * The name given to this operator. Used for naming inputs, * logging, visualization, etc. Unique within a single GraphDef. * Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". * </pre> * * <code>string name = 1;</code> */ public Builder clearName() { name_ = getDefaultInstance().getName(); onChanged(); return this; } /**
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val inputBegin = nodeDef.getInput(1); val inputEnd = nodeDef.getInput(2); val inputStrides = nodeDef.getInput(3); if(graph.getNode(i).getName().equals(inputBegin)) { beginNode = graph.getNode(i); if(graph.getNode(i).getName().equals(inputEnd)) { endNode = graph.getNode(i); if(graph.getNode(i).getName().equals(inputStrides)) { strides = graph.getNode(i); val bm = nodeDef.getAttrOrThrow("begin_mask"); val xm = nodeDef.getAttrOrThrow("ellipsis_mask"); val em = nodeDef.getAttrOrThrow("end_mask"); val nm = nodeDef.getAttrOrThrow("new_axis_mask"); val sm = nodeDef.getAttrOrThrow("shrink_axis_mask");
/** * <pre> * A (possibly partial) specification for the device on which this * node should be placed. * The expected syntax for this string is as follows: * DEVICE_SPEC ::= PARTIAL_SPEC * PARTIAL_SPEC ::= ("/" CONSTRAINT) * * CONSTRAINT ::= ("job:" JOB_NAME) * | ("replica:" [1-9][0-9]*) * | ("task:" [1-9][0-9]*) * | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) * Valid values for this string include: * * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) * * "/job:worker/device:GPU:3" (partial specification) * * "" (no specification) * If the constraints do not resolve to a single device (or if this * field is empty or not present), the runtime will attempt to * choose a device automatically. * </pre> * * <code>string device = 4;</code> */ public Builder clearDevice() { device_ = getDefaultInstance().getDevice(); onChanged(); return this; } /**
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val targetNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); val maxlen = TFGraphMapper.getInstance().getNDArrayFromTensor("value", targetNode, graph); if (maxlen == null){ // No 2nd input this.is_static_maxlen = true; } TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (is_static_maxlen) { addIArgument(this.maxLen); } } @Override