gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/**
 * @license
 * Copyright 2022 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
export var RowPartitionType;
(function (RowPartitionType) {
    RowPartitionType[RowPartitionType["FIRST_DIM_SIZE"] = 0] = "FIRST_DIM_SIZE";
    RowPartitionType[RowPartitionType["VALUE_ROWIDS"] = 1] = "VALUE_ROWIDS";
    RowPartitionType[RowPartitionType["ROW_LENGTHS"] = 2] = "ROW_LENGTHS";
    RowPartitionType[RowPartitionType["ROW_SPLITS"] = 3] = "ROW_SPLITS";
    RowPartitionType[RowPartitionType["ROW_LIMITS"] = 4] = "ROW_LIMITS";
    RowPartitionType[RowPartitionType["ROW_STARTS"] = 5] = "ROW_STARTS";
})(RowPartitionType || (RowPartitionType = {}));
export function combineRaggedTensorToTensorShapes(raggedRank, shape, valueShape) {
    // Test for consistency of valueShape and shape specified.
    // If shape is unspecified and valueShape is specified, then copy
    // over the size from the valueShape dimension.
    let outputShape = new Array();
    if (valueShape == null && shape == null) {
        return outputShape;
    }
    if (shape == null) {
        // Here, value_shape must be of known size.
        while (outputShape.length < raggedRank + valueShape.length) {
            outputShape.push(-1);
        }
    }
    else {
        outputShape = shape.slice();
    }
    if (valueShape == null) {
        return outputShape;
    }
    // At this point, valueShape and output_shape have known ranks.
    if (raggedRank + valueShape.length !== outputShape.length) {
        throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.rank = ${raggedRank +
            valueShape.length}, but shape.rank = ${outputShape.length}`);
    }
    for (let i = 1; i < valueShape.length; ++i) {
        const valueDim = valueShape[i];
        const outputShapeDimIndex = outputShape[outputShape.length - valueShape.length + i];
        const outputShapeDim = outputShape[outputShapeDimIndex];
        if (valueDim >= 0) {
            if (outputShapeDim >= 0) {
                if (outputShapeDim !== valueDim) {
                    throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.shape[${i + raggedRank}] = ${valueDim} but shape[${i + raggedRank}] = ${outputShapeDim}`);
                }
            }
            else {
                outputShape[outputShapeDimIndex] = valueDim;
            }
        }
    }
    return outputShape;
}
export function getRowPartitionTypesHelper(rowPartitionTypeStrings) {
    const stringToType = {
        'FIRST_DIM_SIZE': RowPartitionType.FIRST_DIM_SIZE,
        'VALUE_ROWIDS': RowPartitionType.VALUE_ROWIDS,
        'ROW_LENGTHS': RowPartitionType.ROW_LENGTHS,
        'ROW_SPLITS': RowPartitionType.ROW_SPLITS,
        'ROW_LIMITS': RowPartitionType.ROW_LIMITS,
        'ROW_STARTS': RowPartitionType.ROW_STARTS
    };
    const result = [];
    for (const typeStr of rowPartitionTypeStrings) {
        if (typeStr in stringToType) {
            result.push(stringToType[typeStr]);
        }
        else {
            break;
        }
    }
    return result;
}
export function getRaggedRank(rowPartitionTypes) {
    if (rowPartitionTypes.length === 0) {
        return 0;
    }
    if (rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
        return rowPartitionTypes.length - 1;
    }
    return rowPartitionTypes.length;
}
export function validateDefaultValueShape(defaultValueShape, valueShape) {
    if (defaultValueShape == null || valueShape == null) {
        return;
    }
    const defaultNDims = defaultValueShape.length;
    const valuesNDims = valueShape.length;
    if (defaultNDims >= valuesNDims) {
        throw new Error(`defaultValue.shape=${defaultValueShape} and ragged tensor flatValues.shape=${valueShape}, are incompatible: defaultValue.rank = ${defaultNDims} must be less than ragged tensor input flatValues.rank = ${valuesNDims})`);
    }
    for (let i = 0; i < Math.min(defaultNDims, valuesNDims - 1); ++i) {
        const defaultDim = defaultValueShape[i];
        const valueDim = valueShape[i + 1];
        if (defaultDim >= 0 && valueDim >= 0 && defaultDim !== 1 &&
            defaultDim !== valueDim) {
            throw new Error(`defaultValue.shape=${defaultValueShape}, and ragged tensor input flatValues.shape=${valueShape} are incompatible: defaultValue.shape[${i - defaultValueShape.length}] = ${defaultDim} but ragged tensor input.flatValues.shape[${i - defaultValueShape.length}] = ${valueDim}`);
        }
    }
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"ragged_to_dense_util.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/ragged_to_dense_util.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,MAAM,CAAN,IAAY,gBAOX;AAPD,WAAY,gBAAgB;IAC1B,2EAAc,CAAA;IACd,uEAAY,CAAA;IACZ,qEAAW,CAAA;IACX,mEAAU,CAAA;IACV,mEAAU,CAAA;IACV,mEAAU,CAAA;AACZ,CAAC,EAPW,gBAAgB,KAAhB,gBAAgB,QAO3B;AAED,MAAM,UAAU,iCAAiC,CAC7C,UAAkB,EAAE,KAAe,EAAE,UAAoB;IAC3D,0DAA0D;IAC1D,iEAAiE;IACjE,+CAA+C;IAE/C,IAAI,WAAW,GAAa,IAAI,KAAK,EAAE,CAAC;IACxC,IAAI,UAAU,IAAI,IAAI,IAAI,KAAK,IAAI,IAAI,EAAE;QACvC,OAAO,WAAW,CAAC;KACpB;IAED,IAAI,KAAK,IAAI,IAAI,EAAE;QACjB,2CAA2C;QAC3C,OAAO,WAAW,CAAC,MAAM,GAAG,UAAU,GAAG,UAAU,CAAC,MAAM,EAAE;YAC1D,WAAW,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;SACtB;KACF;SAAM;QACL,WAAW,GAAG,KAAK,CAAC,KAAK,EAAE,CAAC;KAC7B;IACD,IAAI,UAAU,IAAI,IAAI,EAAE;QACtB,OAAO,WAAW,CAAC;KACpB;IACD,+DAA+D;IAC/D,IAAI,UAAU,GAAG,UAAU,CAAC,MAAM,KAAK,WAAW,CAAC,MAAM,EAAE;QACzD,MAAM,IAAI,KAAK,CACX,4BAA4B,KAAK,sCAC7B,UAAU;YACV,UAAU,CAAC,MAAM,sBAAsB,WAAW,CAAC,MAAM,EAAE,CAAC,CAAC;KACtE;IAED,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QAC1C,MAAM,QAAQ,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC;QAC/B,MAAM,mBAAmB,GACrB,WAAW,CAAC,WAAW,CAAC,MAAM,GAAG,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QAC5D,MAAM,cAAc,GAAG,WAAW,CAAC,mBAAmB,CAAC,CAAC;QAExD,IAAI,QAAQ,IAAI,CAAC,EAAE;YACjB,IAAI,cAAc,IAAI,CAAC,EAAE;gBACvB,IAAI,cAAc,KAAK,QAAQ,EAAE;oBAC/B,MAAM,IAAI,KAAK,CAAC,4BACZ,KAAK,qCAAqC,CAAC,GAAG,UAAU,OACxD,QAAQ,cAAc,CAAC,GAAG,UAAU,OAAO,cAAc,EAAE,CAAC,CAAC;iBAClE;aACF;iBAAM;gBACL,WAAW,CAAC,mBAAmB,CAAC,GAAG,QAAQ,CAAC;aAC7C;SACF;KACF;IACD,OAAO,WAAW,CAAC;AACrB,CAAC;AAED,MAAM,UAAU,0BAA0B,CAAC,uBAAiC;IAC1E,MAAM,YAAY,GAAG;QACnB,gBAAgB,EAAE,gBAAgB,CAAC,cAAc;QACjD,cAAc,EAAE,gBAAgB,CAAC,YAAY;QAC7C,aAAa,EAAE,gBAAgB,CAAC,WAAW;QAC3C,YAAY,EAAE,gBAAgB,CAAC,UAAU;QACzC,YAAY,EAAE,gBAAgB,CAAC,UAAU;QACzC,YAAY,EAAE,gBAAgB,CAAC,UAAU;KAC1C,CAAC;IAEF,MAAM,MAAM,GAAuB,EAAE,CAAC;IACtC,KAAK,MAAM,OAAO,IAAI,uBAAuB,EAAE;QAC7C,IAAI,OAAO,IAAI,YAAY,EAAE;YAC3B,MAAM,CAAC,IAAI,CAAC,YAAY,CAAC,OAAoC,CAAC,CAAC,CAAC;SACjE;aAAM;YACL,MAAM;SACP;KACF;IAED,OAAO,MAAM,CAAC;AAChB,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,iBAAqC;IACjE,IAAI,iBAAiB,CAAC,MAAM,KAAK,CAAC,EAAE;QAClC,OAAO,CAAC,CAAC;KACV;IACD,IAAI,iBAAiB,CAAC,CAAC,CAAC,KAAK,gBAAgB,CAAC,cAAc,EAAE;QAC5D,OAAO,iBAAiB,CAAC,MAAM,GAAG,CAAC,CAAC;KACrC;IACD,OAAO,iBAAiB,CAAC,MAAM,CAAC;AAClC,CAAC;AAED,MAAM,UAAU,yBAAyB,CACrC,iBAA2B,EAAE,UAAoB;IACnD,IAAI,iBAAiB,IAAI,IAAI,IAAI,UAAU,IAAI,IAAI,EAAE;QACnD,OAAO;KACR;IAED,MAAM,YAAY,GAAG,iBAAiB,CAAC,MAAM,CAAC;IAC9C,MAAM,WAAW,GAAG,UAAU,CAAC,MAAM,CAAC;IACtC,IAAI,YAAY,IAAI,WAAW,EAAE;QAC/B,MAAM,IAAI,KAAK,CAAC,sBACZ,iBAAiB,uCACjB,UAAU,2CACV,YAAY,4DACZ,WAAW,GAAG,CAAC,CAAC;KACrB;IACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,YAAY,EAAE,WAAW,GAAG,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE;QAChE,MAAM,UAAU,GAAG,iBAAiB,CAAC,CAAC,CAAC,CAAC;QACxC,MAAM,QAAQ,GAAG,UAAU,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;QACnC,IAAI,UAAU,IAAI,CAAC,IAAI,QAAQ,IAAI,CAAC,IAAI,UAAU,KAAK,CAAC;YACpD,UAAU,KAAK,QAAQ,EAAE;YAC3B,MAAM,IAAI,KAAK,CAAC,sBACZ,iBAAiB,8CACjB,UAAU,yCACV,CAAC,GAAG,iBAAiB,CAAC,MAAM,OAC5B,UAAU,6CACV,CAAC,GAAG,iBAAiB,CAAC,MAAM,OAAO,QAAQ,EAAE,CAAC,CAAC;SACpD;KACF;AACH,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2022 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nexport enum RowPartitionType {\n  FIRST_DIM_SIZE,\n  VALUE_ROWIDS,\n  ROW_LENGTHS,\n  ROW_SPLITS,\n  ROW_LIMITS,\n  ROW_STARTS\n}\n\nexport function combineRaggedTensorToTensorShapes(\n    raggedRank: number, shape: number[], valueShape: number[]) {\n  // Test for consistency of valueShape and shape specified.\n  // If shape is unspecified and valueShape is specified, then copy\n  // over the size from the valueShape dimension.\n\n  let outputShape: number[] = new Array();\n  if (valueShape == null && shape == null) {\n    return outputShape;\n  }\n\n  if (shape == null) {\n    // Here, value_shape must be of known size.\n    while (outputShape.length < raggedRank + valueShape.length) {\n      outputShape.push(-1);\n    }\n  } else {\n    outputShape = shape.slice();\n  }\n  if (valueShape == null) {\n    return outputShape;\n  }\n  // At this point, valueShape and output_shape have known ranks.\n  if (raggedRank + valueShape.length !== outputShape.length) {\n    throw new Error(\n        `rt input.shape and shape=${shape} are incompatible: rt input.rank = ${\n            raggedRank +\n            valueShape.length}, but shape.rank = ${outputShape.length}`);\n  }\n\n  for (let i = 1; i < valueShape.length; ++i) {\n    const valueDim = valueShape[i];\n    const outputShapeDimIndex =\n        outputShape[outputShape.length - valueShape.length + i];\n    const outputShapeDim = outputShape[outputShapeDimIndex];\n\n    if (valueDim >= 0) {\n      if (outputShapeDim >= 0) {\n        if (outputShapeDim !== valueDim) {\n          throw new Error(`rt input.shape and shape=${\n              shape} are incompatible: rt input.shape[${i + raggedRank}] = ${\n              valueDim} but shape[${i + raggedRank}] = ${outputShapeDim}`);\n        }\n      } else {\n        outputShape[outputShapeDimIndex] = valueDim;\n      }\n    }\n  }\n  return outputShape;\n}\n\nexport function getRowPartitionTypesHelper(rowPartitionTypeStrings: string[]) {\n  const stringToType = {\n    'FIRST_DIM_SIZE': RowPartitionType.FIRST_DIM_SIZE,\n    'VALUE_ROWIDS': RowPartitionType.VALUE_ROWIDS,\n    'ROW_LENGTHS': RowPartitionType.ROW_LENGTHS,\n    'ROW_SPLITS': RowPartitionType.ROW_SPLITS,\n    'ROW_LIMITS': RowPartitionType.ROW_LIMITS,\n    'ROW_STARTS': RowPartitionType.ROW_STARTS\n  };\n\n  const result: RowPartitionType[] = [];\n  for (const typeStr of rowPartitionTypeStrings) {\n    if (typeStr in stringToType) {\n      result.push(stringToType[typeStr as keyof typeof stringToType]);\n    } else {\n      break;\n    }\n  }\n\n  return result;\n}\n\nexport function getRaggedRank(rowPartitionTypes: RowPartitionType[]) {\n  if (rowPartitionTypes.length === 0) {\n    return 0;\n  }\n  if (rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {\n    return rowPartitionTypes.length - 1;\n  }\n  return rowPartitionTypes.length;\n}\n\nexport function validateDefaultValueShape(\n    defaultValueShape: number[], valueShape: number[]) {\n  if (defaultValueShape == null || valueShape == null) {\n    return;\n  }\n\n  const defaultNDims = defaultValueShape.length;\n  const valuesNDims = valueShape.length;\n  if (defaultNDims >= valuesNDims) {\n    throw new Error(`defaultValue.shape=${\n        defaultValueShape} and ragged tensor flatValues.shape=${\n        valueShape}, are incompatible: defaultValue.rank = ${\n        defaultNDims} must be less than ragged tensor input flatValues.rank = ${\n        valuesNDims})`);\n  }\n  for (let i = 0; i < Math.min(defaultNDims, valuesNDims - 1); ++i) {\n    const defaultDim = defaultValueShape[i];\n    const valueDim = valueShape[i + 1];\n    if (defaultDim >= 0 && valueDim >= 0 && defaultDim !== 1 &&\n        defaultDim !== valueDim) {\n      throw new Error(`defaultValue.shape=${\n          defaultValueShape}, and ragged tensor input flatValues.shape=${\n          valueShape} are incompatible: defaultValue.shape[${\n          i - defaultValueShape.length}] = ${\n          defaultDim} but ragged tensor input.flatValues.shape[${\n          i - defaultValueShape.length}] = ${valueDim}`);\n    }\n  }\n}\n"]}