gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { concat, keep, reshape, scalar, slice, stack, tensor, tidy, unstack } from '@tensorflow/tfjs-core';
import { assertShapesMatchAllowUndefinedSize } from './tensor_utils';
/**
 * The TensorArray object keeps an array of Tensors.  It
 * allows reading from the array and writing to the array.
 */
export class TensorArray {
    constructor(name, dtype, maxSize, elementShape, identicalElementShapes, dynamicSize, clearAfterRead) {
        this.name = name;
        this.dtype = dtype;
        this.maxSize = maxSize;
        this.elementShape = elementShape;
        this.identicalElementShapes = identicalElementShapes;
        this.dynamicSize = dynamicSize;
        this.clearAfterRead = clearAfterRead;
        this.tensors = [];
        this.closed_ = false;
        this.idTensor = scalar(0);
        keep(this.idTensor);
    }
    get id() {
        return this.idTensor.id;
    }
    get closed() {
        return this.closed_;
    }
    /**
     * Dispose the tensors and idTensor and mark the TensoryArray as closed.
     */
    clearAndClose(keepIds) {
        this.tensors.forEach(tensor => {
            if (keepIds == null || !keepIds.has(tensor.tensor.id)) {
                tensor.tensor.dispose();
            }
        });
        this.tensors = [];
        this.closed_ = true;
        this.idTensor.dispose();
    }
    size() {
        return this.tensors.length;
    }
    /**
     * Read the value at location index in the TensorArray.
     * @param index Number the index to read from.
     */
    read(index) {
        if (this.closed_) {
            throw new Error(`TensorArray ${this.name} has already been closed.`);
        }
        if (index < 0 || index >= this.size()) {
            throw new Error(`Tried to read from index ${index}, but array size is: ${this.size()}`);
        }
        const tensorWithState = this.tensors[index];
        if (tensorWithState.cleared) {
            throw new Error(`TensorArray ${this.name}: Could not read index ${index} twice because it was cleared after a previous read ` +
                `(perhaps try setting clear_after_read = false?).`);
        }
        if (this.clearAfterRead) {
            tensorWithState.cleared = true;
        }
        tensorWithState.read = true;
        return tensorWithState.tensor;
    }
    /**
     * Helper method to read multiple tensors from the specified indices.
     */
    readMany(indices) {
        return indices.map(index => this.read(index));
    }
    /**
     * Write value into the index of the TensorArray.
     * @param index number the index to write to.
     * @param tensor
     */
    write(index, tensor) {
        if (this.closed_) {
            throw new Error(`TensorArray ${this.name} has already been closed.`);
        }
        if (index < 0 || !this.dynamicSize && index >= this.maxSize) {
            throw new Error(`Tried to write to index ${index}, but array is not resizeable and size is: ${this.maxSize}`);
        }
        const t = this.tensors[index] || {};
        if (tensor.dtype !== this.dtype) {
            throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index},
          because the value dtype is ${tensor.dtype}, but TensorArray dtype is ${this.dtype}.`);
        }
        // Set the shape for the first time write to unknow shape tensor array
        if (this.size() === 0 &&
            (this.elementShape == null || this.elementShape.length === 0)) {
            this.elementShape = tensor.shape;
        }
        assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, `TensorArray ${this.name}: Could not write to TensorArray index ${index}.`);
        if (t.read) {
            throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been read.`);
        }
        if (t.written) {
            throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been written.`);
        }
        t.tensor = tensor;
        keep(tensor);
        t.written = true;
        this.tensors[index] = t;
    }
    /**
     * Helper method to write multiple tensors to the specified indices.
     */
    writeMany(indices, tensors) {
        if (indices.length !== tensors.length) {
            throw new Error(`TensorArray ${this.name}: could not write multiple tensors,` +
                `because the index size: ${indices.length} is not the same as tensors size: ${tensors.length}.`);
        }
        indices.forEach((i, index) => this.write(i, tensors[index]));
    }
    /**
     * Return selected values in the TensorArray as a packed Tensor. All of
     * selected values must have been written and their shapes must all match.
     * @param [indices] number[] Optional. Taking values in [0, max_value). If the
     *    TensorArray is not dynamic, max_value=size(). If not specified returns
     *    all tensors in the original order.
     * @param [dtype]
     */
    gather(indices, dtype) {
        if (!!dtype && dtype !== this.dtype) {
            throw new Error(`TensorArray dtype is ${this.dtype} but gather requested dtype ${dtype}`);
        }
        if (!indices) {
            indices = [];
            for (let i = 0; i < this.size(); i++) {
                indices.push(i);
            }
        }
        else {
            indices = indices.slice(0, this.size());
        }
        if (indices.length === 0) {
            return tensor([], [0].concat(this.elementShape));
        }
        // Read all the PersistentTensors into a vector to keep track of
        // their memory.
        const tensors = this.readMany(indices);
        assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: ');
        return stack(tensors, 0);
    }
    /**
     * Return the values in the TensorArray as a concatenated Tensor.
     */
    concat(dtype) {
        if (!!dtype && dtype !== this.dtype) {
            throw new Error(`TensorArray dtype is ${this.dtype} but concat requested dtype ${dtype}`);
        }
        if (this.size() === 0) {
            return tensor([], [0].concat(this.elementShape));
        }
        const indices = [];
        for (let i = 0; i < this.size(); i++) {
            indices.push(i);
        }
        // Collect all the tensors from the tensors array.
        const tensors = this.readMany(indices);
        assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, `TensorArray shape mismatch: tensor array shape (${this.elementShape}) vs first tensor shape (${tensors[0].shape})`);
        return concat(tensors, 0);
    }
    /**
     * Scatter the values of a Tensor in specific indices of a TensorArray.
     * @param indices nummber[] values in [0, max_value). If the
     *    TensorArray is not dynamic, max_value=size().
     * @param tensor Tensor input tensor.
     */
    scatter(indices, tensor) {
        if (tensor.dtype !== this.dtype) {
            throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor.dtype}`);
        }
        if (indices.length !== tensor.shape[0]) {
            throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${indices.length} vs. ${tensor.shape[0]}`);
        }
        const maxIndex = Math.max(...indices);
        if (!this.dynamicSize && maxIndex >= this.maxSize) {
            throw new Error(`Max index must be < array size (${maxIndex}  vs. ${this.maxSize})`);
        }
        this.writeMany(indices, unstack(tensor, 0));
    }
    /**
     * Split the values of a Tensor into the TensorArray.
     * @param length number[] with the lengths to use when splitting value along
     *    its first dimension.
     * @param tensor Tensor, the tensor to split.
     */
    split(length, tensor) {
        if (tensor.dtype !== this.dtype) {
            throw new Error(`TensorArray dtype is ${this.dtype} but tensor has dtype ${tensor.dtype}`);
        }
        let totalLength = 0;
        const cumulativeLengths = length.map(len => {
            totalLength += len;
            return totalLength;
        });
        if (totalLength !== tensor.shape[0]) {
            throw new Error(`Expected sum of lengths to be equal to
          tensor.shape[0], but sum of lengths is
        ${totalLength}, and tensor's shape is: ${tensor.shape}`);
        }
        if (!this.dynamicSize && length.length !== this.maxSize) {
            throw new Error(`TensorArray's size is not equal to the size of lengths (${this.maxSize} vs. ${length.length}), ` +
                'and the TensorArray is not marked as dynamically resizeable');
        }
        const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
        const tensors = [];
        tidy(() => {
            tensor = reshape(tensor, [1, totalLength, elementPerRow]);
            for (let i = 0; i < length.length; ++i) {
                const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];
                const indices = [0, previousLength, 0];
                const sizes = [1, length[i], elementPerRow];
                tensors[i] = reshape(slice(tensor, indices, sizes), this.elementShape);
            }
            return tensors;
        });
        const indices = [];
        for (let i = 0; i < length.length; i++) {
            indices[i] = i;
        }
        this.writeMany(indices, tensors);
    }
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"tensor_array.js","sourceRoot":"","sources":["../../../../../../tfjs-converter/src/executor/tensor_array.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,MAAM,EAAY,IAAI,EAAE,OAAO,EAAE,MAAM,EAAE,KAAK,EAAE,KAAK,EAAU,MAAM,EAAE,IAAI,EAAE,OAAO,EAAC,MAAM,uBAAuB,CAAC;AAE3H,OAAO,EAAC,mCAAmC,EAAC,MAAM,gBAAgB,CAAC;AAQnE;;;GAGG;AACH,MAAM,OAAO,WAAW;IAItB,YACa,IAAY,EAAW,KAAe,EAAU,OAAe,EAChE,YAAsB,EAAW,sBAA+B,EAC/D,WAAoB,EAAW,cAAuB;QAFtD,SAAI,GAAJ,IAAI,CAAQ;QAAW,UAAK,GAAL,KAAK,CAAU;QAAU,YAAO,GAAP,OAAO,CAAQ;QAChE,iBAAY,GAAZ,YAAY,CAAU;QAAW,2BAAsB,GAAtB,sBAAsB,CAAS;QAC/D,gBAAW,GAAX,WAAW,CAAS;QAAW,mBAAc,GAAd,cAAc,CAAS;QAN3D,YAAO,GAAsB,EAAE,CAAC;QAChC,YAAO,GAAG,KAAK,CAAC;QAMtB,IAAI,CAAC,QAAQ,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QAC1B,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;IACtB,CAAC;IAED,IAAI,EAAE;QACJ,OAAO,IAAI,CAAC,QAAQ,CAAC,EAAE,CAAC;IAC1B,CAAC;IAED,IAAI,MAAM;QACR,OAAO,IAAI,CAAC,OAAO,CAAC;IACtB,CAAC;IAED;;OAEG;IACH,aAAa,CAAC,OAAqB;QACjC,IAAI,CAAC,OAAO,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC5B,IAAI,OAAO,IAAI,IAAI,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC,EAAE,CAAC,EAAE;gBACrD,MAAM,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC;aACzB;QACH,CAAC,CAAC,CAAC;QACH,IAAI,CAAC,OAAO,GAAG,EAAE,CAAC;QAClB,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC;QACpB,IAAI,CAAC,QAAQ,CAAC,OAAO,EAAE,CAAC;IAC1B,CAAC;IAED,IAAI;QACF,OAAO,IAAI,CAAC,OAAO,CAAC,MAAM,CAAC;IAC7B,CAAC;IAED;;;OAGG;IACH,IAAI,CAAC,KAAa;QAChB,IAAI,IAAI,CAAC,OAAO,EAAE;YAChB,MAAM,IAAI,KAAK,CAAC,eAAe,IAAI,CAAC,IAAI,2BAA2B,CAAC,CAAC;SACtE;QAED,IAAI,KAAK,GAAG,CAAC,IAAI,KAAK,IAAI,IAAI,CAAC,IAAI,EAAE,EAAE;YACrC,MAAM,IAAI,KAAK,CAAC,4BAA4B,KAAK,wBAC7C,IAAI,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC;SACpB;QAED,MAAM,eAAe,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;QAC5C,IAAI,eAAe,CAAC,OAAO,EAAE;YAC3B,MAAM,IAAI,KAAK,CACX,eAAe,IAAI,CAAC,IAAI,0BACpB,KAAK,sDAAsD;gBAC/D,kDAAkD,CAAC,CAAC;SACzD;QAED,IAAI,IAAI,CAAC,cAAc,EAAE;YACvB,eAAe,CAAC,OAAO,GAAG,IAAI,CAAC;SAChC;QAED,eAAe,CAAC,IAAI,GAAG,IAAI,CAAC;QAC5B,OAAO,eAAe,CAAC,MAAM,CAAC;IAChC,CAAC;IAED;;OAEG;IACH,QAAQ,CAAC,OAAiB;QACxB,OAAO,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC;IAChD,CAAC;IAED;;;;OAIG;IACH,KAAK,CAAC,KAAa,EAAE,MAAc;QACjC,IAAI,IAAI,CAAC,OAAO,EAAE;YAChB,MAAM,IAAI,KAAK,CAAC,eAAe,IAAI,CAAC,IAAI,2BAA2B,CAAC,CAAC;SACtE;QAED,IAAI,KAAK,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,WAAW,IAAI,KAAK,IAAI,IAAI,CAAC,OAAO,EAAE;YAC3D,MAAM,IAAI,KAAK,CAAC,2BACZ,KAAK,8CAA8C,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;SACxE;QAED,MAAM,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,EAAE,CAAC;QAEpC,IAAI,MAAM,CAAC,KAAK,KAAK,IAAI,CAAC,KAAK,EAAE;YAC/B,MAAM,IAAI,KAAK,CAAC,eACZ,IAAI,CAAC,IAAI,0CAA0C,KAAK;uCAExD,MAAM,CAAC,KAAK,8BAA8B,IAAI,CAAC,KAAK,GAAG,CAAC,CAAC;SAC9D;QAED,sEAAsE;QACtE,IAAI,IAAI,CAAC,IAAI,EAAE,KAAK,CAAC;YACjB,CAAC,IAAI,CAAC,YAAY,IAAI,IAAI,IAAI,IAAI,CAAC,YAAY,CAAC,MAAM,KAAK,CAAC,CAAC,EAAE;YACjE,IAAI,CAAC,YAAY,GAAG,MAAM,CAAC,KAAK,CAAC;SAClC;QAED,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,MAAM,CAAC,KAAK,EAC/B,eAAe,IAAI,CAAC,IAAI,0CACpB,KAAK,GAAG,CAAC,CAAC;QAElB,IAAI,CAAC,CAAC,IAAI,EAAE;YACV,MAAM,IAAI,KAAK,CACX,eAAe,IAAI,CAAC,IAAI,0CACpB,KAAK,qCAAqC,CAAC,CAAC;SACrD;QAED,IAAI,CAAC,CAAC,OAAO,EAAE;YACb,MAAM,IAAI,KAAK,CACX,eAAe,IAAI,CAAC,IAAI,0CACpB,KAAK,wCAAwC,CAAC,CAAC;SACxD;QAED,CAAC,CAAC,MAAM,GAAG,MAAM,CAAC;QAClB,IAAI,CAAC,MAAM,CAAC,CAAC;QACb,CAAC,CAAC,OAAO,GAAG,IAAI,CAAC;QAEjB,IAAI,CAAC,OAAO,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC;IAC1B,CAAC;IAED;;OAEG;IACH,SAAS,CAAC,OAAiB,EAAE,OAAiB;QAC5C,IAAI,OAAO,CAAC,MAAM,KAAK,OAAO,CAAC,MAAM,EAAE;YACrC,MAAM,IAAI,KAAK,CACX,eAAe,IAAI,CAAC,IAAI,qCAAqC;gBAC7D,2BACI,OAAO,CAAC,MAAM,qCACd,OAAO,CAAC,MAAM,GAAG,CAAC,CAAC;SAC5B;QAED,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,KAAK,EAAE,EAAE,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAC/D,CAAC;IAED;;;;;;;OAOG;IACH,MAAM,CAAC,OAAkB,EAAE,KAAgB;QACzC,IAAI,CAAC,CAAC,KAAK,IAAI,KAAK,KAAK,IAAI,CAAC,KAAK,EAAE;YACnC,MAAM,IAAI,KAAK,CAAC,wBACZ,IAAI,CAAC,KAAK,+BAA+B,KAAK,EAAE,CAAC,CAAC;SACvD;QAED,IAAI,CAAC,OAAO,EAAE;YACZ,OAAO,GAAG,EAAE,CAAC;YACb,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE;gBACpC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;aACjB;SACF;aAAM;YACL,OAAO,GAAG,OAAO,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC;SACzC;QAED,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;YACxB,OAAO,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC,CAAC;SAClD;QAED,gEAAgE;QAChE,gBAAgB;QAChB,MAAM,OAAO,GAAG,IAAI,CAAC,QAAQ,CAAC,OAAO,CAAC,CAAC;QAEvC,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,OAAO,CAAC,CAAC,CAAC,CAAC,KAAK,EAAE,8BAA8B,CAAC,CAAC;QAEzE,OAAO,KAAK,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC;IAC3B,CAAC;IAED;;OAEG;IACH,MAAM,CAAC,KAAgB;QACrB,IAAI,CAAC,CAAC,KAAK,IAAI,KAAK,KAAK,IAAI,CAAC,KAAK,EAAE;YACnC,MAAM,IAAI,KAAK,CAAC,wBACZ,IAAI,CAAC,KAAK,+BAA+B,KAAK,EAAE,CAAC,CAAC;SACvD;QAED,IAAI,IAAI,CAAC,IAAI,EAAE,KAAK,CAAC,EAAE;YACrB,OAAO,MAAM,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC,CAAC;SAClD;QAED,MAAM,OAAO,GAAG,EAAE,CAAC;QACnB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,IAAI,EAAE,EAAE,CAAC,EAAE,EAAE;YACpC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SACjB;QACD,kDAAkD;QAClD,MAAM,OAAO,GAAG,IAAI,CAAC,QAAQ,CAAC,OAAO,CAAC,CAAC;QAEvC,mCAAmC,CAC/B,IAAI,CAAC,YAAY,EAAE,OAAO,CAAC,CAAC,CAAC,CAAC,KAAK,EACnC,mDACI,IAAI,CAAC,YAAY,4BAA4B,OAAO,CAAC,CAAC,CAAC,CAAC,KAAK,GAAG,CAAC,CAAC;QAE1E,OAAO,MAAM,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC;IAC5B,CAAC;IAED;;;;;OAKG;IACH,OAAO,CAAC,OAAiB,EAAE,MAAc;QACvC,IAAI,MAAM,CAAC,KAAK,KAAK,IAAI,CAAC,KAAK,EAAE;YAC/B,MAAM,IAAI,KAAK,CAAC,wBACZ,IAAI,CAAC,KAAK,yBAAyB,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;SACxD;QAED,IAAI,OAAO,CAAC,MAAM,KAAK,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE;YACtC,MAAM,IAAI,KAAK,CAAC,sDACZ,OAAO,CAAC,MAAM,QAAQ,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;SAC9C;QAED,MAAM,QAAQ,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,OAAO,CAAC,CAAC;QAEtC,IAAI,CAAC,IAAI,CAAC,WAAW,IAAI,QAAQ,IAAI,IAAI,CAAC,OAAO,EAAE;YACjD,MAAM,IAAI,KAAK,CACX,mCAAmC,QAAQ,SAAS,IAAI,CAAC,OAAO,GAAG,CAAC,CAAC;SAC1E;QAED,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC;IAC9C,CAAC;IAED;;;;;OAKG;IACH,KAAK,CAAC,MAAgB,EAAE,MAAc;QACpC,IAAI,MAAM,CAAC,KAAK,KAAK,IAAI,CAAC,KAAK,EAAE;YAC/B,MAAM,IAAI,KAAK,CAAC,wBACZ,IAAI,CAAC,KAAK,yBAAyB,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;SACxD;QACD,IAAI,WAAW,GAAG,CAAC,CAAC;QACpB,MAAM,iBAAiB,GAAG,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;YACzC,WAAW,IAAI,GAAG,CAAC;YACnB,OAAO,WAAW,CAAC;QACrB,CAAC,CAAC,CAAC;QAEH,IAAI,WAAW,KAAK,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE;YACnC,MAAM,IAAI,KAAK,CAAC;;UAEZ,WAAW,4BAA4B,MAAM,CAAC,KAAK,EAAE,CAAC,CAAC;SAC5D;QAED,IAAI,CAAC,IAAI,CAAC,WAAW,IAAI,MAAM,CAAC,MAAM,KAAK,IAAI,CAAC,OAAO,EAAE;YACvD,MAAM,IAAI,KAAK,CACX,2DACI,IAAI,CAAC,OAAO,QAAQ,MAAM,CAAC,MAAM,KAAK;gBAC1C,6DAA6D,CAAC,CAAC;SACpE;QAED,MAAM,aAAa,GAAG,WAAW,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,IAAI,GAAG,WAAW,CAAC;QACxE,MAAM,OAAO,GAAa,EAAE,CAAC;QAC7B,IAAI,CAAC,GAAG,EAAE;YACR,MAAM,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,EAAE,WAAW,EAAE,aAAa,CAAC,CAAC,CAAC;YAC1D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;gBACtC,MAAM,cAAc,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,iBAAiB,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;gBAChE,MAAM,OAAO,GAAG,CAAC,CAAC,EAAE,cAAc,EAAE,CAAC,CAAC,CAAC;gBACvC,MAAM,KAAK,GAAG,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,CAAC,EAAE,aAAa,CAAC,CAAC;gBAC5C,OAAO,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,KAAK,CAAC,MAAM,EAAE,OAAO,EAAE,KAAK,CAAC,EAAE,IAAI,CAAC,YAAY,CAAC,CAAC;aACxE;YACD,OAAO,OAAO,CAAC;QACjB,CAAC,CAAC,CAAC;QACH,MAAM,OAAO,GAAG,EAAE,CAAC;QACnB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;YACtC,OAAO,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;SAChB;QACD,IAAI,CAAC,SAAS,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;IACnC,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {concat, DataType, keep, reshape, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core';\n\nimport {assertShapesMatchAllowUndefinedSize} from './tensor_utils';\n\nexport interface TensorWithState {\n  tensor?: Tensor;\n  written?: boolean;\n  read?: boolean;\n  cleared?: boolean;\n}\n/**\n * The TensorArray object keeps an array of Tensors.  It\n * allows reading from the array and writing to the array.\n */\nexport class TensorArray {\n  private tensors: TensorWithState[] = [];\n  private closed_ = false;\n  readonly idTensor: Tensor;\n  constructor(\n      readonly name: string, readonly dtype: DataType, private maxSize: number,\n      private elementShape: number[], readonly identicalElementShapes: boolean,\n      readonly dynamicSize: boolean, readonly clearAfterRead: boolean) {\n    this.idTensor = scalar(0);\n    keep(this.idTensor);\n  }\n\n  get id() {\n    return this.idTensor.id;\n  }\n\n  get closed() {\n    return this.closed_;\n  }\n\n  /**\n   * Dispose the tensors and idTensor and mark the TensoryArray as closed.\n   */\n  clearAndClose(keepIds?: Set<number>) {\n    this.tensors.forEach(tensor => {\n      if (keepIds == null || !keepIds.has(tensor.tensor.id)) {\n        tensor.tensor.dispose();\n      }\n    });\n    this.tensors = [];\n    this.closed_ = true;\n    this.idTensor.dispose();\n  }\n\n  size(): number {\n    return this.tensors.length;\n  }\n\n  /**\n   * Read the value at location index in the TensorArray.\n   * @param index Number the index to read from.\n   */\n  read(index: number): Tensor {\n    if (this.closed_) {\n      throw new Error(`TensorArray ${this.name} has already been closed.`);\n    }\n\n    if (index < 0 || index >= this.size()) {\n      throw new Error(`Tried to read from index ${index}, but array size is: ${\n          this.size()}`);\n    }\n\n    const tensorWithState = this.tensors[index];\n    if (tensorWithState.cleared) {\n      throw new Error(\n          `TensorArray ${this.name}: Could not read index ${\n              index} twice because it was cleared after a previous read ` +\n          `(perhaps try setting clear_after_read = false?).`);\n    }\n\n    if (this.clearAfterRead) {\n      tensorWithState.cleared = true;\n    }\n\n    tensorWithState.read = true;\n    return tensorWithState.tensor;\n  }\n\n  /**\n   * Helper method to read multiple tensors from the specified indices.\n   */\n  readMany(indices: number[]): Tensor[] {\n    return indices.map(index => this.read(index));\n  }\n\n  /**\n   * Write value into the index of the TensorArray.\n   * @param index number the index to write to.\n   * @param tensor\n   */\n  write(index: number, tensor: Tensor) {\n    if (this.closed_) {\n      throw new Error(`TensorArray ${this.name} has already been closed.`);\n    }\n\n    if (index < 0 || !this.dynamicSize && index >= this.maxSize) {\n      throw new Error(`Tried to write to index ${\n          index}, but array is not resizeable and size is: ${this.maxSize}`);\n    }\n\n    const t = this.tensors[index] || {};\n\n    if (tensor.dtype !== this.dtype) {\n      throw new Error(`TensorArray ${\n          this.name}: Could not write to TensorArray index ${index},\n          because the value dtype is ${\n          tensor.dtype}, but TensorArray dtype is ${this.dtype}.`);\n    }\n\n    // Set the shape for the first time write to unknow shape tensor array\n    if (this.size() === 0 &&\n        (this.elementShape == null || this.elementShape.length === 0)) {\n      this.elementShape = tensor.shape;\n    }\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, tensor.shape,\n        `TensorArray ${this.name}: Could not write to TensorArray index ${\n            index}.`);\n\n    if (t.read) {\n      throw new Error(\n          `TensorArray ${this.name}: Could not write to TensorArray index ${\n              index}, because it has already been read.`);\n    }\n\n    if (t.written) {\n      throw new Error(\n          `TensorArray ${this.name}: Could not write to TensorArray index ${\n              index}, because it has already been written.`);\n    }\n\n    t.tensor = tensor;\n    keep(tensor);\n    t.written = true;\n\n    this.tensors[index] = t;\n  }\n\n  /**\n   * Helper method to write multiple tensors to the specified indices.\n   */\n  writeMany(indices: number[], tensors: Tensor[]) {\n    if (indices.length !== tensors.length) {\n      throw new Error(\n          `TensorArray ${this.name}: could not write multiple tensors,` +\n          `because the index size: ${\n              indices.length} is not the same as tensors size: ${\n              tensors.length}.`);\n    }\n\n    indices.forEach((i, index) => this.write(i, tensors[index]));\n  }\n\n  /**\n   * Return selected values in the TensorArray as a packed Tensor. All of\n   * selected values must have been written and their shapes must all match.\n   * @param [indices] number[] Optional. Taking values in [0, max_value). If the\n   *    TensorArray is not dynamic, max_value=size(). If not specified returns\n   *    all tensors in the original order.\n   * @param [dtype]\n   */\n  gather(indices?: number[], dtype?: DataType): Tensor {\n    if (!!dtype && dtype !== this.dtype) {\n      throw new Error(`TensorArray dtype is ${\n          this.dtype} but gather requested dtype ${dtype}`);\n    }\n\n    if (!indices) {\n      indices = [];\n      for (let i = 0; i < this.size(); i++) {\n        indices.push(i);\n      }\n    } else {\n      indices = indices.slice(0, this.size());\n    }\n\n    if (indices.length === 0) {\n      return tensor([], [0].concat(this.elementShape));\n    }\n\n    // Read all the PersistentTensors into a vector to keep track of\n    // their memory.\n    const tensors = this.readMany(indices);\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: ');\n\n    return stack(tensors, 0);\n  }\n\n  /**\n   * Return the values in the TensorArray as a concatenated Tensor.\n   */\n  concat(dtype?: DataType): Tensor {\n    if (!!dtype && dtype !== this.dtype) {\n      throw new Error(`TensorArray dtype is ${\n          this.dtype} but concat requested dtype ${dtype}`);\n    }\n\n    if (this.size() === 0) {\n      return tensor([], [0].concat(this.elementShape));\n    }\n\n    const indices = [];\n    for (let i = 0; i < this.size(); i++) {\n      indices.push(i);\n    }\n    // Collect all the tensors from the tensors array.\n    const tensors = this.readMany(indices);\n\n    assertShapesMatchAllowUndefinedSize(\n        this.elementShape, tensors[0].shape,\n        `TensorArray shape mismatch: tensor array shape (${\n            this.elementShape}) vs first tensor shape (${tensors[0].shape})`);\n\n    return concat(tensors, 0);\n  }\n\n  /**\n   * Scatter the values of a Tensor in specific indices of a TensorArray.\n   * @param indices nummber[] values in [0, max_value). If the\n   *    TensorArray is not dynamic, max_value=size().\n   * @param tensor Tensor input tensor.\n   */\n  scatter(indices: number[], tensor: Tensor) {\n    if (tensor.dtype !== this.dtype) {\n      throw new Error(`TensorArray dtype is ${\n          this.dtype} but tensor has dtype ${tensor.dtype}`);\n    }\n\n    if (indices.length !== tensor.shape[0]) {\n      throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${\n          indices.length} vs. ${tensor.shape[0]}`);\n    }\n\n    const maxIndex = Math.max(...indices);\n\n    if (!this.dynamicSize && maxIndex >= this.maxSize) {\n      throw new Error(\n          `Max index must be < array size (${maxIndex}  vs. ${this.maxSize})`);\n    }\n\n    this.writeMany(indices, unstack(tensor, 0));\n  }\n\n  /**\n   * Split the values of a Tensor into the TensorArray.\n   * @param length number[] with the lengths to use when splitting value along\n   *    its first dimension.\n   * @param tensor Tensor, the tensor to split.\n   */\n  split(length: number[], tensor: Tensor) {\n    if (tensor.dtype !== this.dtype) {\n      throw new Error(`TensorArray dtype is ${\n          this.dtype} but tensor has dtype ${tensor.dtype}`);\n    }\n    let totalLength = 0;\n    const cumulativeLengths = length.map(len => {\n      totalLength += len;\n      return totalLength;\n    });\n\n    if (totalLength !== tensor.shape[0]) {\n      throw new Error(`Expected sum of lengths to be equal to\n          tensor.shape[0], but sum of lengths is\n        ${totalLength}, and tensor's shape is: ${tensor.shape}`);\n    }\n\n    if (!this.dynamicSize && length.length !== this.maxSize) {\n      throw new Error(\n          `TensorArray's size is not equal to the size of lengths (${\n              this.maxSize} vs. ${length.length}), ` +\n          'and the TensorArray is not marked as dynamically resizeable');\n    }\n\n    const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;\n    const tensors: Tensor[] = [];\n    tidy(() => {\n      tensor = reshape(tensor, [1, totalLength, elementPerRow]);\n      for (let i = 0; i < length.length; ++i) {\n        const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];\n        const indices = [0, previousLength, 0];\n        const sizes = [1, length[i], elementPerRow];\n        tensors[i] = reshape(slice(tensor, indices, sizes), this.elementShape);\n      }\n      return tensors;\n    });\n    const indices = [];\n    for (let i = 0; i < length.length; i++) {\n      indices[i] = i;\n    }\n    this.writeMany(indices, tensors);\n  }\n}\n"]}