gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/**
 * @license
 * Copyright 2021 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, util } from '@tensorflow/tfjs-core';
export function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) {
    const indicesCount = indicesShape[0];
    const denseRows = denseShape[0];
    const emptyRowIndicator = new Array(denseRows);
    const reverseIndexMap = new Array(indicesCount);
    const rank = indicesShape[1];
    if (denseRows === 0) {
        if (indicesCount !== 0) {
            throw new Error(backend_util.getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesCount));
        }
        const outputIndices = util.getArrayFromDType(indicesDType, 0);
        const outputValues = util.getArrayFromDType(valuesDType, 0);
        return [
            outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap
        ];
    }
    let rowsAreOrdered = true;
    let lastIndicesRow = 0;
    const csrOffset = new Array(denseRows).fill(0);
    for (let i = 0; i < indicesCount; ++i) {
        // indices is a 2d tensor with shape of [N, rank]
        const row = indices[i * rank];
        if (row < 0) {
            throw new Error(backend_util.getSparseFillEmptyRowsNegativeIndexErrorMessage(i, row));
        }
        if (row >= denseRows) {
            throw new Error(backend_util.getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(i, row, denseRows));
        }
        ++csrOffset[row];
        rowsAreOrdered = rowsAreOrdered && (row >= lastIndicesRow);
        lastIndicesRow = row;
    }
    let allRowsFull = true;
    for (let row = 0; row < denseRows; ++row) {
        // csrOffset here describes the number of elements in this dense row
        const rowEmpty = (csrOffset[row] === 0);
        emptyRowIndicator[row] = rowEmpty;
        allRowsFull = allRowsFull && !rowEmpty;
        // In filled version, each row has at least one element.
        csrOffset[row] = Math.max(csrOffset[row], 1);
        // Update csrOffset to represent the number of elements up to and
        // including denseRows + 1:
        //  csrOffset[0] == #{elements of row 0}
        //  csrOffset[1] == #{elements of row 1} + #{elements of row 0}
        //  ..
        //  csrOffset[i] == starting index for elements in row i + 1.
        if (row > 0) {
            csrOffset[row] += csrOffset[row - 1];
        }
    }
    if (allRowsFull && rowsAreOrdered) {
        const outputIndices = indices;
        const outputValues = values;
        for (let i = 0; i < indicesCount; ++i) {
            reverseIndexMap[i] = i;
        }
        return [
            outputIndices, [indicesCount, rank], outputValues, emptyRowIndicator,
            reverseIndexMap
        ];
    }
    else {
        const fullIndicesCount = csrOffset[denseRows - 1];
        const outputIndices = util.getArrayFromDType(indicesDType, fullIndicesCount * rank);
        const outputValues = util.getArrayFromDType(valuesDType, fullIndicesCount);
        const filledCount = new Array(denseRows).fill(0);
        // Fill in values for rows that are not missing
        for (let i = 0; i < indicesCount; ++i) {
            // indices is a 2d tensor with shape of [N, rank]
            const row = indices[i * rank];
            const offset = filledCount[row];
            const outputI = ((row === 0) ? 0 : csrOffset[row - 1]) + offset;
            filledCount[row]++; // Increment the filled count for this row.
            for (let j = 0; j < rank; ++j) {
                // indices and outputIndices are 2d tensors with shape of [N, rank]
                outputIndices[outputI * rank + j] = indices[i * rank + j];
            }
            outputValues[outputI] = values[i];
            // We'll need this reverse index map to backprop correctly.
            reverseIndexMap[i] = outputI;
        }
        // Fill in values for rows that are missing
        for (let row = 0; row < denseRows; ++row) {
            const rowCount = filledCount[row];
            if (rowCount === 0) { // We haven't filled this row
                const startingIndex = (row === 0) ? 0 : csrOffset[row - 1];
                // Remaining index values were set to zero already.
                // Just need to set the row index in the right location.
                // outputIndices is a 2d tensor with shape of [N, rank]
                outputIndices[startingIndex * rank + 0] = row;
                for (let col = 1; col < rank; ++col) {
                    outputIndices[startingIndex * rank + col] = 0;
                }
                outputValues[startingIndex] = defaultValue;
            }
        }
        return [
            outputIndices, [fullIndicesCount, rank], outputValues, emptyRowIndicator,
            reverseIndexMap
        ];
    }
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"SparseFillEmptyRows_impl.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/SparseFillEmptyRows_impl.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAwB,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAE/E,MAAM,UAAU,uBAAuB,CACnC,OAAmB,EAAE,YAAsB,EAAE,YAAsB,EACnE,MAAkB,EAAE,WAAqB,EAAE,UAAsB,EACjE,YAAoB;IAEtB,MAAM,YAAY,GAAG,YAAY,CAAC,CAAC,CAAC,CAAC;IACrC,MAAM,SAAS,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC;IAEhC,MAAM,iBAAiB,GAAc,IAAI,KAAK,CAAC,SAAS,CAAC,CAAC;IAC1D,MAAM,eAAe,GAAa,IAAI,KAAK,CAAC,YAAY,CAAC,CAAC;IAE1D,MAAM,IAAI,GAAG,YAAY,CAAC,CAAC,CAAC,CAAC;IAE7B,IAAI,SAAS,KAAK,CAAC,EAAE;QACnB,IAAI,YAAY,KAAK,CAAC,EAAE;YACtB,MAAM,IAAI,KAAK,CACX,YAAY,CAAC,+CAA+C,CACxD,YAAY,CAAC,CAAC,CAAC;SACxB;QACD,MAAM,aAAa,GAAG,IAAI,CAAC,iBAAiB,CAAC,YAAY,EAAE,CAAC,CAAe,CAAC;QAC5E,MAAM,YAAY,GAAG,IAAI,CAAC,iBAAiB,CAAC,WAAW,EAAE,CAAC,CAAe,CAAC;QAC1E,OAAO;YACL,aAAa,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,EAAE,YAAY,EAAE,iBAAiB,EAAE,eAAe;SAC3E,CAAC;KACH;IAED,IAAI,cAAc,GAAG,IAAI,CAAC;IAC1B,IAAI,cAAc,GAAG,CAAC,CAAC;IACvB,MAAM,SAAS,GAAa,IAAI,KAAK,CAAC,SAAS,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAEzD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,YAAY,EAAE,EAAE,CAAC,EAAE;QACrC,iDAAiD;QACjD,MAAM,GAAG,GAAG,OAAO,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC;QAC9B,IAAI,GAAG,GAAG,CAAC,EAAE;YACX,MAAM,IAAI,KAAK,CACX,YAAY,CAAC,+CAA+C,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC;SAC3E;QACD,IAAI,GAAG,IAAI,SAAS,EAAE;YACpB,MAAM,IAAI,KAAK,CACX,YAAY,CAAC,iDAAiD,CAC1D,CAAC,EAAE,GAAG,EAAE,SAAS,CAAC,CAAC,CAAC;SAC7B;QACD,EAAE,SAAS,CAAC,GAAG,CAAC,CAAC;QACjB,cAAc,GAAG,cAAc,IAAI,CAAC,GAAG,IAAI,cAAc,CAAC,CAAC;QAC3D,cAAc,GAAG,GAAG,CAAC;KACtB;IAED,IAAI,WAAW,GAAG,IAAI,CAAC;IACvB,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,SAAS,EAAE,EAAE,GAAG,EAAE;QACxC,oEAAoE;QACpE,MAAM,QAAQ,GAAG,CAAC,SAAS,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC;QACxC,iBAAiB,CAAC,GAAG,CAAC,GAAG,QAAQ,CAAC;QAClC,WAAW,GAAG,WAAW,IAAI,CAAC,QAAQ,CAAC;QACvC,wDAAwD;QACxD,SAAS,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAC;QAC7C,iEAAiE;QACjE,2BAA2B;QAC3B,wCAAwC;QACxC,+DAA+D;QAC/D,MAAM;QACN,6DAA6D;QAC7D,IAAI,GAAG,GAAG,CAAC,EAAE;YACX,SAAS,CAAC,GAAG,CAAC,IAAI,SAAS,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC;SACtC;KACF;IAED,IAAI,WAAW,IAAI,cAAc,EAAE;QACjC,MAAM,aAAa,GAAe,OAAO,CAAC;QAC1C,MAAM,YAAY,GAAe,MAAM,CAAC;QACxC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,YAAY,EAAE,EAAE,CAAC,EAAE;YACrC,eAAe,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;SACxB;QACD,OAAO;YACL,aAAa,EAAE,CAAC,YAAY,EAAE,IAAI,CAAC,EAAE,YAAY,EAAE,iBAAiB;YACpE,eAAe;SAChB,CAAC;KACH;SAAM;QACL,MAAM,gBAAgB,GAAG,SAAS,CAAC,SAAS,GAAG,CAAC,CAAC,CAAC;QAClD,MAAM,aAAa,GACf,IAAI,CAAC,iBAAiB,CAAC,YAAY,EAAE,gBAAgB,GAAG,IAAI,CAClD,CAAC;QACf,MAAM,YAAY,GACd,IAAI,CAAC,iBAAiB,CAAC,WAAW,EAAE,gBAAgB,CAAe,CAAC;QACxE,MAAM,WAAW,GAAa,IAAI,KAAK,CAAC,SAAS,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE3D,+CAA+C;QAC/C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,YAAY,EAAE,EAAE,CAAC,EAAE;YACrC,iDAAiD;YACjD,MAAM,GAAG,GAAG,OAAO,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC;YAC9B,MAAM,MAAM,GAAG,WAAW,CAAC,GAAG,CAAC,CAAC;YAChC,MAAM,OAAO,GAAG,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC;YAChE,WAAW,CAAC,GAAG,CAAC,EAAE,CAAC,CAAE,2CAA2C;YAChE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,EAAE,CAAC,EAAE;gBAC7B,mEAAmE;gBACnE,aAAa,CAAC,OAAO,GAAG,IAAI,GAAG,CAAC,CAAC,GAAG,OAAO,CAAC,CAAC,GAAG,IAAI,GAAG,CAAC,CAAC,CAAC;aAC3D;YACD,YAAY,CAAC,OAAO,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;YAClC,2DAA2D;YAC3D,eAAe,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC;SAC9B;QAED,2CAA2C;QAC3C,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,SAAS,EAAE,EAAE,GAAG,EAAE;YACxC,MAAM,QAAQ,GAAG,WAAW,CAAC,GAAG,CAAC,CAAC;YAClC,IAAI,QAAQ,KAAK,CAAC,EAAE,EAAG,6BAA6B;gBAClD,MAAM,aAAa,GAAG,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC;gBAC3D,mDAAmD;gBACnD,wDAAwD;gBACxD,uDAAuD;gBACvD,aAAa,CAAC,aAAa,GAAG,IAAI,GAAG,CAAC,CAAC,GAAG,GAAG,CAAC;gBAC9C,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,IAAI,EAAE,EAAE,GAAG,EAAE;oBACnC,aAAa,CAAC,aAAa,GAAG,IAAI,GAAG,GAAG,CAAC,GAAG,CAAC,CAAC;iBAC/C;gBACD,YAAY,CAAC,aAAa,CAAC,GAAG,YAAY,CAAC;aAC5C;SACF;QACD,OAAO;YACL,aAAa,EAAE,CAAC,gBAAgB,EAAE,IAAI,CAAC,EAAE,YAAY,EAAE,iBAAiB;YACxE,eAAe;SAChB,CAAC;KACH;AACH,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2021 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, DataType, TypedArray, util} from '@tensorflow/tfjs-core';\n\nexport function sparseFillEmptyRowsImpl(\n    indices: TypedArray, indicesShape: number[], indicesDType: DataType,\n    values: TypedArray, valuesDType: DataType, denseShape: TypedArray,\n    defaultValue: number):\n    [TypedArray, number[], TypedArray, boolean[], number[]] {\n  const indicesCount = indicesShape[0];\n  const denseRows = denseShape[0];\n\n  const emptyRowIndicator: boolean[] = new Array(denseRows);\n  const reverseIndexMap: number[] = new Array(indicesCount);\n\n  const rank = indicesShape[1];\n\n  if (denseRows === 0) {\n    if (indicesCount !== 0) {\n      throw new Error(\n          backend_util.getSparseFillEmptyRowsIndicesDenseShapeMismatch(\n              indicesCount));\n    }\n    const outputIndices = util.getArrayFromDType(indicesDType, 0) as TypedArray;\n    const outputValues = util.getArrayFromDType(valuesDType, 0) as TypedArray;\n    return [\n      outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap\n    ];\n  }\n\n  let rowsAreOrdered = true;\n  let lastIndicesRow = 0;\n  const csrOffset: number[] = new Array(denseRows).fill(0);\n\n  for (let i = 0; i < indicesCount; ++i) {\n    // indices is a 2d tensor with shape of [N, rank]\n    const row = indices[i * rank];\n    if (row < 0) {\n      throw new Error(\n          backend_util.getSparseFillEmptyRowsNegativeIndexErrorMessage(i, row));\n    }\n    if (row >= denseRows) {\n      throw new Error(\n          backend_util.getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(\n              i, row, denseRows));\n    }\n    ++csrOffset[row];\n    rowsAreOrdered = rowsAreOrdered && (row >= lastIndicesRow);\n    lastIndicesRow = row;\n  }\n\n  let allRowsFull = true;\n  for (let row = 0; row < denseRows; ++row) {\n    // csrOffset here describes the number of elements in this dense row\n    const rowEmpty = (csrOffset[row] === 0);\n    emptyRowIndicator[row] = rowEmpty;\n    allRowsFull = allRowsFull && !rowEmpty;\n    // In filled version, each row has at least one element.\n    csrOffset[row] = Math.max(csrOffset[row], 1);\n    // Update csrOffset to represent the number of elements up to and\n    // including denseRows + 1:\n    //  csrOffset[0] == #{elements of row 0}\n    //  csrOffset[1] == #{elements of row 1} + #{elements of row 0}\n    //  ..\n    //  csrOffset[i] == starting index for elements in row i + 1.\n    if (row > 0) {\n      csrOffset[row] += csrOffset[row - 1];\n    }\n  }\n\n  if (allRowsFull && rowsAreOrdered) {\n    const outputIndices: TypedArray = indices;\n    const outputValues: TypedArray = values;\n    for (let i = 0; i < indicesCount; ++i) {\n      reverseIndexMap[i] = i;\n    }\n    return [\n      outputIndices, [indicesCount, rank], outputValues, emptyRowIndicator,\n      reverseIndexMap\n    ];\n  } else {\n    const fullIndicesCount = csrOffset[denseRows - 1];\n    const outputIndices =\n        util.getArrayFromDType(indicesDType, fullIndicesCount * rank) as\n        TypedArray;\n    const outputValues =\n        util.getArrayFromDType(valuesDType, fullIndicesCount) as TypedArray;\n    const filledCount: number[] = new Array(denseRows).fill(0);\n\n    // Fill in values for rows that are not missing\n    for (let i = 0; i < indicesCount; ++i) {\n      // indices is a 2d tensor with shape of [N, rank]\n      const row = indices[i * rank];\n      const offset = filledCount[row];\n      const outputI = ((row === 0) ? 0 : csrOffset[row - 1]) + offset;\n      filledCount[row]++;  // Increment the filled count for this row.\n      for (let j = 0; j < rank; ++j) {\n        // indices and outputIndices are 2d tensors with shape of [N, rank]\n        outputIndices[outputI * rank + j] = indices[i * rank + j];\n      }\n      outputValues[outputI] = values[i];\n      // We'll need this reverse index map to backprop correctly.\n      reverseIndexMap[i] = outputI;\n    }\n\n    // Fill in values for rows that are missing\n    for (let row = 0; row < denseRows; ++row) {\n      const rowCount = filledCount[row];\n      if (rowCount === 0) {  // We haven't filled this row\n        const startingIndex = (row === 0) ? 0 : csrOffset[row - 1];\n        // Remaining index values were set to zero already.\n        // Just need to set the row index in the right location.\n        // outputIndices is a 2d tensor with shape of [N, rank]\n        outputIndices[startingIndex * rank + 0] = row;\n        for (let col = 1; col < rank; ++col) {\n          outputIndices[startingIndex * rank + col] = 0;\n        }\n        outputValues[startingIndex] = defaultValue;\n      }\n    }\n    return [\n      outputIndices, [fullIndicesCount, rank], outputValues, emptyRowIndicator,\n      reverseIndexMap\n    ];\n  }\n}\n"]}