/**
|
* @license
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
export var RowPartitionType;
|
(function (RowPartitionType) {
|
RowPartitionType[RowPartitionType["FIRST_DIM_SIZE"] = 0] = "FIRST_DIM_SIZE";
|
RowPartitionType[RowPartitionType["VALUE_ROWIDS"] = 1] = "VALUE_ROWIDS";
|
RowPartitionType[RowPartitionType["ROW_LENGTHS"] = 2] = "ROW_LENGTHS";
|
RowPartitionType[RowPartitionType["ROW_SPLITS"] = 3] = "ROW_SPLITS";
|
RowPartitionType[RowPartitionType["ROW_LIMITS"] = 4] = "ROW_LIMITS";
|
RowPartitionType[RowPartitionType["ROW_STARTS"] = 5] = "ROW_STARTS";
|
})(RowPartitionType || (RowPartitionType = {}));
|
export function combineRaggedTensorToTensorShapes(raggedRank, shape, valueShape) {
|
// Test for consistency of valueShape and shape specified.
|
// If shape is unspecified and valueShape is specified, then copy
|
// over the size from the valueShape dimension.
|
let outputShape = new Array();
|
if (valueShape == null && shape == null) {
|
return outputShape;
|
}
|
if (shape == null) {
|
// Here, value_shape must be of known size.
|
while (outputShape.length < raggedRank + valueShape.length) {
|
outputShape.push(-1);
|
}
|
}
|
else {
|
outputShape = shape.slice();
|
}
|
if (valueShape == null) {
|
return outputShape;
|
}
|
// At this point, valueShape and output_shape have known ranks.
|
if (raggedRank + valueShape.length !== outputShape.length) {
|
throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.rank = ${raggedRank +
|
valueShape.length}, but shape.rank = ${outputShape.length}`);
|
}
|
for (let i = 1; i < valueShape.length; ++i) {
|
const valueDim = valueShape[i];
|
const outputShapeDimIndex = outputShape[outputShape.length - valueShape.length + i];
|
const outputShapeDim = outputShape[outputShapeDimIndex];
|
if (valueDim >= 0) {
|
if (outputShapeDim >= 0) {
|
if (outputShapeDim !== valueDim) {
|
throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.shape[${i + raggedRank}] = ${valueDim} but shape[${i + raggedRank}] = ${outputShapeDim}`);
|
}
|
}
|
else {
|
outputShape[outputShapeDimIndex] = valueDim;
|
}
|
}
|
}
|
return outputShape;
|
}
|
export function getRowPartitionTypesHelper(rowPartitionTypeStrings) {
|
const stringToType = {
|
'FIRST_DIM_SIZE': RowPartitionType.FIRST_DIM_SIZE,
|
'VALUE_ROWIDS': RowPartitionType.VALUE_ROWIDS,
|
'ROW_LENGTHS': RowPartitionType.ROW_LENGTHS,
|
'ROW_SPLITS': RowPartitionType.ROW_SPLITS,
|
'ROW_LIMITS': RowPartitionType.ROW_LIMITS,
|
'ROW_STARTS': RowPartitionType.ROW_STARTS
|
};
|
const result = [];
|
for (const typeStr of rowPartitionTypeStrings) {
|
if (typeStr in stringToType) {
|
result.push(stringToType[typeStr]);
|
}
|
else {
|
break;
|
}
|
}
|
return result;
|
}
|
export function getRaggedRank(rowPartitionTypes) {
|
if (rowPartitionTypes.length === 0) {
|
return 0;
|
}
|
if (rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
|
return rowPartitionTypes.length - 1;
|
}
|
return rowPartitionTypes.length;
|
}
|
export function validateDefaultValueShape(defaultValueShape, valueShape) {
|
if (defaultValueShape == null || valueShape == null) {
|
return;
|
}
|
const defaultNDims = defaultValueShape.length;
|
const valuesNDims = valueShape.length;
|
if (defaultNDims >= valuesNDims) {
|
throw new Error(`defaultValue.shape=${defaultValueShape} and ragged tensor flatValues.shape=${valueShape}, are incompatible: defaultValue.rank = ${defaultNDims} must be less than ragged tensor input flatValues.rank = ${valuesNDims})`);
|
}
|
for (let i = 0; i < Math.min(defaultNDims, valuesNDims - 1); ++i) {
|
const defaultDim = defaultValueShape[i];
|
const valueDim = valueShape[i + 1];
|
if (defaultDim >= 0 && valueDim >= 0 && defaultDim !== 1 &&
|
defaultDim !== valueDim) {
|
throw new Error(`defaultValue.shape=${defaultValueShape}, and ragged tensor input flatValues.shape=${valueShape} are incompatible: defaultValue.shape[${i - defaultValueShape.length}] = ${defaultDim} but ragged tensor input.flatValues.shape[${i - defaultValueShape.length}] = ${valueDim}`);
|
}
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"ragged_to_dense_util.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/ragged_to_dense_util.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,MAAM,CAAN,IAAY,gBAOX;AAPD,WAAY,gBAAgB;IAC1B,2EAAc,CAAA;IACd,uEAAY,CAAA;IACZ,qEAAW,CAAA;IACX,mEAAU,CAAA;IACV,mEAAU,CAAA;IACV,mEAAU,CAAA;AACZ,CAAC,EAPW,gBAAgB,KAAhB,gBAAgB,QAO3B;AAED,MAAM,UAAU,iCAAiC,CAC7C,UAAkB,EAAE,KAAe,EAAE,UAAoB;IAC3D,0DAA0D;IAC1D,iEAAiE;IACjE,+CAA+C;IAE/C,IAAI,WAAW,GAAa,IAAI,KAAK,EAAE,CAAC;IACxC,IAAI,UAAU,IAAI,IAAI,IAAI,KAAK,IAAI,IAAI,EAAE;QACvC,OAAO,WAAW,CAAC;KACpB;IAED,IAAI,KAAK,IAAI,IAAI,EAAE;QACjB,2CAA2C;QAC3C,OAAO,WAAW,CAAC,MAAM,GAAG,UAAU,GAAG,UAAU,CAAC,MAAM,EAAE;YAC1D,WAAW,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;SACtB;KACF;SAAM;QACL,WAAW,GAAG,KAAK,CAAC,KAAK,EAAE,CAAC;KAC7B;IACD,IAAI,UAAU,IAAI,IAAI,EAAE;QACtB,OAAO,WAAW,CAAC;KACpB;IACD,+DAA+D;IAC/D,IAAI,UAAU,GAAG,UAAU,CAAC,MAAM,KAAK,WAAW,CAAC,MAAM,EAAE;QACzD,MAAM,IAAI,KAAK,CACX,4BAA4B,KAAK,sCAC7B,UAAU;YACV,UAAU,CAAC,MAAM,sBAAsB,WAAW,CAAC,MAAM,EAAE,CAAC,CAAC;KACtE;IAED,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QAC1C,MAAM,QAAQ,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC;QAC/B,MAAM,mBAAmB,GACrB,WAAW,CAAC,WAAW,CAAC,MAAM,GAAG,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QAC5D,MAAM,cAAc,GAAG,WAAW,CAAC,mBAAmB,CAAC,CAAC;QAExD,IAAI,QAAQ,IAAI,CAAC,EAAE;YACjB,IAAI,cAAc,IAAI,CAAC,EAAE;gBACvB,IAAI,cAAc,KAAK,QAAQ,EAAE;oBAC/B,MAAM,IAAI,KAAK,CAAC,4BACZ,KAAK,qCAAqC,CAAC,GAAG,UAAU,OACxD,QAAQ,cAAc,CAAC,GAAG,UAAU,OAAO,cAAc,EAAE,CAAC,CAAC;iBAClE;aACF;iBAAM;gBACL,WAAW,CAAC,mBAAmB,CAAC,GAAG,QAAQ,CAAC;aAC7C;SACF;KACF;IACD,OAAO,WAAW,CAAC;AACrB,CAAC;AAED,MAAM,UAAU,0BAA0B,CAAC,uBAAiC;IAC1E,MAAM,YAAY,GAAG;QACnB,gBAAgB,EAAE,gBAAgB,CAAC,cAAc;QACjD,cAAc,EAAE,gBAAgB,CAAC,YAAY;QAC7C,aAAa,EAAE,gBAAgB,CAAC,WAAW;QAC3C,YAAY,EAAE,gBAAgB,CAAC,UAAU;QACzC,YAAY,EAAE,gBAAgB,CAAC,UAAU;QACzC,YAAY,EAAE,gBAAgB,CAAC,UAAU;KAC1C,CAAC;IAEF,MAAM,MAAM,GAAuB,EAAE,CAAC;IACtC,KAAK,MAAM,OAAO,IAAI,uBAAuB,EAAE;QAC7C,IAAI,OAAO,IAAI,YAAY,EAAE;YAC3B,MAAM,CAAC,IAAI,CAAC,YAAY,CAAC,OAAoC,CAAC,CAAC,CAAC;SACjE;aAAM;YACL,MAAM;SACP;KACF;IAED,OAAO,MAAM,CAAC;AAChB,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,iBAAqC;IACjE,IAAI,iBAAiB,CAAC,MAAM,KAAK,CAAC,EAAE;QAClC,OAAO,CAAC,CAAC;KACV;IACD,IAAI,iBAAiB,CAAC,CAAC,CAAC,KAAK,gBAAgB,CAAC,cAAc,EAAE;QAC5D,OAAO,iBAAiB,CAAC,MAAM,GAAG,CAAC,CAAC;KACrC;IACD,OAAO,iBAAiB,CAAC,MAAM,CAAC;AAClC,CAAC;AAED,MAAM,UAAU,yBAAyB,CACrC,iBAA2B,EAAE,UAAoB;IACnD,IAAI,iBAAiB,IAAI,IAAI,IAAI,UAAU,IAAI,IAAI,EAAE;QACnD,OAAO;KACR;IAED,MAAM,YAAY,GAAG,iBAAiB,CAAC,MAAM,CAAC;IAC9C,MAAM,WAAW,GAAG,UAAU,CAAC,MAAM,CAAC;IACtC,IAAI,YAAY,IAAI,WAAW,EAAE;QAC/B,MAAM,IAAI,KAAK,CAAC,sBACZ,iBAAiB,uCACjB,UAAU,2CACV,YAAY,4DACZ,WAAW,GAAG,CAAC,CAAC;KACrB;IACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,YAAY,EAAE,WAAW,GAAG,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE;QAChE,MAAM,UAAU,GAAG,iBAAiB,CAAC,CAAC,CAAC,CAAC;QACxC,MAAM,QAAQ,GAAG,UAAU,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;QACnC,IAAI,UAAU,IAAI,CAAC,IAAI,QAAQ,IAAI,CAAC,IAAI,UAAU,KAAK,CAAC;YACpD,UAAU,KAAK,QAAQ,EAAE;YAC3B,MAAM,IAAI,KAAK,CAAC,sBACZ,iBAAiB,8CACjB,UAAU,yCACV,CAAC,GAAG,iBAAiB,CAAC,MAAM,OAC5B,UAAU,6CACV,CAAC,GAAG,iBAAiB,CAAC,MAAM,OAAO,QAAQ,EAAE,CAAC,CAAC;SACpD;KACF;AACH,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2022 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nexport enum RowPartitionType {\n  FIRST_DIM_SIZE,\n  VALUE_ROWIDS,\n  ROW_LENGTHS,\n  ROW_SPLITS,\n  ROW_LIMITS,\n  ROW_STARTS\n}\n\nexport function combineRaggedTensorToTensorShapes(\n    raggedRank: number, shape: number[], valueShape: number[]) {\n  // Test for consistency of valueShape and shape specified.\n  // If shape is unspecified and valueShape is specified, then copy\n  // over the size from the valueShape dimension.\n\n  let outputShape: number[] = new Array();\n  if (valueShape == null && shape == null) {\n    return outputShape;\n  }\n\n  if (shape == null) {\n    // Here, value_shape must be of known size.\n    while (outputShape.length < raggedRank + valueShape.length) {\n      outputShape.push(-1);\n    }\n  } else {\n    outputShape = shape.slice();\n  }\n  if (valueShape == null) {\n    return outputShape;\n  }\n  // At this point, valueShape and output_shape have known ranks.\n  if (raggedRank + valueShape.length !== outputShape.length) {\n    throw new Error(\n        `rt input.shape and shape=${shape} are incompatible: rt input.rank = ${\n            raggedRank +\n            valueShape.length}, but shape.rank = ${outputShape.length}`);\n  }\n\n  for (let i = 1; i < valueShape.length; ++i) {\n    const valueDim = valueShape[i];\n    const outputShapeDimIndex =\n        outputShape[outputShape.length - valueShape.length + i];\n    const outputShapeDim = outputShape[outputShapeDimIndex];\n\n    if (valueDim >= 0) {\n      if (outputShapeDim >= 0) {\n        if (outputShapeDim !== valueDim) {\n          throw new Error(`rt input.shape and shape=${\n              shape} are incompatible: rt input.shape[${i + raggedRank}] = ${\n              valueDim} but shape[${i + raggedRank}] = ${outputShapeDim}`);\n        }\n      } else {\n        outputShape[outputShapeDimIndex] = valueDim;\n      }\n    }\n  }\n  return outputShape;\n}\n\nexport function getRowPartitionTypesHelper(rowPartitionTypeStrings: string[]) {\n  const stringToType = {\n    'FIRST_DIM_SIZE': RowPartitionType.FIRST_DIM_SIZE,\n    'VALUE_ROWIDS': RowPartitionType.VALUE_ROWIDS,\n    'ROW_LENGTHS': RowPartitionType.ROW_LENGTHS,\n    'ROW_SPLITS': RowPartitionType.ROW_SPLITS,\n    'ROW_LIMITS': RowPartitionType.ROW_LIMITS,\n    'ROW_STARTS': RowPartitionType.ROW_STARTS\n  };\n\n  const result: RowPartitionType[] = [];\n  for (const typeStr of rowPartitionTypeStrings) {\n    if (typeStr in stringToType) {\n      result.push(stringToType[typeStr as keyof typeof stringToType]);\n    } else {\n      break;\n    }\n  }\n\n  return result;\n}\n\nexport function getRaggedRank(rowPartitionTypes: RowPartitionType[]) {\n  if (rowPartitionTypes.length === 0) {\n    return 0;\n  }\n  if (rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {\n    return rowPartitionTypes.length - 1;\n  }\n  return rowPartitionTypes.length;\n}\n\nexport function validateDefaultValueShape(\n    defaultValueShape: number[], valueShape: number[]) {\n  if (defaultValueShape == null || valueShape == null) {\n    return;\n  }\n\n  const defaultNDims = defaultValueShape.length;\n  const valuesNDims = valueShape.length;\n  if (defaultNDims >= valuesNDims) {\n    throw new Error(`defaultValue.shape=${\n        defaultValueShape} and ragged tensor flatValues.shape=${\n        valueShape}, are incompatible: defaultValue.rank = ${\n        defaultNDims} must be less than ragged tensor input flatValues.rank = ${\n        valuesNDims})`);\n  }\n  for (let i = 0; i < Math.min(defaultNDims, valuesNDims - 1); ++i) {\n    const defaultDim = defaultValueShape[i];\n    const valueDim = valueShape[i + 1];\n    if (defaultDim >= 0 && valueDim >= 0 && defaultDim !== 1 &&\n        defaultDim !== valueDim) {\n      throw new Error(`defaultValue.shape=${\n          defaultValueShape}, and ragged tensor input flatValues.shape=${\n          valueShape} are incompatible: defaultValue.shape[${\n          i - defaultValueShape.length}] = ${\n          defaultDim} but ragged tensor input.flatValues.shape[${\n          i - defaultValueShape.length}] = ${valueDim}`);\n    }\n  }\n}\n"]}
|