/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { slice_util, StridedSlice, util } from '@tensorflow/tfjs-core';
|
import { assertNotComplex } from '../cpu_util';
|
import { reshape } from './Reshape';
|
import { slice } from './Slice';
|
import { stridedSliceImpl } from './StridedSlice_impl';
|
export function stridedSlice(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
|
assertNotComplex(x, 'stridedSlice');
|
const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = slice_util.sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
|
let result;
|
// ref:
|
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/strided_slice_op.cc
|
if (isIdentity) {
|
// Optimization #1, slice is a no-op plus reshape
|
result = reshape({ inputs: { x }, backend, attrs: { shape: finalShape } });
|
}
|
else if (sliceDim0 || isSimpleSlice) {
|
// Optimization #2, slice is memory contiguous (only occurs in dim 0)
|
util.assert(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
|
const size = slice_util.computeOutShape($begin, $end, $strides);
|
// To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
|
const sliced = slice({ inputs: { x }, backend, attrs: { begin: $begin, size } });
|
result =
|
reshape({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
|
backend.disposeIntermediateTensorInfo(sliced);
|
}
|
else {
|
const xBuf = backend.bufferSync(x);
|
const outBuf = stridedSliceImpl(finalShapeSparse, xBuf, $strides, $begin);
|
result = backend.makeTensorInfo(finalShape, outBuf.dtype, outBuf.values);
|
}
|
return result;
|
}
|
export const stridedSliceConfig = {
|
kernelName: StridedSlice,
|
backendName: 'cpu',
|
kernelFunc: stridedSlice
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3RyaWRlZFNsaWNlLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLWNwdS9zcmMva2VybmVscy9TdHJpZGVkU2xpY2UudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFpQyxVQUFVLEVBQUUsWUFBWSxFQUFxRCxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUd4SixPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDN0MsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBQzlCLE9BQU8sRUFBQyxnQkFBZ0IsRUFBQyxNQUFNLHFCQUFxQixDQUFDO0FBRXJELE1BQU0sVUFBVSxZQUFZLENBQUMsSUFJNUI7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNuQixNQUFNLEVBQ0osS0FBSyxFQUNMLEdBQUcsRUFDSCxPQUFPLEVBQ1AsU0FBUyxFQUNULE9BQU8sRUFDUCxZQUFZLEVBQ1osV0FBVyxFQUNYLGNBQWMsRUFDZixHQUFHLEtBQUssQ0FBQztJQUVWLGdCQUFnQixDQUFDLENBQUMsRUFBRSxjQUFjLENBQUMsQ0FBQztJQUVwQyxNQUFNLEVBQ0osZ0JBQWdCLEVBQ2hCLFVBQVUsRUFDVixVQUFVLEVBQ1YsU0FBUyxFQUNULGFBQWEsRUFDYixLQUFLLEVBQUUsTUFBTSxFQUNiLEdBQUcsRUFBRSxJQUFJLEVBQ1QsT0FBTyxFQUFFLFFBQVEsRUFDbEIsR0FDRyxVQUFVLENBQUMsU0FBUyxDQUNoQixDQUFDLENBQUMsS0FBSyxFQUFFLEtBQUssRUFBRSxHQUFHLEVBQUUsT0FBTyxFQUFFLFNBQVMsRUFBRSxPQUFPLEVBQUUsWUFBWSxFQUM5RCxXQUFXLEVBQUUsY0FBYyxDQUFDLENBQUM7SUFFckMsSUFBSSxNQUFNLENBQUM7SUFFWCxPQUFPO0lBQ1AsbUdBQW1HO0lBQ25HLElBQUksVUFBVSxFQUFFO1FBQ2QsaURBQWlEO1FBQ2pELE1BQU0sR0FBRyxPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLFVBQVUsRUFBQyxFQUFDLENBQUMsQ0FBQztLQUN0RTtTQUFNLElBQUksU0FBUyxJQUFJLGFBQWEsRUFBRTtRQUNyQyxxRUFBcUU7UUFDckUsSUFBSSxDQUFDLE1BQU0sQ0FDUCxDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sSUFBSSxDQUFDLEVBQ25CLEdBQUcsRUFBRSxDQUFDLHlDQUF5QyxDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sRUFBRSxDQUFDLENBQUM7UUFFckUsTUFBTSxJQUFJLEdBQUcsVUFBVSxDQUFDLGVBQWUsQ0FBQyxNQUFNLEVBQUUsSUFBSSxFQUFFLFFBQVEsQ0FBQyxDQUFDO1FBQ2hFLHdFQUF3RTtRQUN4RSxNQUFNLE1BQU0sR0FBRyxLQUFLLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLE1BQU0sRUFBRSxJQUFJLEVBQUMsRUFBQyxDQUFDLENBQUM7UUFDM0UsTUFBTTtZQUNGLE9BQU8sQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLFVBQVUsRUFBQyxFQUFDLENBQUMsQ0FBQztRQUN4RSxPQUFPLENBQUMsNkJBQTZCLENBQUMsTUFBTSxDQUFDLENBQUM7S0FDL0M7U0FBTTtRQUNMLE1BQU0sSUFBSSxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQWtCLENBQUMsQ0FBQyxDQUFDO1FBQ3BELE1BQU0sTUFBTSxHQUFHLGdCQUFnQixDQUFDLGdCQUFnQixFQUFFLElBQUksRUFBRSxRQUFRLEVBQUUsTUFBTSxDQUFDLENBQUM7UUFFMUUsTUFBTSxHQUFHLE9BQU8sQ0FBQyxjQUFjLENBQUMsVUFBVSxFQUFFLE1BQU0sQ0FBQyxLQUFLLEVBQUUsTUFBTSxDQUFDLE1BQU0sQ0FBQyxDQUFDO0tBQzFFO0lBRUQsT0FBTyxNQUFNLENBQUM7QUFDaEIsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLGtCQUFrQixHQUFpQjtJQUM5QyxVQUFVLEVBQUUsWUFBWTtJQUN4QixXQUFXLEVBQUUsS0FBSztJQUNsQixVQUFVLEVBQUUsWUFBcUM7Q0FDbEQsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFJhbmssIHNsaWNlX3V0aWwsIFN0cmlkZWRTbGljZSwgU3RyaWRlZFNsaWNlQXR0cnMsIFN0cmlkZWRTbGljZUlucHV0cywgVGVuc29ySW5mbywgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtNYXRoQmFja2VuZENQVX0gZnJvbSAnLi4vYmFja2VuZF9jcHUnO1xuaW1wb3J0IHthc3NlcnROb3RDb21wbGV4fSBmcm9tICcuLi9jcHVfdXRpbCc7XG5pbXBvcnQge3Jlc2hhcGV9IGZyb20gJy4vUmVzaGFwZSc7XG5pbXBvcnQge3NsaWNlfSBmcm9tICcuL1NsaWNlJztcbmltcG9ydCB7c3RyaWRlZFNsaWNlSW1wbH0gZnJvbSAnLi9TdHJpZGVkU2xpY2VfaW1wbCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzdHJpZGVkU2xpY2UoYXJnczoge1xuICBpbnB1dHM6IFN0cmlkZWRTbGljZUlucHV0cyxcbiAgYmFja2VuZDogTWF0aEJhY2tlbmRDUFUsXG4gIGF0dHJzOiBTdHJpZGVkU2xpY2VBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eH0gPSBpbnB1dHM7XG4gIGNvbnN0IHtcbiAgICBiZWdpbixcbiAgICBlbmQsXG4gICAgc3RyaWRlcyxcbiAgICBiZWdpbk1hc2ssXG4gICAgZW5kTWFzayxcbiAgICBlbGxpcHNpc01hc2ssXG4gICAgbmV3QXhpc01hc2ssXG4gICAgc2hyaW5rQXhpc01hc2tcbiAgfSA9IGF0dHJzO1xuXG4gIGFzc2VydE5vdENvbXBsZXgoeCwgJ3N0cmlkZWRTbGljZScpO1xuXG4gIGNvbnN0IHtcbiAgICBmaW5hbFNoYXBlU3BhcnNlLFxuICAgIGZpbmFsU2hhcGUsXG4gICAgaXNJZGVudGl0eSxcbiAgICBzbGljZURpbTAsXG4gICAgaXNTaW1wbGVTbGljZSxcbiAgICBiZWdpbjogJGJlZ2luLFxuICAgIGVuZDogJGVuZCxcbiAgICBzdHJpZGVzOiAkc3RyaWRlc1xuICB9ID1cbiAgICAgIHNsaWNlX3V0aWwuc2xpY2VJbmZvKFxuICAgICAgICAgIHguc2hhcGUsIGJlZ2luLCBlbmQsIHN0cmlkZXMsIGJlZ2luTWFzaywgZW5kTWFzaywgZWxsaXBzaXNNYXNrLFxuICAgICAgICAgIG5ld0F4aXNNYXNrLCBzaHJpbmtBeGlzTWFzayk7XG5cbiAgbGV0IHJlc3VsdDtcblxuICAvLyByZWY6XG4gIC8vIGh0dHBzOi8vZ2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvYmxvYi9tYXN0ZXIvdGVuc29yZmxvdy9jb3JlL2tlcm5lbHMvc3RyaWRlZF9zbGljZV9vcC5jY1xuICBpZiAoaXNJZGVudGl0eSkge1xuICAgIC8vIE9wdGltaXphdGlvbiAjMSwgc2xpY2UgaXMgYSBuby1vcCBwbHVzIHJlc2hhcGVcbiAgICByZXN1bHQgPSByZXNoYXBlKHtpbnB1dHM6IHt4fSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogZmluYWxTaGFwZX19KTtcbiAgfSBlbHNlIGlmIChzbGljZURpbTAgfHwgaXNTaW1wbGVTbGljZSkge1xuICAgIC8vIE9wdGltaXphdGlvbiAjMiwgc2xpY2UgaXMgbWVtb3J5IGNvbnRpZ3VvdXMgKG9ubHkgb2NjdXJzIGluIGRpbSAwKVxuICAgIHV0aWwuYXNzZXJ0KFxuICAgICAgICB4LnNoYXBlLmxlbmd0aCA+PSAxLFxuICAgICAgICAoKSA9PiBgSW5wdXQgbXVzdCBoYXZlIHJhbmsgYXQgbGVhc3QgMSwgZ290OiAke3guc2hhcGUubGVuZ3RofWApO1xuXG4gICAgY29uc3Qgc2l6ZSA9IHNsaWNlX3V0aWwuY29tcHV0ZU91dFNoYXBlKCRiZWdpbiwgJGVuZCwgJHN0cmlkZXMpO1xuICAgIC8vIFRvIHRvbGVyYXRlIGJlZ2luWzBdID4gZW5kWzBdIChhIDAtb3V0cHV0IHNsaWNlKSwgd2UgbWluKGJlZ2luLCBlbmQpLlxuICAgIGNvbnN0IHNsaWNlZCA9IHNsaWNlKHtpbnB1dHM6IHt4fSwgYmFja2VuZCwgYXR0cnM6IHtiZWdpbjogJGJlZ2luLCBzaXplfX0pO1xuICAgIHJlc3VsdCA9XG4gICAgICAgIHJlc2hhcGUoe2lucHV0czoge3g6IHNsaWNlZH0sIGJhY2tlbmQsIGF0dHJzOiB7c2hhcGU6IGZpbmFsU2hhcGV9fSk7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhzbGljZWQpO1xuICB9IGVsc2Uge1xuICAgIGNvbnN0IHhCdWYgPSBiYWNrZW5kLmJ1ZmZlclN5bmM8UmFuaywgJ2Zsb2F0MzInPih4KTtcbiAgICBjb25zdCBvdXRCdWYgPSBzdHJpZGVkU2xpY2VJbXBsKGZpbmFsU2hhcGVTcGFyc2UsIHhCdWYsICRzdHJpZGVzLCAkYmVnaW4pO1xuXG4gICAgcmVzdWx0ID0gYmFja2VuZC5tYWtlVGVuc29ySW5mbyhmaW5hbFNoYXBlLCBvdXRCdWYuZHR5cGUsIG91dEJ1Zi52YWx1ZXMpO1xuICB9XG5cbiAgcmV0dXJuIHJlc3VsdDtcbn1cblxuZXhwb3J0IGNvbnN0IHN0cmlkZWRTbGljZUNvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBTdHJpZGVkU2xpY2UsXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogc3RyaWRlZFNsaWNlIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==
|