result = result && getValueCase().equals( other.getValueCase()); if (!result) return false; switch (valueCase_) { case 2: result = result && getS() .equals(other.getS()); break; case 3: result = result && (getI() == other.getI()); break; case 4: result = result && ( java.lang.Float.floatToIntBits(getF()) == java.lang.Float.floatToIntBits( other.getF())); break; case 5: result = result && (getB() == other.getB()); break; case 6: result = result && getTypeValue() == other.getTypeValue(); break; case 7: result = result && getShape() .equals(other.getShape());
@Override public long[] getShapeFromTensor(NodeDef tensorProto) { if(tensorProto.containsAttr("shape")) { return shapeFromShapeProto(tensorProto.getAttrOrThrow("shape").getShape()); } //yet to be determined shape, or tied to an op where output shape is dynamic else if(!tensorProto.containsAttr("value")) { return null; } else return shapeFromShapeProto(tensorProto.getAttrOrThrow("value").getTensor().getTensorShape()); }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val aStrides = nodeDef.getAttrOrThrow("strides"); val tfStrides = aStrides.getList().getIList(); val tfKernels = aKernels.getList().getIList(); val padding = aPadding.getList().getIList(); val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"",""); val attr = nodeDef.getAttrOrThrow("data_format"); data_format = attr.getS().toStringUtf8().toLowerCase();
switch (attr.getValueCase()) { case B: if (adapter != null) { adapter.mapAttributeFor(attr.getB(), currentField, on); case FUNC: break; case S: val setString = attr.getS().toStringUtf8(); if(adapter != null) { adapter.mapAttributeFor(setString,currentField,on); break; case I: val setInt = (int) attr.getI(); if(adapter != null) { adapter.mapAttributeFor(setInt,currentField,on); break; case SHAPE: val shape = attr.getShape().getDimList(); int[] dimsToSet = new int[shape.size()]; for(int i = 0; i < dimsToSet.length; i++) { case PLACEHOLDER: break; case LIST: val setList = attr.getList(); if(!setList.getIList().isEmpty()) { val intList = Ints.toArray(setList.getIList());
val type = attr.getType(); if(fields == null) { throw new ND4JIllegalStateException("No fields found for op " + mapping); switch(type) { case DT_BOOL: valueToSet = attr.getB(); break; case DT_INT8: valueToSet = attr.getI(); break; case DT_INT16: valueToSet = attr.getI(); break; case DT_INT32: valueToSet = attr.getI(); break; case DT_FLOAT: valueToSet = attr.getF(); break; case DT_DOUBLE: valueToSet = attr.getF(); break; case DT_STRING: valueToSet = attr.getS(); break; case DT_INT64: valueToSet = attr.getI(); break;
if (!shape.hasShape()) { val shapeRet = new long[2]; shapeRet[0] = 1; shapeRet[1] = shape.getValueCase().getNumber(); this.shape = shapeRet; } else { val shapeVals = shape.getShape().getDimList(); if (shapeVals.size() > 1) { this.shape = new long[shapeVals.size()];
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val attrAxis = nodeDef.getAttrOrThrow("axis"); int axis = (int) attrAxis.getI(); this.axis = axis; addArgs(); }
@Override public long[] getShapeFromAttr(AttrValue attr) { return shapeFromShapeProto(attr.getShape()); }
@Override public String getAttrValueFromNode(NodeDef nodeDef, String key) { return nodeDef.getAttrOrThrow(key).getS().toStringUtf8(); }
@Override public INDArray getNDArrayFromTensor(String tensorName, NodeDef node, GraphDef graph) { //placeholder of some kind if(!node.getAttrMap().containsKey("value")) { return null; } val tfTensor = node.getAttrOrThrow("value").getTensor(); return mapTensorProto(tfTensor); }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val aAlpha = nodeDef.getAttrOrThrow("alpha"); val aBeta = nodeDef.getAttrOrThrow("beta"); val aBias = nodeDef.getAttrOrThrow("bias"); val aDepth = nodeDef.getAttrOrThrow("depth_radius"); val alpha = aAlpha.getF(); val beta = aBeta.getF(); val bias = aBias.getF(); val depth = aDepth.getF(); LocalResponseNormalizationConfig localResponseNormalizationConfig = LocalResponseNormalizationConfig.builder() .alpha(alpha) .beta(beta) .bias(bias) .depth((int) depth) .build(); this.config = localResponseNormalizationConfig; addArgs(); }
public org.tensorflow.framework.AttrValue buildPartial() { org.tensorflow.framework.AttrValue result = new org.tensorflow.framework.AttrValue(this); if (valueCase_ == 2) { result.value_ = value_;
private Object getValue(AttrValue v) { switch (v.getValueCase()) { case S: return v.getS().toStringUtf8(); case I: return v.getI(); case F: return v.getF(); case B: return v.getB(); case TYPE: return v.getType().name(); // todo case SHAPE: return v.getShape().toString(); // tdo case TENSOR: return v.getTensor().toString(); // todo handle with prefxied properties case LIST: return v.getList().toString(); // todo getType/Count(idx) and then handle each type with prefixed property case FUNC: return v.getFunc().getAttrMap().toString(); // todo handle recursively case PLACEHOLDER: break; case VALUE_NOT_SET: return null; default: return null; } return null; }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val aStrides = nodeDef.getAttrOrThrow("strides"); val tfStrides = aStrides.getList().getIList(); val sY = tfStrides.get(1); val sX = tfStrides.get(2); val tfKernels = aKernels.getList().getIList(); val padding = aPadding.getList().getIList(); val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"","");
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { val numSplits = (int) attributesForNode.get("num_split").getI(); this.numSplit = numSplits; addIArgument(numSplits); val splitDim = TFGraphMapper.getInstance().getArrayFrom(TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph); if(splitDim != null) { this.splitDim = splitDim.getInt(0); addIArgument(splitDim.getInt(0)); } }
@Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); val isTransposeA = attributesForNode.get("transpose_a").getB(); val isTransposeB = attributesForNode.get("transpose_b").getB(); MMulTranspose mMulTranspose = MMulTranspose.builder() .transposeA(isTransposeA).transposeB(isTransposeB) .build(); this.mMulTranspose = mMulTranspose; val args = args(); for(val arg : args) { if(sameDiff.isPlaceHolder(arg.getVarName()) || arg.getShape() == null) { sameDiff.addPropertyToResolve(this,arg.getVarName()); } } }
@Override public long[] getShapeFromAttribute(AttrValue attrValue) { TensorShapeProto shape = attrValue.getShape(); long[] ret = new long[shape.getDimCount()]; for(int i = 0; i < ret.length; i++) { ret[i] = (int) shape.getDim(i).getSize(); } return ret; }
public org.tensorflow.framework.AttrValue buildPartial() { org.tensorflow.framework.AttrValue result = new org.tensorflow.framework.AttrValue(this); if (valueCase_ == 2) { result.value_ = value_;
hash = (19 * hash) + getDescriptor().hashCode(); switch (valueCase_) { case 2: hash = (37 * hash) + S_FIELD_NUMBER; hash = (53 * hash) + getS().hashCode(); break; case 3: hash = (37 * hash) + I_FIELD_NUMBER; hash = (53 * hash) + com.github.os72.protobuf351.Internal.hashLong( getI()); break; case 4: hash = (37 * hash) + F_FIELD_NUMBER; hash = (53 * hash) + java.lang.Float.floatToIntBits( getF()); break; case 5: hash = (37 * hash) + B_FIELD_NUMBER; hash = (53 * hash) + com.github.os72.protobuf351.Internal.hashBoolean( getB()); break; case 6: hash = (37 * hash) + TYPE_FIELD_NUMBER; hash = (53 * hash) + getTypeValue(); break; case 7: hash = (37 * hash) + SHAPE_FIELD_NUMBER; hash = (53 * hash) + getShape().hashCode(); break;