gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
/**
 * @license
 * Copyright 2018 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
 
import {ENGINE} from '../engine';
import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {Rank, TensorLike} from '../types';
import * as util from '../util';
import {op} from './operation';
import * as slice_util from './slice_util';
 
/**
 * Extracts a 1D slice from 1D array starting at coordinates `begin` and is
 * of length `size`. See `slice` for details.
 */
function slice1d_(
    x: Tensor1D|TensorLike, begin: number, size: number): Tensor1D {
  const $x = convertToTensor(x, 'x', 'slice1d');
  util.assert(
      $x.rank === 1,
      () =>
          `slice1d expects a rank-1 tensor, but got a rank-${$x.rank} tensor`);
  return slice($x, [begin], [size]);
}
 
/**
 * Extracts a 2D slice from a 2D array starting at coordinates `begin` and
 * is of size `size`. See `slice` for details.
 */
function slice2d_(
    x: Tensor2D|TensorLike, begin: [number, number],
    size: [number, number]): Tensor2D {
  const $x = convertToTensor(x, 'x', 'slice2d');
  util.assert(
      $x.rank === 2,
      () =>
          `slice2d expects a rank-2 tensor, but got a rank-${$x.rank} tensor`);
  return slice($x, begin, size);
}
 
/**
 * Extracts a 3D slice from a 3D array starting at coordinates `begin` and
 * is of size `size`. See `slice` for details.
 */
function slice3d_(
    x: Tensor3D|TensorLike, begin: [number, number, number],
    size: [number, number, number]): Tensor3D {
  const $x = convertToTensor(x, 'x', 'slice3d');
  util.assert(
      $x.rank === 3,
      () =>
          `slice3d expects a rank-3 tensor, but got a rank-${$x.rank} tensor`);
  return slice($x, begin, size);
}
 
/**
 * Extracts a 4D slice from a 4D array starting at coordinates `begin` and
 * is of size `size`. See `slice` for details.
 */
function slice4d_(
    x: Tensor4D|TensorLike, begin: [number, number, number, number],
    size: [number, number, number, number]): Tensor4D {
  const $x = convertToTensor(x, 'x', 'slice4d');
  util.assert(
      $x.rank === 4,
      () =>
          `slice4d expects a rank-4 tensor, but got a rank-${$x.rank} tensor`);
  return slice($x, begin, size);
}
 
/**
 * Extracts a slice from a `tf.Tensor` starting at coordinates `begin`
 * and is of size `size`.
 *
 * Also available are stricter rank-specific methods with the same signature
 * as this method that assert that `x` is of the given rank:
 *   - `tf.slice1d`
 *   - `tf.slice2d`
 *   - `tf.slice3d`
 *   - `tf.slice4d`
 *
 * ```js
 * const x = tf.tensor1d([1, 2, 3, 4]);
 *
 * x.slice([1], [2]).print();
 * ```
 *
 * ```js
 * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
 *
 * x.slice([1, 0], [1, 2]).print();
 * ```
 * @param x The input `tf.Tensor` to slice from.
 * @param begin The coordinates to start the slice from. The length can be
 *     less than the rank of x - the rest of the axes will have implicit 0 as
 *     start. Can also be a single number, in which case it specifies the
 *     first axis.
 * @param size The size of the slice. The length can be less than the rank of
 *     x - the rest of the axes will have implicit -1. A value of -1 requests
 *     the rest of the dimensions in the axis. Can also be a single number,
 *     in which case it specifies the size of the first axis.
 */
/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
function slice_<R extends Rank, T extends Tensor<R>>(
    x: T|TensorLike, begin: number|number[], size?: number|number[]): T {
  const $x = convertToTensor(x, 'x', 'slice');
 
  if ($x.rank === 0) {
    throw new Error('Slicing scalar is not possible');
  }
  // The following logic allows for more ergonomic calls.
  let begin_: number[];
  if (typeof begin === 'number') {
    begin_ = [begin, ...new Array($x.rank - 1).fill(0)];
  } else if (begin.length < $x.rank) {
    begin_ = begin.concat(new Array($x.rank - begin.length).fill(0));
  } else {
    begin_ = begin.slice();
  }
  begin_.forEach(d => {
    util.assert(
        d !== -1, () => 'slice() does not support negative begin indexing.');
  });
  let size_: number[];
  if (size == null) {
    size_ = new Array($x.rank).fill(-1);
  } else if (typeof size === 'number') {
    size_ = [size, ...new Array($x.rank - 1).fill(-1)];
  } else if (size.length < $x.rank) {
    size_ = size.concat(new Array($x.rank - size.length).fill(-1));
  } else {
    size_ = size;
  }
  size_ = size_.map((d, i) => {
    if (d >= 0) {
      return d;
    } else {
      util.assert(
          d === -1,
          () => `Negative size values should be exactly -1 but got ` +
              `${d} for the slice() size at index ${i}.`);
      return $x.shape[i] - begin_[i];
    }
  });
  slice_util.assertParamsValid($x, begin_, size_);
  const inputShape = $x.shape;
  const grad = (dy: T) => {
    // Create an Nx2 padding where the first column represents how many
    // zeros are prepended (at start) for each dimension, and the second
    // column indicates how many zeros are appended (at end).
 
    // The number of zeros to append is the shape of the input
    // elementwise-subtracted by both the begin vector and sizes vector.
    const paddings: Array<[number, number]> = [];
    for (let i = 0; i < dy.rank; i++) {
      paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
    }
    return {x: () => dy.pad(paddings)};
  };
  const attrs = {begin: begin_, size: size_};
  return ENGINE.runKernelFunc(
      backend => backend.slice($x, begin_, size_), {x: $x}, grad, 'Slice',
      attrs);
}
 
export const slice = op({slice_});
export const slice1d = op({slice1d_});
export const slice2d = op({slice2d_});
export const slice3d = op({slice3d_});
export const slice4d = op({slice4d_});