gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { buffer, slice_util, StridedSlice, util } from '@tensorflow/tfjs-core';
import { stridedSliceImplCPU } from '../kernel_utils/shared';
import { StridedSliceProgram } from '../strided_slice_gpu';
import { reshape } from './Reshape';
import { slice } from './Slice';
export function stridedSlice(args) {
    const { inputs, backend, attrs } = args;
    const { x } = inputs;
    const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
    const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = slice_util.sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
    let result;
    if (isIdentity) {
        // Optimization #1, slice is a no-op plus reshape
        result = reshape({ inputs: { x }, backend, attrs: { shape: finalShape } });
    }
    else if (sliceDim0 || isSimpleSlice) {
        // Optimization #2, slice is memory contiguous (only occurs in dim 0)
        util.assert(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
        const size = slice_util.computeOutShape($begin, $end, $strides);
        // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
        const sliced = slice({ inputs: { x }, backend, attrs: { begin: $begin, size } });
        result =
            reshape({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
        backend.disposeIntermediateTensorInfo(sliced);
    }
    else {
        const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
        if (shouldExecuteOnCPU) {
            // tslint:disable-next-line: no-unnecessary-type-assertion
            const values = backend.readSync(x.dataId);
            // tslint:disable-next-line: no-unnecessary-type-assertion
            const xBuf = buffer(x.shape, x.dtype, values);
            const resultValues = stridedSliceImplCPU(finalShapeSparse, xBuf, $strides, $begin);
            result = backend.makeTensorInfo(finalShape, x.dtype, resultValues.values);
        }
        else {
            const program = new StridedSliceProgram($begin, $strides, finalShapeSparse);
            result = backend.runWebGLProgram(program, [x], x.dtype);
        }
    }
    const resultReshaped = reshape({ inputs: { x: result }, backend, attrs: { shape: finalShape } });
    backend.disposeIntermediateTensorInfo(result);
    return resultReshaped;
}
export const stridedSliceConfig = {
    kernelName: StridedSlice,
    backendName: 'webgl',
    kernelFunc: stridedSlice
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3RyaWRlZFNsaWNlLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdlYmdsL3NyYy9rZXJuZWxzL1N0cmlkZWRTbGljZS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsTUFBTSxFQUFrQyxVQUFVLEVBQUUsWUFBWSxFQUErRSxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUcxTCxPQUFPLEVBQUMsbUJBQW1CLEVBQUMsTUFBTSx3QkFBd0IsQ0FBQztBQUMzRCxPQUFPLEVBQUMsbUJBQW1CLEVBQUMsTUFBTSxzQkFBc0IsQ0FBQztBQUV6RCxPQUFPLEVBQUMsT0FBTyxFQUFDLE1BQU0sV0FBVyxDQUFDO0FBQ2xDLE9BQU8sRUFBQyxLQUFLLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFFOUIsTUFBTSxVQUFVLFlBQVksQ0FBQyxJQUk1QjtJQUNDLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsQ0FBQyxFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ25CLE1BQU0sRUFDSixLQUFLLEVBQ0wsR0FBRyxFQUNILE9BQU8sRUFDUCxTQUFTLEVBQ1QsT0FBTyxFQUNQLFlBQVksRUFDWixXQUFXLEVBQ1gsY0FBYyxFQUNmLEdBQUcsS0FBSyxDQUFDO0lBRVYsTUFBTSxFQUNKLGdCQUFnQixFQUNoQixVQUFVLEVBQ1YsVUFBVSxFQUNWLFNBQVMsRUFDVCxhQUFhLEVBQ2IsS0FBSyxFQUFFLE1BQU0sRUFDYixHQUFHLEVBQUUsSUFBSSxFQUNULE9BQU8sRUFBRSxRQUFRLEVBQ2xCLEdBQ0csVUFBVSxDQUFDLFNBQVMsQ0FDaEIsQ0FBQyxDQUFDLEtBQUssRUFBRSxLQUFLLEVBQUUsR0FBRyxFQUFFLE9BQU8sRUFBRSxTQUFTLEVBQUUsT0FBTyxFQUFFLFlBQVksRUFDOUQsV0FBVyxFQUFFLGNBQWMsQ0FBQyxDQUFDO0lBRXJDLElBQUksTUFBTSxDQUFDO0lBRVgsSUFBSSxVQUFVLEVBQUU7UUFDZCxpREFBaUQ7UUFDakQsTUFBTSxHQUFHLE9BQU8sQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsVUFBVSxFQUFDLEVBQUMsQ0FBQyxDQUFDO0tBQ3RFO1NBQU0sSUFBSSxTQUFTLElBQUksYUFBYSxFQUFFO1FBQ3JDLHFFQUFxRTtRQUNyRSxJQUFJLENBQUMsTUFBTSxDQUNQLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxJQUFJLENBQUMsRUFDbkIsR0FBRyxFQUFFLENBQUMseUNBQXlDLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxFQUFFLENBQUMsQ0FBQztRQUVyRSxNQUFNLElBQUksR0FBRyxVQUFVLENBQUMsZUFBZSxDQUFDLE1BQU0sRUFBRSxJQUFJLEVBQUUsUUFBUSxDQUFDLENBQUM7UUFDaEUsd0VBQXdFO1FBQ3hFLE1BQU0sTUFBTSxHQUFHLEtBQUssQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsTUFBTSxFQUFFLElBQUksRUFBQyxFQUFDLENBQUMsQ0FBQztRQUMzRSxNQUFNO1lBQ0YsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsVUFBVSxFQUFDLEVBQUMsQ0FBQyxDQUFDO1FBQ3hFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxNQUFNLENBQUMsQ0FBQztLQUMvQztTQUFNO1FBQ0wsTUFBTSxrQkFBa0IsR0FBRyxPQUFPLENBQUMsa0JBQWtCLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQzNELElBQUksa0JBQWtCLEVBQUU7WUFDdEIsMERBQTBEO1lBQzFELE1BQU0sTUFBTSxHQUFHLE9BQU8sQ0FBQyxRQUFRLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBZSxDQUFDO1lBQ3hELDBEQUEwRDtZQUMxRCxNQUFNLElBQUksR0FBRyxNQUFNLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUMsS0FBSyxFQUFFLE1BQU0sQ0FBdUIsQ0FBQztZQUNwRSxNQUFNLFlBQVksR0FDZCxtQkFBbUIsQ0FBQyxnQkFBZ0IsRUFBRSxJQUFJLEVBQUUsUUFBUSxFQUFFLE1BQU0sQ0FBQyxDQUFDO1lBQ2xFLE1BQU0sR0FBRyxPQUFPLENBQUMsY0FBYyxDQUFDLFVBQVUsRUFBRSxDQUFDLENBQUMsS0FBSyxFQUFFLFlBQVksQ0FBQyxNQUFNLENBQUMsQ0FBQztTQUMzRTthQUFNO1lBQ0wsTUFBTSxPQUFPLEdBQ1QsSUFBSSxtQkFBbUIsQ0FBQyxNQUFNLEVBQUUsUUFBUSxFQUFFLGdCQUFnQixDQUFDLENBQUM7WUFDaEUsTUFBTSxHQUFHLE9BQU8sQ0FBQyxlQUFlLENBQUMsT0FBTyxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDO1NBQ3pEO0tBQ0Y7SUFFRCxNQUFNLGNBQWMsR0FDaEIsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsVUFBVSxFQUFDLEVBQUMsQ0FBQyxDQUFDO0lBRXhFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxNQUFNLENBQUMsQ0FBQztJQUU5QyxPQUFPLGNBQWMsQ0FBQztBQUN4QixDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0sa0JBQWtCLEdBQWlCO0lBQzlDLFVBQVUsRUFBRSxZQUFZO0lBQ3hCLFdBQVcsRUFBRSxPQUFPO0lBQ3BCLFVBQVUsRUFBRSxZQUFxQztDQUNsRCxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2J1ZmZlciwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBSYW5rLCBzbGljZV91dGlsLCBTdHJpZGVkU2xpY2UsIFN0cmlkZWRTbGljZUF0dHJzLCBTdHJpZGVkU2xpY2VJbnB1dHMsIFRlbnNvckJ1ZmZlciwgVGVuc29ySW5mbywgVHlwZWRBcnJheSwgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtNYXRoQmFja2VuZFdlYkdMfSBmcm9tICcuLi9iYWNrZW5kX3dlYmdsJztcbmltcG9ydCB7c3RyaWRlZFNsaWNlSW1wbENQVX0gZnJvbSAnLi4va2VybmVsX3V0aWxzL3NoYXJlZCc7XG5pbXBvcnQge1N0cmlkZWRTbGljZVByb2dyYW19IGZyb20gJy4uL3N0cmlkZWRfc2xpY2VfZ3B1JztcblxuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL1Jlc2hhcGUnO1xuaW1wb3J0IHtzbGljZX0gZnJvbSAnLi9TbGljZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzdHJpZGVkU2xpY2UoYXJnczoge1xuICBpbnB1dHM6IFN0cmlkZWRTbGljZUlucHV0cyxcbiAgYmFja2VuZDogTWF0aEJhY2tlbmRXZWJHTCxcbiAgYXR0cnM6IFN0cmlkZWRTbGljZUF0dHJzXG59KTogVGVuc29ySW5mbyB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHt4fSA9IGlucHV0cztcbiAgY29uc3Qge1xuICAgIGJlZ2luLFxuICAgIGVuZCxcbiAgICBzdHJpZGVzLFxuICAgIGJlZ2luTWFzayxcbiAgICBlbmRNYXNrLFxuICAgIGVsbGlwc2lzTWFzayxcbiAgICBuZXdBeGlzTWFzayxcbiAgICBzaHJpbmtBeGlzTWFza1xuICB9ID0gYXR0cnM7XG5cbiAgY29uc3Qge1xuICAgIGZpbmFsU2hhcGVTcGFyc2UsXG4gICAgZmluYWxTaGFwZSxcbiAgICBpc0lkZW50aXR5LFxuICAgIHNsaWNlRGltMCxcbiAgICBpc1NpbXBsZVNsaWNlLFxuICAgIGJlZ2luOiAkYmVnaW4sXG4gICAgZW5kOiAkZW5kLFxuICAgIHN0cmlkZXM6ICRzdHJpZGVzXG4gIH0gPVxuICAgICAgc2xpY2VfdXRpbC5zbGljZUluZm8oXG4gICAgICAgICAgeC5zaGFwZSwgYmVnaW4sIGVuZCwgc3RyaWRlcywgYmVnaW5NYXNrLCBlbmRNYXNrLCBlbGxpcHNpc01hc2ssXG4gICAgICAgICAgbmV3QXhpc01hc2ssIHNocmlua0F4aXNNYXNrKTtcblxuICBsZXQgcmVzdWx0O1xuXG4gIGlmIChpc0lkZW50aXR5KSB7XG4gICAgLy8gT3B0aW1pemF0aW9uICMxLCBzbGljZSBpcyBhIG5vLW9wIHBsdXMgcmVzaGFwZVxuICAgIHJlc3VsdCA9IHJlc2hhcGUoe2lucHV0czoge3h9LCBiYWNrZW5kLCBhdHRyczoge3NoYXBlOiBmaW5hbFNoYXBlfX0pO1xuICB9IGVsc2UgaWYgKHNsaWNlRGltMCB8fCBpc1NpbXBsZVNsaWNlKSB7XG4gICAgLy8gT3B0aW1pemF0aW9uICMyLCBzbGljZSBpcyBtZW1vcnkgY29udGlndW91cyAob25seSBvY2N1cnMgaW4gZGltIDApXG4gICAgdXRpbC5hc3NlcnQoXG4gICAgICAgIHguc2hhcGUubGVuZ3RoID49IDEsXG4gICAgICAgICgpID0+IGBJbnB1dCBtdXN0IGhhdmUgcmFuayBhdCBsZWFzdCAxLCBnb3Q6ICR7eC5zaGFwZS5sZW5ndGh9YCk7XG5cbiAgICBjb25zdCBzaXplID0gc2xpY2VfdXRpbC5jb21wdXRlT3V0U2hhcGUoJGJlZ2luLCAkZW5kLCAkc3RyaWRlcyk7XG4gICAgLy8gVG8gdG9sZXJhdGUgYmVnaW5bMF0gPiBlbmRbMF0gKGEgMC1vdXRwdXQgc2xpY2UpLCB3ZSBtaW4oYmVnaW4sIGVuZCkuXG4gICAgY29uc3Qgc2xpY2VkID0gc2xpY2Uoe2lucHV0czoge3h9LCBiYWNrZW5kLCBhdHRyczoge2JlZ2luOiAkYmVnaW4sIHNpemV9fSk7XG4gICAgcmVzdWx0ID1cbiAgICAgICAgcmVzaGFwZSh7aW5wdXRzOiB7eDogc2xpY2VkfSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogZmluYWxTaGFwZX19KTtcbiAgICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHNsaWNlZCk7XG4gIH0gZWxzZSB7XG4gICAgY29uc3Qgc2hvdWxkRXhlY3V0ZU9uQ1BVID0gYmFja2VuZC5zaG91bGRFeGVjdXRlT25DUFUoW3hdKTtcbiAgICBpZiAoc2hvdWxkRXhlY3V0ZU9uQ1BVKSB7XG4gICAgICAvLyB0c2xpbnQ6ZGlzYWJsZS1uZXh0LWxpbmU6IG5vLXVubmVjZXNzYXJ5LXR5cGUtYXNzZXJ0aW9uXG4gICAgICBjb25zdCB2YWx1ZXMgPSBiYWNrZW5kLnJlYWRTeW5jKHguZGF0YUlkKSBhcyBUeXBlZEFycmF5O1xuICAgICAgLy8gdHNsaW50OmRpc2FibGUtbmV4dC1saW5lOiBuby11bm5lY2Vzc2FyeS10eXBlLWFzc2VydGlvblxuICAgICAgY29uc3QgeEJ1ZiA9IGJ1ZmZlcih4LnNoYXBlLCB4LmR0eXBlLCB2YWx1ZXMpIGFzIFRlbnNvckJ1ZmZlcjxSYW5rPjtcbiAgICAgIGNvbnN0IHJlc3VsdFZhbHVlcyA9XG4gICAgICAgICAgc3RyaWRlZFNsaWNlSW1wbENQVShmaW5hbFNoYXBlU3BhcnNlLCB4QnVmLCAkc3RyaWRlcywgJGJlZ2luKTtcbiAgICAgIHJlc3VsdCA9IGJhY2tlbmQubWFrZVRlbnNvckluZm8oZmluYWxTaGFwZSwgeC5kdHlwZSwgcmVzdWx0VmFsdWVzLnZhbHVlcyk7XG4gICAgfSBlbHNlIHtcbiAgICAgIGNvbnN0IHByb2dyYW0gPVxuICAgICAgICAgIG5ldyBTdHJpZGVkU2xpY2VQcm9ncmFtKCRiZWdpbiwgJHN0cmlkZXMsIGZpbmFsU2hhcGVTcGFyc2UpO1xuICAgICAgcmVzdWx0ID0gYmFja2VuZC5ydW5XZWJHTFByb2dyYW0ocHJvZ3JhbSwgW3hdLCB4LmR0eXBlKTtcbiAgICB9XG4gIH1cblxuICBjb25zdCByZXN1bHRSZXNoYXBlZCA9XG4gICAgICByZXNoYXBlKHtpbnB1dHM6IHt4OiByZXN1bHR9LCBiYWNrZW5kLCBhdHRyczoge3NoYXBlOiBmaW5hbFNoYXBlfX0pO1xuXG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8ocmVzdWx0KTtcblxuICByZXR1cm4gcmVzdWx0UmVzaGFwZWQ7XG59XG5cbmV4cG9ydCBjb25zdCBzdHJpZGVkU2xpY2VDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogU3RyaWRlZFNsaWNlLFxuICBiYWNrZW5kTmFtZTogJ3dlYmdsJyxcbiAga2VybmVsRnVuYzogc3RyaWRlZFNsaWNlIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==