gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { buffer, slice_util, StridedSlice, util } from '@tensorflow/tfjs-core';
import { stridedSliceImplCPU } from '../kernel_utils/shared';
import { StridedSliceProgram } from '../strided_slice_gpu';
import { reshape } from './Reshape';
import { slice } from './Slice';
export function stridedSlice(args) {
    const { inputs, backend, attrs } = args;
    const { x } = inputs;
    const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
    const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = slice_util.sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
    let result;
    if (isIdentity) {
        // Optimization #1, slice is a no-op plus reshape
        result = reshape({ inputs: { x }, backend, attrs: { shape: finalShape } });
    }
    else if (sliceDim0 || isSimpleSlice) {
        // Optimization #2, slice is memory contiguous (only occurs in dim 0)
        util.assert(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
        const size = slice_util.computeOutShape($begin, $end, $strides);
        // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
        const sliced = slice({ inputs: { x }, backend, attrs: { begin: $begin, size } });
        result =
            reshape({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
        backend.disposeIntermediateTensorInfo(sliced);
    }
    else {
        const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
        if (shouldExecuteOnCPU) {
            // tslint:disable-next-line: no-unnecessary-type-assertion
            const values = backend.readSync(x.dataId);
            // tslint:disable-next-line: no-unnecessary-type-assertion
            const xBuf = buffer(x.shape, x.dtype, values);
            const resultValues = stridedSliceImplCPU(finalShapeSparse, xBuf, $strides, $begin);
            result = backend.makeTensorInfo(finalShape, x.dtype, resultValues.values);
        }
        else {
            const program = new StridedSliceProgram($begin, $strides, finalShapeSparse);
            result = backend.runWebGLProgram(program, [x], x.dtype);
        }
    }
    const resultReshaped = reshape({ inputs: { x: result }, backend, attrs: { shape: finalShape } });
    backend.disposeIntermediateTensorInfo(result);
    return resultReshaped;
}
export const stridedSliceConfig = {
    kernelName: StridedSlice,
    backendName: 'webgl',
    kernelFunc: stridedSlice
};
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"StridedSlice.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-webgl/src/kernels/StridedSlice.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,MAAM,EAAkC,UAAU,EAAE,YAAY,EAA+E,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAG1L,OAAO,EAAC,mBAAmB,EAAC,MAAM,wBAAwB,CAAC;AAC3D,OAAO,EAAC,mBAAmB,EAAC,MAAM,sBAAsB,CAAC;AAEzD,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAClC,OAAO,EAAC,KAAK,EAAC,MAAM,SAAS,CAAC;AAE9B,MAAM,UAAU,YAAY,CAAC,IAI5B;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAC,GAAG,MAAM,CAAC;IACnB,MAAM,EACJ,KAAK,EACL,GAAG,EACH,OAAO,EACP,SAAS,EACT,OAAO,EACP,YAAY,EACZ,WAAW,EACX,cAAc,EACf,GAAG,KAAK,CAAC;IAEV,MAAM,EACJ,gBAAgB,EAChB,UAAU,EACV,UAAU,EACV,SAAS,EACT,aAAa,EACb,KAAK,EAAE,MAAM,EACb,GAAG,EAAE,IAAI,EACT,OAAO,EAAE,QAAQ,EAClB,GACG,UAAU,CAAC,SAAS,CAChB,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,OAAO,EAAE,SAAS,EAAE,OAAO,EAAE,YAAY,EAC9D,WAAW,EAAE,cAAc,CAAC,CAAC;IAErC,IAAI,MAAM,CAAC;IAEX,IAAI,UAAU,EAAE;QACd,iDAAiD;QACjD,MAAM,GAAG,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,UAAU,EAAC,EAAC,CAAC,CAAC;KACtE;SAAM,IAAI,SAAS,IAAI,aAAa,EAAE;QACrC,qEAAqE;QACrE,IAAI,CAAC,MAAM,CACP,CAAC,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,EACnB,GAAG,EAAE,CAAC,yCAAyC,CAAC,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,CAAC;QAErE,MAAM,IAAI,GAAG,UAAU,CAAC,eAAe,CAAC,MAAM,EAAE,IAAI,EAAE,QAAQ,CAAC,CAAC;QAChE,wEAAwE;QACxE,MAAM,MAAM,GAAG,KAAK,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,MAAM,EAAE,IAAI,EAAC,EAAC,CAAC,CAAC;QAC3E,MAAM;YACF,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,UAAU,EAAC,EAAC,CAAC,CAAC;QACxE,OAAO,CAAC,6BAA6B,CAAC,MAAM,CAAC,CAAC;KAC/C;SAAM;QACL,MAAM,kBAAkB,GAAG,OAAO,CAAC,kBAAkB,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3D,IAAI,kBAAkB,EAAE;YACtB,0DAA0D;YAC1D,MAAM,MAAM,GAAG,OAAO,CAAC,QAAQ,CAAC,CAAC,CAAC,MAAM,CAAe,CAAC;YACxD,0DAA0D;YAC1D,MAAM,IAAI,GAAG,MAAM,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,MAAM,CAAuB,CAAC;YACpE,MAAM,YAAY,GACd,mBAAmB,CAAC,gBAAgB,EAAE,IAAI,EAAE,QAAQ,EAAE,MAAM,CAAC,CAAC;YAClE,MAAM,GAAG,OAAO,CAAC,cAAc,CAAC,UAAU,EAAE,CAAC,CAAC,KAAK,EAAE,YAAY,CAAC,MAAM,CAAC,CAAC;SAC3E;aAAM;YACL,MAAM,OAAO,GACT,IAAI,mBAAmB,CAAC,MAAM,EAAE,QAAQ,EAAE,gBAAgB,CAAC,CAAC;YAChE,MAAM,GAAG,OAAO,CAAC,eAAe,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC;SACzD;KACF;IAED,MAAM,cAAc,GAChB,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,UAAU,EAAC,EAAC,CAAC,CAAC;IAExE,OAAO,CAAC,6BAA6B,CAAC,MAAM,CAAC,CAAC;IAE9C,OAAO,cAAc,CAAC;AACxB,CAAC;AAED,MAAM,CAAC,MAAM,kBAAkB,GAAiB;IAC9C,UAAU,EAAE,YAAY;IACxB,WAAW,EAAE,OAAO;IACpB,UAAU,EAAE,YAAqC;CAClD,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {buffer, KernelConfig, KernelFunc, Rank, slice_util, StridedSlice, StridedSliceAttrs, StridedSliceInputs, TensorBuffer, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {stridedSliceImplCPU} from '../kernel_utils/shared';\nimport {StridedSliceProgram} from '../strided_slice_gpu';\n\nimport {reshape} from './Reshape';\nimport {slice} from './Slice';\n\nexport function stridedSlice(args: {\n  inputs: StridedSliceInputs,\n  backend: MathBackendWebGL,\n  attrs: StridedSliceAttrs\n}): TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {x} = inputs;\n  const {\n    begin,\n    end,\n    strides,\n    beginMask,\n    endMask,\n    ellipsisMask,\n    newAxisMask,\n    shrinkAxisMask\n  } = attrs;\n\n  const {\n    finalShapeSparse,\n    finalShape,\n    isIdentity,\n    sliceDim0,\n    isSimpleSlice,\n    begin: $begin,\n    end: $end,\n    strides: $strides\n  } =\n      slice_util.sliceInfo(\n          x.shape, begin, end, strides, beginMask, endMask, ellipsisMask,\n          newAxisMask, shrinkAxisMask);\n\n  let result;\n\n  if (isIdentity) {\n    // Optimization #1, slice is a no-op plus reshape\n    result = reshape({inputs: {x}, backend, attrs: {shape: finalShape}});\n  } else if (sliceDim0 || isSimpleSlice) {\n    // Optimization #2, slice is memory contiguous (only occurs in dim 0)\n    util.assert(\n        x.shape.length >= 1,\n        () => `Input must have rank at least 1, got: ${x.shape.length}`);\n\n    const size = slice_util.computeOutShape($begin, $end, $strides);\n    // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).\n    const sliced = slice({inputs: {x}, backend, attrs: {begin: $begin, size}});\n    result =\n        reshape({inputs: {x: sliced}, backend, attrs: {shape: finalShape}});\n    backend.disposeIntermediateTensorInfo(sliced);\n  } else {\n    const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);\n    if (shouldExecuteOnCPU) {\n      // tslint:disable-next-line: no-unnecessary-type-assertion\n      const values = backend.readSync(x.dataId) as TypedArray;\n      // tslint:disable-next-line: no-unnecessary-type-assertion\n      const xBuf = buffer(x.shape, x.dtype, values) as TensorBuffer<Rank>;\n      const resultValues =\n          stridedSliceImplCPU(finalShapeSparse, xBuf, $strides, $begin);\n      result = backend.makeTensorInfo(finalShape, x.dtype, resultValues.values);\n    } else {\n      const program =\n          new StridedSliceProgram($begin, $strides, finalShapeSparse);\n      result = backend.runWebGLProgram(program, [x], x.dtype);\n    }\n  }\n\n  const resultReshaped =\n      reshape({inputs: {x: result}, backend, attrs: {shape: finalShape}});\n\n  backend.disposeIntermediateTensorInfo(result);\n\n  return resultReshaped;\n}\n\nexport const stridedSliceConfig: KernelConfig = {\n  kernelName: StridedSlice,\n  backendName: 'webgl',\n  kernelFunc: stridedSlice as unknown as KernelFunc\n};\n"]}