/**
|
* @license
|
* Copyright 2023 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { util, kernel_impls, KernelBackend, DataStorage, engine, env, backend_util, buffer, Abs, Complex, Identity, Real, Cast, Add, BitwiseAnd, Ceil, Equal, Exp, Expm1, Floor, FloorDiv, Greater, GreaterEqual, Less, LessEqual, Log, Maximum, Minimum, Multiply, Neg, NotEqual, Transpose, upcastType, Prod, tidy, reshape as reshape$1, broadcastTo, Rsqrt, TensorBuffer, Sigmoid, slice_util, Slice, Sqrt, SquaredDifference, StaticRegexReplace, Sub, registerBackend, Elu, LeakyRelu, Prelu, Relu, Relu6, Reshape, BatchMatMul, broadcast_util, _FusedMatMul, Acos, Acosh, AddN, All, Any, ArgMax, ArgMin, Asin, Asinh, Atan, Atan2, Atanh, AvgPool, AvgPool3D, AvgPool3DGrad, AvgPoolGrad, FusedBatchNorm, BatchToSpaceND, Bincount, BroadcastArgs, ClipByValue, ComplexAbs, Imag, Concat, Conv2D, Conv2DBackpropFilter, Conv2DBackpropInput, Conv3D, Conv3DBackpropFilterV2, Conv3DBackpropInputV2, Cos, Cosh, CropAndResize, Cumprod, Cumsum, DenseBincount, DepthToSpace, DepthwiseConv2dNative, DepthwiseConv2dNativeBackpropFilter, DepthwiseConv2dNativeBackpropInput, Diag, Dilation2D, Dilation2DBackpropFilter, Dilation2DBackpropInput, Draw, Sum, Einsum, EluGrad, Erf, ExpandDims, RealDiv, FFT, Fill, FlipLeftRight, FusedConv2D, FusedDepthwiseConv2D, GatherNd, GatherV2, IFFT, IsFinite, IsInf, IsNan, LinSpace, Log1p, LogicalAnd, LogicalNot, LogicalOr, LRN, LRNGrad, Max, MaxPool, MaxPool3D, MaxPool3DGrad, MaxPoolGrad, MaxPoolWithArgmax, Mean, Min, MirrorPad, Mod, Softmax, Multinomial, NonMaxSuppressionV3, NonMaxSuppressionV4, NonMaxSuppressionV5, OneHot, ZerosLike, OnesLike, Pack, PadV2, Pow, RaggedGather, RaggedRange, RaggedTensorToTensor, Range, Reciprocal, ResizeBilinear, ResizeBilinearGrad, ResizeNearestNeighbor, ResizeNearestNeighborGrad, Reverse, RotateWithOffset, Round, ScatterNd, SearchSorted, Select, Selu, Sign, Sin, Sinh, Softplus, SpaceToBatchND, SparseFillEmptyRows, SparseReshape, SparseSegmentMean, SparseSegmentSum, SparseToDense, SplitV, Square, Step, StridedSlice, StringNGrams, StringSplit, StringToHashBucketFast, Tan, Tanh, TensorScatterUpdate, Tile, TopK, Transform, Unique, Unpack, UnsortedSegmentSum, registerKernel } from '@tensorflow/tfjs-core';
|
import * as seedrandom from 'seedrandom';
|
|
/**
|
* @license
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function assertNotComplex(tensor, opName) {
|
if (!Array.isArray(tensor)) {
|
tensor = [tensor];
|
}
|
tensor.forEach(t => {
|
if (t != null) {
|
util.assert(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors in the CPU backend.`);
|
}
|
});
|
}
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const whereImpl = kernel_impls.whereImpl;
|
class MathBackendCPU extends KernelBackend {
|
nextDataId() {
|
return MathBackendCPU.nextDataId++;
|
}
|
constructor() {
|
super();
|
this.blockSize = 48;
|
this.firstUse = true;
|
this.data = new DataStorage(this, engine());
|
}
|
write(values, shape, dtype) {
|
if (this.firstUse) {
|
this.firstUse = false;
|
if (env().get('IS_NODE')) {
|
backend_util.warn('\n============================\n' +
|
'Hi, looks like you are running TensorFlow.js in ' +
|
'Node.js. To speed things up dramatically, install our node ' +
|
'backend, visit https://github.com/tensorflow/tfjs-node for more details. ' +
|
'\n============================');
|
}
|
}
|
const dataId = { id: this.nextDataId() };
|
this.data.set(dataId, { values, dtype, refCount: 1 });
|
return dataId;
|
}
|
/**
|
* Create a data bucket in cpu backend.
|
* @param shape Shape of the `TensorInfo`.
|
* @param dtype DType of the `TensorInfo`.
|
* @param values The value of the `TensorInfo` stored as a flattened array.
|
*/
|
makeTensorInfo(shape, dtype, values) {
|
let outId;
|
if (dtype === 'string' && values != null && values.length > 0 &&
|
util.isString(values[0])) {
|
const encodedValues = values.map(d => util.encodeString(d));
|
outId = this.write(encodedValues, shape, dtype);
|
}
|
else {
|
outId = this.write(values, shape, dtype);
|
}
|
return { dataId: outId, shape, dtype };
|
}
|
/** Return refCount of a `TensorData`. */
|
refCount(dataId) {
|
if (this.data.has(dataId)) {
|
const tensorData = this.data.get(dataId);
|
return tensorData.refCount;
|
}
|
return 0;
|
}
|
/** Increase refCount of a `TensorData`. */
|
incRef(dataId) {
|
const tensorData = this.data.get(dataId);
|
tensorData.refCount++;
|
}
|
/** Decrease refCount of a `TensorData`. */
|
decRef(dataId) {
|
if (this.data.has(dataId)) {
|
const tensorData = this.data.get(dataId);
|
tensorData.refCount--;
|
}
|
}
|
move(dataId, values, shape, dtype, refCount) {
|
this.data.set(dataId, { values, dtype, refCount });
|
}
|
numDataIds() {
|
return this.data.numDataIds();
|
}
|
async read(dataId) {
|
return this.readSync(dataId);
|
}
|
readSync(dataId) {
|
const { dtype, complexTensorInfos } = this.data.get(dataId);
|
if (dtype === 'complex64') {
|
const realValues = this.readSync(complexTensorInfos.real.dataId);
|
const imagValues = this.readSync(complexTensorInfos.imag.dataId);
|
return backend_util.mergeRealAndImagArrays(realValues, imagValues);
|
}
|
return util.convertBackendValuesAndArrayBuffer(this.data.get(dataId).values, dtype);
|
}
|
bufferSync(t) {
|
const data = this.readSync(t.dataId);
|
if (t.dtype === 'string') {
|
try {
|
// Decode the bytes into string.
|
const strings = data.map(d => util.decodeString(d));
|
return buffer(t.shape, t.dtype, strings);
|
}
|
catch (_a) {
|
throw new Error('Failed to decode encoded string bytes into utf-8');
|
}
|
}
|
return buffer(t.shape, t.dtype, data);
|
}
|
makeOutput(values, shape, dtype) {
|
return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this);
|
}
|
/**
|
* Dispose the memory if the dataId has 0 refCount. Return true if the memory
|
* is released or memory is not managed in this backend, false if memory is
|
* not cleared.
|
* @param dataId
|
* @oaram force Optional, remove the data regardless of refCount
|
*/
|
disposeData(dataId, force = false) {
|
if (this.data.has(dataId)) {
|
this.data.get(dataId).refCount--;
|
if (!force && this.data.get(dataId).refCount > 0) {
|
return false;
|
}
|
const { complexTensorInfos } = this.data.get(dataId);
|
if (complexTensorInfos != null) {
|
this.disposeData(complexTensorInfos.real.dataId, true);
|
this.disposeData(complexTensorInfos.imag.dataId, true);
|
}
|
this.data.delete(dataId);
|
}
|
return true;
|
}
|
disposeIntermediateTensorInfo(tensorInfo) {
|
this.disposeData(tensorInfo.dataId);
|
}
|
async time(f) {
|
const start = util.now();
|
f();
|
const kernelMs = util.now() - start;
|
return { kernelMs };
|
}
|
memory() {
|
return {
|
// Unreliable due to automatic gc. The numbers above are cumulative.
|
unreliable: true,
|
reasons: ['The reported memory is an upper bound. Due to automatic garbage ' +
|
'collection, the true allocated memory may be less.']
|
};
|
}
|
where(condition) {
|
assertNotComplex([condition], 'where');
|
const condVals = this.readSync(condition.dataId);
|
return whereImpl(condition.shape, condVals);
|
}
|
dispose() { }
|
floatPrecision() {
|
return 32;
|
}
|
/** Returns the smallest representable number. */
|
epsilon() {
|
return super.epsilon();
|
}
|
}
|
MathBackendCPU.nextDataId = 0;
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function simpleAbsImpl(vals) {
|
const resultValues = new Float32Array(vals.length);
|
for (let i = 0; i < vals.length; ++i) {
|
resultValues[i] = Math.abs(vals[i]);
|
}
|
return resultValues;
|
}
|
const abs = (args) => {
|
const { x } = args.inputs;
|
const cpuBackend = args.backend;
|
assertNotComplex(x, 'abs');
|
let resultValues = new Float32Array(util.sizeFromShape(x.shape));
|
const values = cpuBackend.data.get(x.dataId).values;
|
resultValues = simpleAbsImpl(values);
|
return cpuBackend.makeOutput(resultValues, x.shape, x.dtype);
|
};
|
const absConfig = {
|
kernelName: Abs,
|
backendName: 'cpu',
|
kernelFunc: abs,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/**
|
* Template that creates implementation for binary ops. Supports broadcast.
|
*/
|
function createSimpleBinaryKernelImpl(op) {
|
return (aShape, bShape, aVals, bVals, dtype) => {
|
const newShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
|
const resultRank = newShape.length;
|
const resultStrides = util.computeStrides(newShape);
|
const resultSize = util.sizeFromShape(newShape);
|
const result = util.getTypedArrayFromDType(dtype, resultSize);
|
const aRank = aShape.length;
|
const bRank = bShape.length;
|
const aStrides = util.computeStrides(aShape);
|
const bStrides = util.computeStrides(bShape);
|
const aBroadcastDims = backend_util.getBroadcastDims(aShape, newShape);
|
const bBroadcastDims = backend_util.getBroadcastDims(bShape, newShape);
|
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
|
for (let i = 0; i < result.length; ++i) {
|
result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
|
}
|
}
|
else {
|
for (let i = 0; i < result.length; ++i) {
|
const loc = util.indexToLoc(i, resultRank, resultStrides);
|
const aLoc = loc.slice(-aRank);
|
aBroadcastDims.forEach(d => aLoc[d] = 0);
|
const aIndex = util.locToIndex(aLoc, aRank, aStrides);
|
const bLoc = loc.slice(-bRank);
|
bBroadcastDims.forEach(d => bLoc[d] = 0);
|
const bIndex = util.locToIndex(bLoc, bRank, bStrides);
|
result[i] = op(aVals[aIndex], bVals[bIndex]);
|
}
|
}
|
return [result, newShape];
|
};
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function complex(args) {
|
const { inputs, backend } = args;
|
const { real, imag } = inputs;
|
const realVals = backend.data.get(real.dataId).values;
|
const imagVals = backend.data.get(imag.dataId).values;
|
const complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
|
const complex = backend.data.get(complexInfo.dataId);
|
// The complex tensor owns the underlying real and imag tensorInfos, only the
|
// complex tensor tracks refCount, when complexData is disposed the
|
// underlying tensorData will be disposed.
|
complex.complexTensorInfos = {
|
real: backend.makeTensorInfo(real.shape, 'float32', realVals),
|
imag: backend.makeTensorInfo(imag.shape, 'float32', imagVals)
|
};
|
return complexInfo;
|
}
|
const complexConfig = {
|
kernelName: Complex,
|
backendName: 'cpu',
|
kernelFunc: complex
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/**
|
* Generates a tensorInfo with all zeros value.
|
* @param backend cpu backend.
|
* @param shape Shape for the zeros tensor.
|
* @param dtype Optional. If set, the result has this dtype.
|
*/
|
function zeros(backend, shape, dtype = 'float32') {
|
if (dtype === 'complex64') {
|
const real = zeros(backend, shape, 'float32');
|
const imag = zeros(backend, shape, 'float32');
|
return complex({ inputs: { real, imag }, backend });
|
}
|
const values = util.makeZerosTypedArray(util.sizeFromShape(shape), dtype);
|
return backend.makeTensorInfo(shape, dtype, values);
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function identity(args) {
|
const { inputs, backend } = args;
|
const { x } = inputs;
|
backend.incRef(x.dataId);
|
return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
|
}
|
const identityConfig = {
|
kernelName: Identity,
|
backendName: 'cpu',
|
kernelFunc: identity
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function real(args) {
|
const { inputs, backend } = args;
|
const { input } = inputs;
|
const real = backend.data.get(input.dataId).complexTensorInfos.real;
|
const realVal = backend.data.get(real.dataId).values;
|
// When complex tensor is disposed, its underlying parts will be disposed too.
|
// Make new tensor out of the real value of the complex. This makes sure the
|
// value is still accessible even if complex tensor is disposed.
|
return backend.makeTensorInfo(real.shape, real.dtype, realVal);
|
}
|
const realConfig = {
|
kernelName: Real,
|
backendName: 'cpu',
|
kernelFunc: real
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function castImpl(values, shape, inputType, dtype) {
|
if (dtype === 'int32') {
|
const resultValues = Int32Array.from(values);
|
return [shape, 'int32', resultValues];
|
}
|
if (dtype === 'bool') {
|
// This is essentially the result of notEqual(x, 0). We avoid using
|
// kernel notEqual to avoid circular dependency, i.e. binary_utils ->
|
// cast -> notEqual -> binary_utils.
|
const zero = util.toTypedArray([0], inputType);
|
const [resultData, resultShape] = createSimpleBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0)(shape, [], values, zero, 'bool');
|
return [resultShape, 'bool', resultData];
|
}
|
throw new Error(`Error in Cast: failed to cast ${inputType} to ${dtype}`);
|
}
|
function cast(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { dtype } = attrs;
|
// Casting to complex64.
|
if (dtype === 'complex64') {
|
if (x.dtype === 'complex64') {
|
return identity({ inputs: { x }, backend });
|
}
|
const zerosTensorInfo = zeros(backend, x.shape, x.dtype);
|
const floatX = cast({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
|
const result = complex({ inputs: { real: floatX, imag: zerosTensorInfo }, backend });
|
backend.disposeIntermediateTensorInfo(zerosTensorInfo);
|
backend.disposeIntermediateTensorInfo(floatX);
|
return result;
|
}
|
// Casting from complex64
|
if (x.dtype === 'complex64') {
|
const realPart = real({ inputs: { input: x }, backend });
|
const result = cast({ inputs: { x: realPart }, backend, attrs: { dtype } });
|
backend.disposeIntermediateTensorInfo(realPart);
|
return result;
|
}
|
if (!util.hasEncodingLoss(x.dtype, dtype)) {
|
// We don't change the underlying data, since we cast to higher
|
// precision.
|
const result = identity({ inputs: { x }, backend });
|
return { dataId: result.dataId, shape: result.shape, dtype };
|
}
|
const values = backend.data.get(x.dataId).values;
|
const [resultShape, resultType, resultData] = castImpl(values, x.shape, x.dtype, dtype);
|
return backend.makeTensorInfo(resultShape, resultType, resultData);
|
}
|
const castConfig = {
|
kernelName: Cast,
|
backendName: 'cpu',
|
kernelFunc: cast
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/**
|
* Template that creates a `KernelFunc` for binary ops.
|
* @param name Kernel name.
|
* @param binaryKernelImpl A `SimpleBinaryKernelImpl` for the kernel.
|
* @param binaryKernelComplexImpl Optional. If exists, represents a
|
* `ComplexBinaryKernelImpl` for the kernel, will be used when input dtype
|
* is `complex64`.
|
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
|
* result has the same dtype as the first input. This is mainly used in
|
* comparison kernels, such as Equal, Less, Greater, etc.
|
*/
|
function binaryKernelFunc(name, simpleImpl, complexImpl, dtype) {
|
if (complexImpl == null) {
|
return ({ inputs, backend }) => {
|
const { a, b } = inputs;
|
const cpuBackend = backend;
|
assertNotComplex([a, b], name);
|
const aVals = cpuBackend.data.get(a.dataId).values;
|
const bVals = cpuBackend.data.get(b.dataId).values;
|
const decodedAVals = a.dtype === 'string' ?
|
// tslint:disable-next-line: no-any
|
backend_util.fromUint8ToStringArray(aVals) :
|
aVals;
|
const decodedBVals = a.dtype === 'string' ?
|
// tslint:disable-next-line: no-any
|
backend_util.fromUint8ToStringArray(bVals) :
|
bVals;
|
const $dtype = dtype || a.dtype;
|
const [resultData, resultShape] = simpleImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);
|
return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
|
};
|
}
|
return ({ inputs, backend }) => {
|
const { a, b } = inputs;
|
const cpuBackend = backend;
|
if (a.dtype === 'complex64' || b.dtype === 'complex64') {
|
const $aComplex = cast({ inputs: { x: a }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
|
const $aComplexVals = cpuBackend.data.get($aComplex.dataId);
|
const aReal = $aComplexVals.complexTensorInfos.real;
|
const aImag = $aComplexVals.complexTensorInfos.imag;
|
const aRealVals = cpuBackend.data.get(aReal.dataId).values;
|
const aImagVals = cpuBackend.data.get(aImag.dataId).values;
|
const $bComplex = cast({ inputs: { x: b }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
|
const $bComplexVals = cpuBackend.data.get($bComplex.dataId);
|
const bReal = $bComplexVals.complexTensorInfos.real;
|
const bImag = $bComplexVals.complexTensorInfos.imag;
|
const bRealVals = cpuBackend.data.get(bReal.dataId).values;
|
const bImagVals = cpuBackend.data.get(bImag.dataId).values;
|
const [resultRealData, resultImagData, resultShape] = complexImpl(a.shape, b.shape, aRealVals, aImagVals, bRealVals, bImagVals);
|
const resultReal = cpuBackend.makeTensorInfo(resultShape, 'float32', resultRealData);
|
const resultImag = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImagData);
|
const result = complex({ inputs: { real: resultReal, imag: resultImag }, backend: cpuBackend });
|
cpuBackend.disposeIntermediateTensorInfo($aComplex);
|
cpuBackend.disposeIntermediateTensorInfo($bComplex);
|
cpuBackend.disposeIntermediateTensorInfo(resultReal);
|
cpuBackend.disposeIntermediateTensorInfo(resultImag);
|
return result;
|
}
|
else {
|
const aVals = cpuBackend.data.get(a.dataId).values;
|
const bVals = cpuBackend.data.get(b.dataId).values;
|
const $dtype = dtype || a.dtype;
|
const [resultData, resultShape] = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype);
|
return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
|
}
|
};
|
}
|
/**
|
* Template that creates the complex type implementation for binary ops.
|
* Supports broadcast.
|
*/
|
function createComplexBinaryKernelImpl(op) {
|
return (aShape, bShape, aRealVals, aImagVals, bRealVals, bImagVals) => {
|
const resultShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
|
const resultSize = util.sizeFromShape(resultShape);
|
const resultRank = resultShape.length;
|
const resultStrides = util.computeStrides(resultShape);
|
const resultRealVals = util.getTypedArrayFromDType('float32', resultSize);
|
const resultImagVals = util.getTypedArrayFromDType('float32', resultSize);
|
const aBroadcastDims = backend_util.getBroadcastDims(aShape, resultShape);
|
const bBroadcastDims = backend_util.getBroadcastDims(bShape, resultShape);
|
const aVals = backend_util.mergeRealAndImagArrays(aRealVals, aImagVals);
|
const bVals = backend_util.mergeRealAndImagArrays(bRealVals, bImagVals);
|
const aRank = aShape.length;
|
const aStrides = util.computeStrides(aShape);
|
const bRank = bShape.length;
|
const bStrides = util.computeStrides(bShape);
|
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
|
for (let i = 0; i < resultRealVals.length; i++) {
|
const aIdx = i % aVals.length;
|
const bIdx = i % bVals.length;
|
const result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]);
|
resultRealVals[i] = result.real;
|
resultImagVals[i] = result.imag;
|
}
|
}
|
else {
|
for (let i = 0; i < resultRealVals.length; i++) {
|
const loc = util.indexToLoc(i, resultRank, resultStrides);
|
const aLoc = loc.slice(-aRank);
|
aBroadcastDims.forEach(d => aLoc[d] = 0);
|
const aIndex = util.locToIndex(aLoc, aRank, aStrides);
|
const bLoc = loc.slice(-bRank);
|
bBroadcastDims.forEach(d => bLoc[d] = 0);
|
const bIndex = util.locToIndex(bLoc, bRank, bStrides);
|
const opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]);
|
resultRealVals[i] = opResult.real;
|
resultImagVals[i] = opResult.imag;
|
}
|
}
|
return [resultRealVals, resultImagVals, resultShape];
|
};
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const addImpl = createSimpleBinaryKernelImpl(((a, b) => a + b));
|
const addComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
|
return { real: aReal + bReal, imag: aImag + bImag };
|
}));
|
const add = binaryKernelFunc(Add, addImpl, addComplexImpl);
|
const addConfig = {
|
kernelName: Add,
|
backendName: 'cpu',
|
kernelFunc: add
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function bincountImpl(xVals, weightsVals, weightsDtype, weightsShape, size) {
|
const weightsSize = util.sizeFromShape(weightsShape);
|
const outVals = util.makeZerosTypedArray(size, weightsDtype);
|
for (let i = 0; i < xVals.length; i++) {
|
const value = xVals[i];
|
if (value < 0) {
|
throw new Error('Input x must be non-negative!');
|
}
|
if (value >= size) {
|
continue;
|
}
|
if (weightsSize > 0) {
|
outVals[value] += weightsVals[i];
|
}
|
else {
|
outVals[value] += 1;
|
}
|
}
|
return outVals;
|
}
|
function bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput = false) {
|
const numRows = xBuf.shape[0];
|
const numCols = xBuf.shape[1];
|
const outBuf = buffer([numRows, size], weightsBuf.dtype);
|
for (let i = 0; i < numRows; i++) {
|
for (let j = 0; j < numCols; j++) {
|
const value = xBuf.get(i, j);
|
if (value < 0) {
|
throw new Error('Input x must be non-negative!');
|
}
|
if (value >= size) {
|
continue;
|
}
|
if (binaryOutput) {
|
outBuf.set(1, i, value);
|
}
|
else {
|
if (weightsBuf.size > 0) {
|
outBuf.set(outBuf.get(i, value) + weightsBuf.get(i, j), i, value);
|
}
|
else {
|
outBuf.set(outBuf.get(i, value) + 1, i, value);
|
}
|
}
|
}
|
}
|
return outBuf;
|
}
|
|
/**
|
* @license
|
* Copyright 2023 Google LLC.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const bitwiseAndImpl = createSimpleBinaryKernelImpl(((a, b) => a & b));
|
const bitwiseAnd = binaryKernelFunc(BitwiseAnd, bitwiseAndImpl);
|
const bitwiseAndConfig = {
|
kernelName: BitwiseAnd,
|
backendName: 'cpu',
|
kernelFunc: bitwiseAnd
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/**
|
* Template that creates implementation for unary op.
|
*/
|
function createSimpleUnaryImpl(op) {
|
return (values, dtype, attrs) => {
|
const newValues = util.getArrayFromDType(dtype, values.length);
|
for (let i = 0; i < values.length; ++i) {
|
newValues[i] = op(values[i], attrs);
|
}
|
return newValues;
|
};
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/**
|
* Template that creates a `KernelFunc` for unary ops.
|
* @param name Kernel name.
|
* @param op A `SimpleUnaryOperation` for the kernel.
|
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
|
* result has the same dtype as the input. This is mainly used in certain
|
* kernels that return bool type, such as isFinite, isInf, etc.
|
*/
|
function unaryKernelFunc(name, op, dtype) {
|
const impl = createSimpleUnaryImpl(op);
|
return unaryKernelFuncFromImpl(name, impl, dtype);
|
}
|
/**
|
* Template that creates a `KernelFunc` for unary ops from the given
|
* `SimpleUnaryImpl`..
|
* @param name Kernel name.
|
* @param unaryImpl A `SimpleUnaryImpl` that implements the op.
|
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
|
* result has the same dtype as the input. This is mainly used in certain
|
* kernels that return bool type, such as isFinite, isInf, etc.
|
*/
|
function unaryKernelFuncFromImpl(name, unaryImpl, dtype) {
|
return ({ inputs, attrs, backend }) => {
|
const { x } = inputs;
|
assertNotComplex(x, name);
|
const cpuBackend = backend;
|
const values = cpuBackend.data.get(x.dataId).values;
|
let decoded;
|
if (x.dtype === 'string') {
|
if (!Array.isArray(values)) {
|
throw new Error('String tensor\'s value was not an instance of Array');
|
}
|
decoded = backend_util.fromUint8ToStringArray(values);
|
}
|
else {
|
decoded = values;
|
}
|
const $dtype = dtype || x.dtype;
|
const newValues = unaryImpl(decoded, $dtype, attrs);
|
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
|
};
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const ceilImpl = createSimpleUnaryImpl((xi) => Math.ceil(xi));
|
const ceil = unaryKernelFuncFromImpl(Ceil, ceilImpl);
|
const ceilConfig = {
|
kernelName: Ceil,
|
backendName: 'cpu',
|
kernelFunc: ceil,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function concatImpl(inputs, outShape, dtype, simplyConcat) {
|
const outVals = util.getArrayFromDType(dtype, util.sizeFromShape(outShape));
|
if (simplyConcat && dtype !== 'string') {
|
// Use built-in TypedArray.set() method for speed.
|
let offset = 0;
|
inputs.forEach(input => {
|
const size = util.sizeFromShape(input.shape);
|
outVals.set(input.vals, offset);
|
offset += size;
|
});
|
}
|
else {
|
let colOffset = 0;
|
inputs.forEach(input => {
|
const decodedData = dtype === 'string' ?
|
backend_util.fromUint8ToStringArray(input.vals) :
|
input.vals;
|
let tIdx = 0;
|
for (let row = 0; row < input.shape[0]; ++row) {
|
const resIdx = row * outShape[1] + colOffset;
|
for (let col = 0; col < input.shape[1]; ++col) {
|
outVals[resIdx + col] = decodedData[tIdx++];
|
}
|
}
|
colOffset += input.shape[1];
|
});
|
}
|
return outVals;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const equalImpl = createSimpleBinaryKernelImpl((a, b) => (a === b) ? 1 : 0);
|
const equal = binaryKernelFunc(Equal, equalImpl, null /* complexImpl */, 'bool');
|
const equalConfig = {
|
kernelName: Equal,
|
backendName: 'cpu',
|
kernelFunc: equal
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const expImpl = createSimpleUnaryImpl((xi) => Math.exp(xi));
|
const exp = unaryKernelFuncFromImpl(Exp, expImpl, 'float32');
|
const expConfig = {
|
kernelName: Exp,
|
backendName: 'cpu',
|
kernelFunc: exp,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const expm1Impl = createSimpleUnaryImpl((xi) => Math.expm1(xi));
|
const expm1 = unaryKernelFuncFromImpl(Expm1, expm1Impl);
|
const expm1Config = {
|
kernelName: Expm1,
|
backendName: 'cpu',
|
kernelFunc: expm1,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const floorImpl = createSimpleUnaryImpl((xi) => Math.floor(xi));
|
const floor = unaryKernelFuncFromImpl(Floor, floorImpl);
|
const floorConfig = {
|
kernelName: Floor,
|
backendName: 'cpu',
|
kernelFunc: floor,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const floorDivImpl = createSimpleBinaryKernelImpl((a, b) => Math.floor(a / b));
|
const floorDiv = binaryKernelFunc(FloorDiv, floorDivImpl, null /* complexImpl */, 'int32');
|
const floorDivConfig = {
|
kernelName: FloorDiv,
|
backendName: 'cpu',
|
kernelFunc: floorDiv
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function gatherNdImpl(indicesData, paramsBuf, dtype, numSlices, sliceRank, sliceSize, strides, paramsShape, paramsSize) {
|
const outBuf = buffer([numSlices, sliceSize], dtype);
|
for (let i = 0; i < numSlices; i++) {
|
const index = [];
|
let flattenIndex = 0;
|
for (let j = 0; j < sliceRank; j++) {
|
const dim = indicesData[i * sliceRank + j];
|
flattenIndex += dim * strides[j];
|
index.push(dim);
|
}
|
if (flattenIndex < 0 || flattenIndex >= paramsSize / sliceSize) {
|
throw new Error(`Invalid indices: ${index} does not index into ${paramsShape}`);
|
}
|
for (let k = 0; k < sliceSize; k++) {
|
outBuf.values[i * sliceSize + k] =
|
paramsBuf.get(...paramsBuf.indexToLoc(flattenIndex * sliceSize + k));
|
}
|
}
|
return outBuf;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function gatherV2Impl(xBuf, indicesBuf, flattenOutputShape) {
|
const outBuf = buffer(flattenOutputShape, xBuf.dtype);
|
for (let i = 0; i < outBuf.size; ++i) {
|
const newLoc = outBuf.indexToLoc(i);
|
const originalLoc = newLoc.slice();
|
const batchIdx = originalLoc[0];
|
const indicesIdx = originalLoc[2];
|
const indicesIndex = indicesBuf.locToIndex([batchIdx, indicesIdx]);
|
originalLoc[2] = indicesBuf.values[indicesIndex];
|
const originalIndex = xBuf.locToIndex(originalLoc);
|
if (0 <= originalIndex && originalIndex < xBuf.values.length) {
|
outBuf.values[i] = xBuf.values[originalIndex];
|
} // Else, index is out of bounds, so leave the default zero val in outBuf.
|
}
|
return outBuf;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const greaterImpl = createSimpleBinaryKernelImpl((a, b) => (a > b) ? 1 : 0);
|
const greater = binaryKernelFunc(Greater, greaterImpl, null /* complexImpl */, 'bool');
|
const greaterConfig = {
|
kernelName: Greater,
|
backendName: 'cpu',
|
kernelFunc: greater
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const greaterEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a >= b) ? 1 : 0);
|
const greaterEqual = binaryKernelFunc(GreaterEqual, greaterEqualImpl, null /* complexImpl */, 'bool');
|
const greaterEqualConfig = {
|
kernelName: GreaterEqual,
|
backendName: 'cpu',
|
kernelFunc: greaterEqual
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const lessImpl = createSimpleBinaryKernelImpl((a, b) => (a < b) ? 1 : 0);
|
const less = binaryKernelFunc(Less, lessImpl, null /* complexImpl */, 'bool');
|
const lessConfig = {
|
kernelName: Less,
|
backendName: 'cpu',
|
kernelFunc: less
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const lessEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a <= b) ? 1 : 0);
|
const lessEqual = binaryKernelFunc(LessEqual, lessEqualImpl, null /* complexImpl */, 'bool');
|
const lessEqualConfig = {
|
kernelName: LessEqual,
|
backendName: 'cpu',
|
kernelFunc: lessEqual
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function linSpaceImpl(start, stop, num) {
|
const step = (stop - start) / (num - 1);
|
const values = util.makeZerosTypedArray(num, 'float32');
|
values[0] = start;
|
for (let i = 1; i < values.length; i++) {
|
values[i] = values[i - 1] + step;
|
}
|
return values;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const logImpl = createSimpleUnaryImpl((xi) => Math.log(xi));
|
const log = unaryKernelFuncFromImpl(Log, logImpl);
|
const logConfig = {
|
kernelName: Log,
|
backendName: 'cpu',
|
kernelFunc: log,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function maxImpl(aVals, reduceSize, outShape, dtype) {
|
const vals = util.getTypedArrayFromDType(dtype, util.sizeFromShape(outShape));
|
for (let i = 0; i < vals.length; ++i) {
|
const offset = i * reduceSize;
|
let max = aVals[offset];
|
for (let j = 0; j < reduceSize; ++j) {
|
const value = aVals[offset + j];
|
if (Number.isNaN(value) ||
|
value > max) { // comparison with NaN always return false
|
max = value;
|
}
|
}
|
vals[i] = max;
|
}
|
return vals;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const maximumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.max(aValue, bValue)));
|
const maximum = binaryKernelFunc(Maximum, maximumImpl);
|
const maximumConfig = {
|
kernelName: Maximum,
|
backendName: 'cpu',
|
kernelFunc: maximum
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const minimumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.min(aValue, bValue)));
|
const minimum = binaryKernelFunc(Minimum, minimumImpl);
|
const minimumConfig = {
|
kernelName: Minimum,
|
backendName: 'cpu',
|
kernelFunc: minimum
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const multiplyImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue * bValue));
|
const multiplyComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
|
return {
|
real: aReal * bReal - aImag * bImag,
|
imag: aReal * bImag + aImag * bReal
|
};
|
}));
|
const multiply = binaryKernelFunc(Multiply, multiplyImpl, multiplyComplexImpl);
|
const multiplyConfig = {
|
kernelName: Multiply,
|
backendName: 'cpu',
|
kernelFunc: multiply
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function negImpl(xVals, xShape, xDtype) {
|
const minusOne = util.createScalarValue(-1, xDtype);
|
return multiplyImpl([], xShape, minusOne, xVals, xDtype);
|
}
|
function neg(args) {
|
const { inputs, backend } = args;
|
const { x } = inputs;
|
assertNotComplex(x, 'neg');
|
const xVals = backend.data.get(x.dataId).values;
|
const [res, newShape] = negImpl(xVals, x.shape, x.dtype);
|
return backend.makeTensorInfo(newShape, x.dtype, res);
|
}
|
const negConfig = {
|
kernelName: Neg,
|
backendName: 'cpu',
|
kernelFunc: neg
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const notEqualImpl = createSimpleBinaryKernelImpl(((a, b) => (a !== b) ? 1 : 0));
|
const notEqual = binaryKernelFunc(NotEqual, notEqualImpl, null /* complexOp */, 'bool');
|
const notEqualConfig = {
|
kernelName: NotEqual,
|
backendName: 'cpu',
|
kernelFunc: notEqual
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function transposeImpl(xVals, xShape, dtype, perm, newShape) {
|
const xRank = xShape.length;
|
const xSize = util.sizeFromShape(xShape);
|
const xStrides = util.computeStrides(xShape);
|
const newStrides = util.computeStrides(newShape);
|
const result = util.getTypedArrayFromDType(dtype, util.sizeFromShape(newShape));
|
for (let i = 0; i < xSize; ++i) {
|
const loc = util.indexToLoc(i, xRank, xStrides);
|
// Permute location.
|
const newLoc = new Array(loc.length);
|
for (let i = 0; i < newLoc.length; i++) {
|
newLoc[i] = loc[perm[i]];
|
}
|
const newIndex = util.locToIndex(newLoc, xRank, newStrides);
|
result[newIndex] = xVals[i];
|
}
|
return result;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function transpose(args) {
|
const { inputs, attrs, backend } = args;
|
const { x } = inputs;
|
const { perm } = attrs;
|
assertNotComplex(x, 'transpose');
|
const xRank = x.shape.length;
|
const newShape = new Array(xRank);
|
for (let i = 0; i < newShape.length; i++) {
|
newShape[i] = x.shape[perm[i]];
|
}
|
const values = backend.data.get(x.dataId).values;
|
const result = transposeImpl(values, x.shape, x.dtype, perm, newShape);
|
const dataId = backend.write(result, newShape, x.dtype);
|
return { dataId, shape: newShape, dtype: x.dtype };
|
}
|
const transposeConfig = {
|
kernelName: Transpose,
|
backendName: 'cpu',
|
kernelFunc: transpose
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function prodImpl(xShape, xDtype, xVals, reductionAxes) {
|
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(xShape, reductionAxes);
|
const outDtype = upcastType(xDtype, 'int32');
|
const outVals = util.makeZerosTypedArray(util.sizeFromShape(outShape), outDtype);
|
const reduceSize = util.sizeFromShape(reduceShape);
|
for (let i = 0; i < outVals.length; ++i) {
|
const offset = i * reduceSize;
|
let prod = 1;
|
for (let j = 0; j < reduceSize; ++j) {
|
prod *= xVals[offset + j];
|
}
|
outVals[i] = prod;
|
}
|
return { outVals, outShape, outDtype };
|
}
|
function prod(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { axis, keepDims } = attrs;
|
assertNotComplex(x, 'prod');
|
const xRank = x.shape.length;
|
const axes = util.parseAxisParam(axis, x.shape);
|
const permutation = backend_util.getAxesPermutation(axes, xRank);
|
let reductionAxes = axes;
|
let permutedX = x;
|
const intermediateTensorInfos = [];
|
if (permutation != null) {
|
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } });
|
intermediateTensorInfos.push(permutedX);
|
reductionAxes = backend_util.getInnerMostAxes(reductionAxes.length, xRank);
|
}
|
const xVals = backend.data.get(permutedX.dataId).values;
|
const { outVals, outShape, outDtype } = prodImpl(permutedX.shape, permutedX.dtype, xVals, reductionAxes);
|
let resultShape = outShape;
|
if (keepDims) {
|
resultShape = backend_util.expandShapeToKeepDim(outShape, axes);
|
}
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
return backend.makeTensorInfo(resultShape, outDtype, outVals);
|
}
|
const prodConfig = {
|
kernelName: Prod,
|
backendName: 'cpu',
|
kernelFunc: prod
|
};
|
|
/**
|
* @license
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function validateIndices(indices, indicesShape, numParams) {
|
indices.forEach((index, i) => {
|
if (index < 0 || index >= numParams) {
|
const locString = util.indexToLoc(i, indicesShape.length, util.computeStrides(indicesShape))
|
.join(',');
|
throw new Error(`indices[${locString}] = ${index} is not in [0, ${numParams})`);
|
}
|
});
|
}
|
function validateSplits(paramsNestedSplits, numParamsDenseValues) {
|
// Validate
|
for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {
|
const splits = paramsNestedSplits[dim];
|
const lastSplit = (dim === paramsNestedSplits.length - 1) ?
|
numParamsDenseValues :
|
paramsNestedSplits[dim + 1].length;
|
if (splits.length === 0) {
|
throw new Error('Ragged splits may not be empty');
|
}
|
if (splits[0] < 0) {
|
throw new Error('Ragged splits must be non-negative');
|
}
|
if (splits[splits.length - 1] > lastSplit) {
|
throw new Error('Ragged splits must not point past values');
|
}
|
for (let i = 1; i < splits.length; ++i) {
|
if (splits[i - 1] > splits[i]) {
|
throw new Error('Ragged splits must be sorted in ascending order');
|
}
|
}
|
}
|
}
|
// Construct the `splits` output tensors, encoded using a nested vector.
|
// Also find the slices of values that need to be copied, and store them
|
// in `valueSlices`. The total number of values that will be copied (which
|
// we need for allocating the output values tensor) is stored in `numValues`.
|
function makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues) {
|
const valueSlices = [];
|
let numValues = 0;
|
const numSplits = indicesShape.length - 1 + paramsNestedSplits.length;
|
const outSplits = new Array(numSplits).fill(null).map(() => [0]);
|
validateSplits(paramsNestedSplits, numParamsDenseValues);
|
// Add `splits` that come from all but the last dimension of the dense
|
// Tensor `indices`. In particular, for each dimension D, we add a
|
// splits tensor whose values are:
|
// range(reduceProd(splits.shape[:D]) + 1) * splits.shape[D+1]
|
// E.g., if indices.shape=[2, 3, 4] then we will add splits tensors:
|
// [0, 3, 6] # length=2+1, stride=3
|
// [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4
|
let nrows = 1;
|
for (let dim = 0; dim < indicesShape.length - 1; ++dim) {
|
nrows *= indicesShape[dim];
|
const rowLength = indicesShape[dim + 1];
|
for (let i = 1; i < nrows + 1; ++i) {
|
outSplits[dim].push(i * rowLength);
|
}
|
}
|
// Add `splits` that come from `paramsNestedSplits`. Starting with the
|
// outermost ragged dimension (i.e., the first `splits` tensor), we work
|
// our way in, finding the range of values that should be copied. As we
|
// go, we update the output `splits` for each dimension with the appropriate
|
// values. In particular, the *lengths* of the slices from `param_splits`
|
// should be copied to generate corresponding slice lengths in the output
|
// splits. E.g., if we are copying a ragged row with length 4, then we
|
// should add a new split point to outSplits that is 4 greater than the
|
// previous split point in outSplits.
|
for (let i = 0; i < indices.length; ++i) {
|
let start = indices[i];
|
let limit = indices[i] + 1;
|
// Copy splits.
|
for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {
|
const splits = paramsNestedSplits[dim];
|
const outDim = dim + indicesShape.length - 1;
|
if (outDim >= 0) {
|
const outSplitsOutDim = outSplits[outDim];
|
const delta = outSplitsOutDim[outSplitsOutDim.length - 1] - splits[start];
|
for (let j = start; j < limit; ++j) {
|
outSplits[outDim].push(splits[j + 1] + delta);
|
}
|
}
|
start = splits[start];
|
limit = splits[limit];
|
}
|
if (limit !== start) {
|
valueSlices.push([start, limit]);
|
numValues += limit - start;
|
}
|
}
|
return { outSplits, valueSlices, numValues };
|
}
|
function getSplits(outSplits) {
|
const splitsOut = [];
|
for (let i = 0; i < outSplits.length; ++i) {
|
const numSplits = outSplits[i].length;
|
const splits = util.getArrayFromDType('int32', numSplits);
|
splitsOut.push(splits);
|
outSplits[i].forEach((value, j) => splits[j] = value);
|
}
|
return splitsOut;
|
}
|
function computeFlatOuterDims(orig, numOutDims) {
|
const outDims = orig.slice(0, numOutDims);
|
while (outDims.length < numOutDims) {
|
outDims.push(1);
|
}
|
for (let inDim = numOutDims; inDim < orig.length; inDim++) {
|
outDims[numOutDims - 1] *= orig[inDim];
|
}
|
return outDims;
|
}
|
// For each slice in `(start, limit)` in `valueSlices`, append
|
// `paramsDenseValues[start,...,limit] to `values`. `valueSize` indicates
|
// the number of scalars contained in each value paramsDenseValues[i].
|
function writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, values, valuesShape) {
|
const denseM = computeFlatOuterDims(paramsDenseValuesShape, 2)[1];
|
const valuesM = computeFlatOuterDims(valuesShape, 2)[1];
|
let outPos = 0;
|
for (const slice of valueSlices) {
|
for (let i = slice[0]; i < slice[1]; ++i) {
|
for (let j = 0; j < valueSize; ++j) {
|
values[outPos * valuesM + j] = paramsDenseValues[i * denseM + j];
|
}
|
++outPos;
|
}
|
}
|
}
|
function getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues) {
|
const valuesShape = paramsDenseValuesShape.slice();
|
valuesShape[0] = numValues;
|
const valuesOut = util.getArrayFromDType(paramsDenseValuesDType, util.sizeFromShape(valuesShape));
|
const numElements = paramsDenseValues.length;
|
const valueSize = numElements === 0 ? 0 : (numElements / paramsDenseValuesShape[0]);
|
writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, valuesOut, valuesShape);
|
return [valuesOut, valuesShape];
|
}
|
function raggedGatherImpl(paramsNestedSplits, paramsNestedSplitsShapes, paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, indices, indicesShape, outputRaggedRank) {
|
if (paramsNestedSplits.length === 0) {
|
throw new Error('paramsNestedSplits must be non empty');
|
}
|
if (paramsNestedSplitsShapes[0].length === 0) {
|
throw new Error('Split tensors must not be scalars');
|
}
|
const numParams = paramsNestedSplitsShapes[0][0] - 1;
|
validateIndices(indices, indicesShape, numParams);
|
if (paramsDenseValuesShape.length === 0) {
|
throw new Error('params.rank must be nonzero');
|
}
|
const numParamsDenseValues = paramsDenseValuesShape[0];
|
// Calculate the `splits`, and store the value slices that we need to
|
// copy in `valueSlices`.
|
const { outSplits, valueSlices, numValues } = makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues);
|
// Write the output tensors.
|
const outputNestedSplits = getSplits(outSplits);
|
const outputDenseValues = getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues);
|
return [outputNestedSplits, outputDenseValues[0], outputDenseValues[1]];
|
}
|
|
/**
|
* @license
|
* Copyright 2022 Google LLC.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const INT32_MAX = 2147483647;
|
function raggedRangeImpl(starts, startsShape, startsDType, limits, limitsShape, deltas, deltasShape) {
|
// Check input tensor shapes.
|
if (startsShape.length > 1) {
|
throw new Error('starts must be a scalar or vector');
|
}
|
if (limitsShape.length > 1) {
|
throw new Error('limits must be a scalar or vector');
|
}
|
if (deltasShape.length > 1) {
|
throw new Error('deltas must be a scalar or vector');
|
}
|
// Determine which tensors we need to broadcast.
|
const broadcastStarts = startsShape.length === 0;
|
const broadcastLimits = limitsShape.length === 0;
|
const broadcastDeltas = deltasShape.length === 0;
|
// nRows (number of output rows) is the size of the non-broadcast inputs,
|
// or 1 if all inputs are scalars.
|
const inSizes = [];
|
if (!broadcastStarts) {
|
inSizes.push(startsShape[0]);
|
}
|
if (!broadcastLimits) {
|
inSizes.push(limitsShape[0]);
|
}
|
if (!broadcastDeltas) {
|
inSizes.push(deltasShape[0]);
|
}
|
for (let i = 1; i < inSizes.length; ++i) {
|
if (inSizes[i] !== inSizes[i - 1]) {
|
throw new Error('starts, limits, and deltas must have the same shape');
|
}
|
}
|
const nRows = inSizes.length === 0 ? 1 : inSizes[0];
|
// Construct the rtNestedSplits tensor.
|
const rtNestedSplits = util.getArrayFromDType('int32', nRows + 1);
|
rtNestedSplits[0] = 0;
|
for (let row = 0; row < nRows; ++row) {
|
const start = broadcastStarts ? starts[0] : starts[row];
|
const limit = broadcastLimits ? limits[0] : limits[row];
|
const delta = broadcastDeltas ? deltas[0] : deltas[row];
|
if (delta === 0) {
|
throw new Error('Requires delta != 0');
|
}
|
let size; // The number of elements in the specified range.
|
if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) {
|
size = 0;
|
}
|
else {
|
size = Math.ceil(Math.abs((limit - start) / delta));
|
if (size > INT32_MAX) {
|
throw new Error(`Requires ((limit - start) / delta) <= ${INT32_MAX}`);
|
}
|
}
|
rtNestedSplits[row + 1] = rtNestedSplits[row] + size;
|
}
|
const nVals = rtNestedSplits[nRows];
|
// Construct the rtDenseValues tensor.
|
const rtDenseValues = util.getArrayFromDType(startsDType, nVals);
|
let valueIndex = 0;
|
for (let row = 0; row < nRows; ++row) {
|
const rowSize = rtNestedSplits[row + 1] - rtNestedSplits[row];
|
let value = broadcastStarts ? starts[0] : starts[row];
|
const delta = broadcastDeltas ? deltas[0] : deltas[row];
|
for (let i = 0; i < rowSize; ++i) {
|
rtDenseValues[valueIndex++] = value;
|
value += delta;
|
}
|
}
|
return [rtNestedSplits, rtDenseValues];
|
}
|
|
/**
|
* @license
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
var RowPartitionType = backend_util.RowPartitionType;
|
// Based on
|
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
|
class RaggedTensorToTensorOp {
|
constructor(shape, shapeShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypeStrings) {
|
this.shape = shape;
|
this.shapeShape = shapeShape;
|
this.values = values;
|
this.valuesShape = valuesShape;
|
this.valuesDType = valuesDType;
|
this.defaultValue = defaultValue;
|
this.defaultValueShape = defaultValueShape;
|
this.rowPartitionValues = rowPartitionValues;
|
this.rowPartitionValuesShapes = rowPartitionValuesShapes;
|
this.rowPartitionTypes =
|
backend_util.getRowPartitionTypesHelper(rowPartitionTypeStrings);
|
this.raggedRank = backend_util.getRaggedRank(this.rowPartitionTypes);
|
}
|
getRowPartitionTypeByDimension(dimension) {
|
if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
|
return this.rowPartitionTypes[dimension + 1];
|
}
|
else {
|
return this.rowPartitionTypes[dimension];
|
}
|
}
|
// Returns the relationship between dimension and dimension + 1.
|
getRowPartitionTensor(dimension) {
|
if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
|
return this.rowPartitionValues[dimension + 1];
|
}
|
else {
|
return this.rowPartitionValues[dimension];
|
}
|
}
|
getMaxWidth(dimension) {
|
const rowPartitionTensor = this.getRowPartitionTensor(dimension - 1);
|
switch (this.getRowPartitionTypeByDimension(dimension - 1)) {
|
case RowPartitionType.VALUE_ROWIDS:
|
return RaggedTensorToTensorOp.getMaxWidthValueRowID(rowPartitionTensor);
|
case RowPartitionType.ROW_SPLITS:
|
return RaggedTensorToTensorOp.getMaxWidthRowSplit(rowPartitionTensor);
|
default:
|
throw new Error(`Cannot handle partition type ${RowPartitionType[this.getRowPartitionTypeByDimension(dimension - 1)]}`);
|
}
|
}
|
static getMaxWidthRowSplit(rowSplit) {
|
const tensorLength = rowSplit.length;
|
if (tensorLength === 0 || tensorLength === 1) {
|
return 0;
|
}
|
let maxWidth = 0;
|
for (let i = 0; i < tensorLength - 1; ++i) {
|
const currentWidth = rowSplit[i + 1] - rowSplit[i];
|
if (currentWidth > maxWidth) {
|
maxWidth = currentWidth;
|
}
|
}
|
return maxWidth;
|
}
|
static getMaxWidthValueRowID(valueRowIds) {
|
const indexLength = valueRowIds.length;
|
if (indexLength === 0) {
|
return 0;
|
}
|
let firstEqualIndex = 0;
|
let firstEqualIndexValue = valueRowIds[0];
|
let maxWidth = 0;
|
for (let i = 1; i < indexLength; ++i) {
|
const value = valueRowIds[i];
|
if (value !== firstEqualIndexValue) {
|
firstEqualIndexValue = value;
|
maxWidth = Math.max(i - firstEqualIndex, maxWidth);
|
firstEqualIndex = i;
|
}
|
}
|
return Math.max(indexLength - firstEqualIndex, maxWidth);
|
}
|
tensorShapeFromTensor(t, tShape, isPartial = true) {
|
if (tShape.length === 0) {
|
if (t[0] === -1) {
|
return [];
|
}
|
throw new Error(`The only valid scalar shape tensor is the fully unknown shape specified as -1.`);
|
}
|
// MakePartialShape/MakeShapeHelper.
|
return makeShape(t, isPartial);
|
}
|
calculateOutputSize(firstDim) {
|
const valueShape = this.valuesShape;
|
const defaultValueShape = this.defaultValueShape;
|
backend_util.validateDefaultValueShape(defaultValueShape, valueShape);
|
const shape = this.tensorShapeFromTensor(this.shape, this.shapeShape);
|
const outputShape = backend_util.combineRaggedTensorToTensorShapes(this.raggedRank, shape, valueShape);
|
const result = outputShape;
|
if (result[0] < 0) {
|
result[0] = firstDim;
|
}
|
for (let i = 1; i <= this.raggedRank; ++i) {
|
if (result[i] < 0) {
|
result[i] = this.getMaxWidth(i);
|
}
|
}
|
return result;
|
}
|
/**
|
* The outputIndex represents the index in the output tensor
|
* where the first element of a particular dimension would be written.
|
* If it is -1, it indicates that the index is out of scope.
|
* Example, given firstDimension = 10, firstDimensionOutput = 6,
|
* and outputIndexMultiplier = 100:
|
* result = [0 100 200 300 400 500 -1 -1 -1 -1]
|
* If firstDimensionOutput = 11 instead, then:
|
* result = [0 100 200 300 400 500 600 700 800 900]
|
*/
|
calculateFirstParentOutputIndex(firstDimension, outputIndexMultiplier, firstDimensionOutput) {
|
const minDimension = Math.min(firstDimension, firstDimensionOutput);
|
const result = [];
|
let currentOutputIndex = 0;
|
for (let i = 0; i < minDimension; ++i, currentOutputIndex += outputIndexMultiplier) {
|
result.push(currentOutputIndex);
|
}
|
for (let i = minDimension; i < firstDimension; ++i) {
|
result.push(-1);
|
}
|
util.assert(result.length === firstDimension, () => 'Final length of result must be equal to firstDimension.');
|
return result;
|
}
|
calculateOutputIndexRowSplit(rowSplit, parentOutputIndex, outputIndexMultiplier, outputSize) {
|
const rowSplitSize = rowSplit.length;
|
const result = [];
|
for (let i = 0; i < rowSplitSize - 1; ++i) {
|
const rowLength = rowSplit[i + 1] - rowSplit[i];
|
let realLength = Math.min(outputSize, rowLength);
|
let parentOutputIndexCurrent = parentOutputIndex[i];
|
if (parentOutputIndexCurrent === -1) {
|
realLength = 0;
|
}
|
for (let j = 0; j < realLength; ++j) {
|
result.push(parentOutputIndexCurrent);
|
parentOutputIndexCurrent += outputIndexMultiplier;
|
}
|
for (let j = 0; j < rowLength - realLength; ++j) {
|
result.push(-1);
|
}
|
}
|
if (rowSplitSize > 0 && result.length !== rowSplit[rowSplitSize - 1]) {
|
throw new Error('Invalid row split size.');
|
}
|
return result;
|
}
|
// Calculate the output index of the first element of a list.
|
// The parentOutputIndex is the same computation for the previous list.
|
// -1 indicates an element or list that is out of range.
|
// The outputIndexMultiplier is the number of output indices one moves
|
// forward for each column.
|
// E.g., given:
|
// valueRowIds:[0 1 2 2 2 3 5 5 6]
|
// parentOutputIndex:[1000 1100 2000 2100 -1 3000 4000]
|
// outputIndexMultiplier: 10
|
// outputSize: 2
|
// You get:
|
// result = [1000 1100 2000 2010 -1 2100 -1 -1 3000]
|
// result[0] = parentOutputIndex[valueRowIds[0]]
|
// result[1] = parentOutputIndex[valueRowIds[1]]
|
// result[2] = parentOutputIndex[valueRowIds[2]]
|
// result[3] = parentOutputIndex[valueRowIds[2] + 10]
|
// result[4] = -1 because it is the third element the size is 2.
|
// result[5] = parentOutputIndex[valueRowIds[3]]
|
// result[6] = -1 because parentOutputIndex[valueRowIds[6]] == -1
|
// result[7] = -1 because parentOutputIndex[valueRowIds[6]] == -1
|
// result[8] = parentOutputIndex[valueRowIds[7]]
|
calculateOutputIndexValueRowID(valueRowIds, parentOutputIndex, outputIndexMultiplier, outputSize) {
|
const indexSize = valueRowIds.length;
|
const result = [];
|
if (indexSize === 0) {
|
return [];
|
}
|
let currentOutputColumn = 0;
|
let currentValueRowId = valueRowIds[0];
|
if (currentValueRowId >= parentOutputIndex.length) {
|
throw new Error(`Got currentValueRowId=${currentValueRowId}, which is not less than ${parentOutputIndex.length}`);
|
}
|
let currentOutputIndex = parentOutputIndex[currentValueRowId];
|
result.push(currentOutputIndex);
|
for (let i = 1; i < indexSize; ++i) {
|
const nextValueRowId = valueRowIds[i];
|
if (nextValueRowId === currentValueRowId) {
|
if (currentOutputIndex >= 0) {
|
++currentOutputColumn;
|
if (currentOutputColumn < outputSize) {
|
currentOutputIndex += outputIndexMultiplier;
|
}
|
else {
|
currentOutputIndex = -1;
|
}
|
}
|
}
|
else {
|
currentOutputColumn = 0;
|
currentValueRowId = nextValueRowId;
|
if (nextValueRowId >= parentOutputIndex.length) {
|
throw new Error(`Got nextValueRowId=${nextValueRowId} which is not less than ${parentOutputIndex.length}`);
|
}
|
currentOutputIndex = parentOutputIndex[nextValueRowId];
|
}
|
result.push(currentOutputIndex);
|
}
|
if (result.length !== valueRowIds.length) {
|
throw new Error('Invalid row ids.');
|
}
|
return result;
|
}
|
calculateOutputIndex(dimension, parentOutputIndex, outputIndexMultiplier, outputSize) {
|
const rowPartitionTensor = this.getRowPartitionTensor(dimension);
|
const partitionType = this.getRowPartitionTypeByDimension(dimension);
|
switch (partitionType) {
|
case RowPartitionType.VALUE_ROWIDS:
|
return this.calculateOutputIndexValueRowID(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize);
|
case RowPartitionType.ROW_SPLITS:
|
if (rowPartitionTensor.length - 1 > parentOutputIndex.length) {
|
throw new Error(`Row partition size is greater than output size: ${rowPartitionTensor.length - 1} > ${parentOutputIndex.length}`);
|
}
|
return this.calculateOutputIndexRowSplit(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize);
|
default:
|
throw new Error(`Unsupported partition type: ${RowPartitionType[partitionType]}`);
|
}
|
}
|
getFirstDimensionSize() {
|
const firstPartitionTensor = this.rowPartitionValues[0];
|
if (this.rowPartitionTypes.length === 0) {
|
throw new Error('No row_partition_types given.');
|
}
|
const firstPartitionType = this.rowPartitionTypes[0];
|
switch (firstPartitionType) {
|
case RowPartitionType.FIRST_DIM_SIZE:
|
return firstPartitionTensor[0];
|
case RowPartitionType.VALUE_ROWIDS:
|
throw new Error('Cannot handle VALUE_ROWIDS in first dimension.');
|
case RowPartitionType.ROW_SPLITS:
|
return this.rowPartitionValuesShapes[0][0] - 1;
|
default:
|
throw new Error(`Cannot handle type ${RowPartitionType[firstPartitionType]}`);
|
}
|
}
|
compute() {
|
const firstPartitionTensor = this.rowPartitionValues[0];
|
if (firstPartitionTensor.length <= 0) {
|
throw new Error('Invalid first partition input. ' +
|
'Tensor requires at least one element.');
|
}
|
const firstDimension = this.getFirstDimensionSize();
|
const outputSize = this.calculateOutputSize(firstDimension);
|
const multiplier = new Array(this.raggedRank + 1);
|
multiplier[multiplier.length - 1] = 1;
|
for (let i = multiplier.length - 2; i >= 0; --i) {
|
multiplier[i] = multiplier[i + 1] * outputSize[i + 1];
|
}
|
// Full size of the tensor.
|
const outputShape = makeShape(outputSize, false);
|
const outputTensor = util.getArrayFromDType(this.valuesDType, util.sizeFromShape(outputShape));
|
const fullSize = multiplier[0] * outputSize[0];
|
if (fullSize > 0) {
|
let outputIndex = this.calculateFirstParentOutputIndex(firstDimension, multiplier[0], outputSize[0]);
|
for (let i = 1; i <= this.raggedRank; ++i) {
|
const newOutputIndex = this.calculateOutputIndex(i - 1, outputIndex, multiplier[i], outputSize[i]);
|
outputIndex = newOutputIndex;
|
}
|
this.setOutput(this.raggedRank, outputIndex, outputTensor, outputShape);
|
}
|
return [outputShape, outputTensor];
|
}
|
setOutput(raggedRank, outputIndex, outputTensor, outputShape) {
|
if (outputTensor.length === 0) {
|
return;
|
}
|
const valuesBase = this.values;
|
const outputBase = outputTensor;
|
let elementShape = outputShape.slice();
|
elementShape = elementShape.slice(raggedRank + 1);
|
const valueElementSize = util.sizeFromShape(elementShape);
|
const outputIndexSize = outputIndex.length;
|
// Broadcast the default value to value_element_size. (We can skip this
|
// if defaultValueTensor.size == 1, since we use fill when that's true.)
|
let defaultValue = this.defaultValue;
|
if (defaultValue.length !== valueElementSize && defaultValue.length !== 1) {
|
const srcShape = this.defaultValueShape;
|
tidy(() => {
|
const defaultValueTensor = reshape$1(defaultValue, srcShape);
|
const bCastDefault = broadcastTo(defaultValueTensor, elementShape);
|
defaultValue = bCastDefault.dataSync();
|
});
|
}
|
// Loop through the outputIndex array, finding contiguous regions that
|
// should be copied. Once we find the end of a contiguous region, copy it
|
// and add any necessary padding (with defaultValue).
|
let srcStart = 0; // Start of contiguous region (in values)
|
let dstStart = 0; // Destination for contiguous region (in output)
|
let dstEnd = 0; // Destination for contiguous region (in output)
|
for (let srcI = 0; srcI <= outputIndexSize; ++srcI) {
|
// dstI is the destination where the value at srcI should be copied.
|
let dstI = srcI < outputIndexSize ? outputIndex[srcI] : -1;
|
// If we're still in a contiguous region, then update dstEnd go to the
|
// next srcI.
|
if (dstI === dstEnd) {
|
++dstEnd;
|
continue;
|
}
|
// We found the end of contiguous region. This can be because we found
|
// a gap (dstI > dstEnd), or a source value that shouldn't be copied
|
// because it's out-of-bounds (dstI == -1), or the end of the tensor
|
// (dstI === -1).
|
if (dstStart < dstEnd) {
|
// Copy the contiguous region.
|
const src = valuesBase.subarray(srcStart * valueElementSize);
|
const dst = outputBase.subarray(dstStart * valueElementSize);
|
const nVals = (dstEnd - dstStart) * valueElementSize;
|
copyArray(dst, src, nVals);
|
}
|
// Add any necessary padding (w/ defaultValue).
|
if (srcI >= outputIndexSize) {
|
// We reached the end of values: pad to the end of output.
|
const outputSize = outputTensor.length;
|
dstI = Math.floor(outputSize / valueElementSize);
|
}
|
if (dstI > dstEnd) {
|
if (this.defaultValue.length === 1) {
|
outputBase
|
.subarray(dstEnd * valueElementSize, dstI * valueElementSize)
|
.fill(this.defaultValue[0]);
|
dstEnd = dstI;
|
}
|
else {
|
while (dstI > dstEnd) {
|
const dst = outputBase.slice(dstEnd * valueElementSize);
|
copyArray(dst, defaultValue, valueElementSize);
|
++dstEnd;
|
}
|
}
|
}
|
// Update indices.
|
if (dstI < 0) {
|
// srcI should be skipped -- leave it out of the contiguous region.
|
srcStart = srcI + 1;
|
dstStart = dstEnd;
|
}
|
else {
|
// srcI should be copied -- include it in the contiguous region.
|
srcStart = srcI;
|
dstStart = dstEnd;
|
dstEnd = dstStart + 1;
|
}
|
}
|
}
|
}
|
function copyArray(dst, src, size) {
|
for (let i = 0; i < size; i++) {
|
dst[i] = src[i];
|
}
|
}
|
function makeShape(shape, isPartial) {
|
const out = [];
|
for (let dim of shape) {
|
if (dim < 0) {
|
if (!isPartial) {
|
throw new Error(`Dimension ${dim} must be >= 0`);
|
}
|
if (dim < -1) {
|
throw new Error(`Dimension ${dim} must be >= -1`);
|
}
|
dim = -1;
|
}
|
out.push(dim);
|
}
|
return out;
|
}
|
function raggedTensorToTensorImpl(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes) {
|
return new RaggedTensorToTensorOp(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes)
|
.compute();
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function rangeImpl(start, stop, step, dtype) {
|
const sameStartStop = start === stop;
|
const increasingRangeNegativeStep = start < stop && step < 0;
|
const decreasingRangePositiveStep = stop < start && step > 1;
|
if (sameStartStop || increasingRangeNegativeStep ||
|
decreasingRangePositiveStep) {
|
return util.makeZerosTypedArray(0, dtype);
|
}
|
const numElements = Math.abs(Math.ceil((stop - start) / step));
|
const values = util.makeZerosTypedArray(numElements, dtype);
|
if (stop < start && step === 1) {
|
// Auto adjust the step's sign if it hasn't been set
|
// (or was set to 1)
|
step = -1;
|
}
|
values[0] = start;
|
for (let i = 1; i < values.length; i++) {
|
values[i] = values[i - 1] + step;
|
}
|
return values;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const rsqrtImpl = createSimpleUnaryImpl((xi) => 1 / Math.sqrt(xi));
|
const rsqrt = unaryKernelFuncFromImpl(Rsqrt, rsqrtImpl);
|
const rsqrtConfig = {
|
kernelName: Rsqrt,
|
backendName: 'cpu',
|
kernelFunc: rsqrt,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function scatterImpl(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) {
|
const flattenShape = [outputSize / sliceSize, sliceSize];
|
const indicesData = indices.values;
|
const updatesData = updates.values;
|
if (outputSize === 0) {
|
return buffer(shape, updates.dtype);
|
}
|
const outBuf = (defaultValue instanceof TensorBuffer) ?
|
defaultValue :
|
buffer(flattenShape, updates.dtype);
|
if (typeof defaultValue === 'string') {
|
outBuf.values.fill(defaultValue);
|
}
|
else if (typeof defaultValue === 'number') {
|
outBuf.values.fill(defaultValue);
|
}
|
else if (typeof defaultValue === 'boolean') {
|
outBuf.values.fill(+defaultValue);
|
}
|
for (let i = 0; i < numUpdates; i++) {
|
const index = [];
|
let flattenIndex = 0;
|
for (let j = 0; j < sliceRank; j++) {
|
const dim = indicesData[i * sliceRank + j];
|
index.push(dim);
|
flattenIndex += dim * strides[j];
|
}
|
if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) {
|
throw new Error(`Invalid indices: ${index} does not index into ${shape}`);
|
}
|
for (let k = 0; k < sliceSize; k++) {
|
if (sumDupeIndices) {
|
outBuf.values[flattenIndex * sliceSize + k] +=
|
updatesData[i * sliceSize + k];
|
}
|
else {
|
outBuf.values[flattenIndex * sliceSize + k] = updates.rank === 0 ?
|
updatesData[0] :
|
updatesData[i * sliceSize + k];
|
}
|
}
|
}
|
return outBuf;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const sigmoidImpl = createSimpleUnaryImpl((xi) => 1 / (1 + Math.exp(-xi)));
|
const sigmoid = unaryKernelFunc(Sigmoid, (xi) => 1 / (1 + Math.exp(-xi)));
|
const sigmoidConfig = {
|
kernelName: Sigmoid,
|
backendName: 'cpu',
|
kernelFunc: sigmoid,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function sliceImpl(vals, begin, size, shape, dtype) {
|
const isContinous = slice_util.isSliceContinous(shape, begin, size);
|
const length = util.sizeFromShape(size);
|
const xStrides = util.computeStrides(shape);
|
if (isContinous) {
|
const flatOffset = slice_util.computeFlatOffset(begin, xStrides);
|
if (dtype === 'string') {
|
return vals.slice(flatOffset, flatOffset + length);
|
}
|
return vals.subarray(flatOffset, flatOffset + length);
|
}
|
const decodedData = dtype === 'string' ?
|
backend_util.fromUint8ToStringArray(vals) :
|
vals;
|
const inBuf = buffer(shape, dtype, decodedData);
|
const outBuf = buffer(size, dtype);
|
for (let i = 0; i < outBuf.size; ++i) {
|
const outLoc = outBuf.indexToLoc(i);
|
const inLoc = outLoc.map((idx, j) => idx + begin[j]);
|
outBuf.set(inBuf.get(...inLoc), ...outLoc);
|
}
|
if (dtype === 'string') {
|
return backend_util.fromStringArrayToUint8(outBuf.values);
|
}
|
return outBuf.values;
|
}
|
function slice(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { begin, size } = attrs;
|
assertNotComplex(x, 'slice');
|
const [$begin, $size] = slice_util.parseSliceParams(x, begin, size);
|
slice_util.assertParamsValid(x, $begin, $size);
|
const vals = backend.data.get(x.dataId).values;
|
const outVals = sliceImpl(vals, $begin, $size, x.shape, x.dtype);
|
return backend.makeTensorInfo($size, x.dtype, outVals);
|
}
|
const sliceConfig = {
|
kernelName: Slice,
|
backendName: 'cpu',
|
kernelFunc: slice
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) {
|
const indicesCount = indicesShape[0];
|
const denseRows = denseShape[0];
|
const emptyRowIndicator = new Array(denseRows);
|
const reverseIndexMap = new Array(indicesCount);
|
const rank = indicesShape[1];
|
if (denseRows === 0) {
|
if (indicesCount !== 0) {
|
throw new Error(backend_util.getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesCount));
|
}
|
const outputIndices = util.getArrayFromDType(indicesDType, 0);
|
const outputValues = util.getArrayFromDType(valuesDType, 0);
|
return [
|
outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap
|
];
|
}
|
let rowsAreOrdered = true;
|
let lastIndicesRow = 0;
|
const csrOffset = new Array(denseRows).fill(0);
|
for (let i = 0; i < indicesCount; ++i) {
|
// indices is a 2d tensor with shape of [N, rank]
|
const row = indices[i * rank];
|
if (row < 0) {
|
throw new Error(backend_util.getSparseFillEmptyRowsNegativeIndexErrorMessage(i, row));
|
}
|
if (row >= denseRows) {
|
throw new Error(backend_util.getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(i, row, denseRows));
|
}
|
++csrOffset[row];
|
rowsAreOrdered = rowsAreOrdered && (row >= lastIndicesRow);
|
lastIndicesRow = row;
|
}
|
let allRowsFull = true;
|
for (let row = 0; row < denseRows; ++row) {
|
// csrOffset here describes the number of elements in this dense row
|
const rowEmpty = (csrOffset[row] === 0);
|
emptyRowIndicator[row] = rowEmpty;
|
allRowsFull = allRowsFull && !rowEmpty;
|
// In filled version, each row has at least one element.
|
csrOffset[row] = Math.max(csrOffset[row], 1);
|
// Update csrOffset to represent the number of elements up to and
|
// including denseRows + 1:
|
// csrOffset[0] == #{elements of row 0}
|
// csrOffset[1] == #{elements of row 1} + #{elements of row 0}
|
// ..
|
// csrOffset[i] == starting index for elements in row i + 1.
|
if (row > 0) {
|
csrOffset[row] += csrOffset[row - 1];
|
}
|
}
|
if (allRowsFull && rowsAreOrdered) {
|
const outputIndices = indices;
|
const outputValues = values;
|
for (let i = 0; i < indicesCount; ++i) {
|
reverseIndexMap[i] = i;
|
}
|
return [
|
outputIndices, [indicesCount, rank], outputValues, emptyRowIndicator,
|
reverseIndexMap
|
];
|
}
|
else {
|
const fullIndicesCount = csrOffset[denseRows - 1];
|
const outputIndices = util.getArrayFromDType(indicesDType, fullIndicesCount * rank);
|
const outputValues = util.getArrayFromDType(valuesDType, fullIndicesCount);
|
const filledCount = new Array(denseRows).fill(0);
|
// Fill in values for rows that are not missing
|
for (let i = 0; i < indicesCount; ++i) {
|
// indices is a 2d tensor with shape of [N, rank]
|
const row = indices[i * rank];
|
const offset = filledCount[row];
|
const outputI = ((row === 0) ? 0 : csrOffset[row - 1]) + offset;
|
filledCount[row]++; // Increment the filled count for this row.
|
for (let j = 0; j < rank; ++j) {
|
// indices and outputIndices are 2d tensors with shape of [N, rank]
|
outputIndices[outputI * rank + j] = indices[i * rank + j];
|
}
|
outputValues[outputI] = values[i];
|
// We'll need this reverse index map to backprop correctly.
|
reverseIndexMap[i] = outputI;
|
}
|
// Fill in values for rows that are missing
|
for (let row = 0; row < denseRows; ++row) {
|
const rowCount = filledCount[row];
|
if (rowCount === 0) { // We haven't filled this row
|
const startingIndex = (row === 0) ? 0 : csrOffset[row - 1];
|
// Remaining index values were set to zero already.
|
// Just need to set the row index in the right location.
|
// outputIndices is a 2d tensor with shape of [N, rank]
|
outputIndices[startingIndex * rank + 0] = row;
|
for (let col = 1; col < rank; ++col) {
|
outputIndices[startingIndex * rank + col] = 0;
|
}
|
outputValues[startingIndex] = defaultValue;
|
}
|
}
|
return [
|
outputIndices, [fullIndicesCount, rank], outputValues, emptyRowIndicator,
|
reverseIndexMap
|
];
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) {
|
const denseSize = util.sizeFromShape(inputShape);
|
const nnz = inputIndicesShape[0];
|
const outputRank = targetShape.length;
|
// Compute the output shape. Determine product of specified dimensions, and
|
// find the index of the unspecified one.
|
const outputShape = [];
|
let product = 1;
|
let unknownIndex = -1;
|
for (let d = 0; d < outputRank; ++d) {
|
const size = targetShape[d];
|
if (size === -1) {
|
if (unknownIndex !== -1) {
|
throw new Error(backend_util
|
.getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(unknownIndex, d));
|
}
|
unknownIndex = d;
|
outputShape.push(1);
|
}
|
else {
|
if (size < 0) {
|
throw new Error(backend_util.getSparseReshapeNegativeOutputDimErrorMessage(d, size));
|
}
|
product *= size;
|
outputShape.push(size);
|
}
|
}
|
if (unknownIndex !== -1) {
|
if (product <= 0) {
|
throw new Error(backend_util.getSparseReshapeEmptyTensorZeroOutputDimErrorMessage());
|
}
|
const missing = Math.trunc(denseSize / product);
|
if (product * missing !== denseSize) {
|
throw new Error(backend_util.getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape));
|
}
|
outputShape[unknownIndex] = missing;
|
}
|
const outputSize = util.sizeFromShape(outputShape);
|
if (outputSize !== denseSize) {
|
throw new Error(backend_util.getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape));
|
}
|
const inputRank = inputShape.length;
|
const inputStrides = [];
|
if (inputRank > 0) {
|
inputStrides[inputRank - 1] = 1;
|
for (let d = inputRank - 2; d >= 0; --d) {
|
inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1];
|
}
|
}
|
const outputStrides = [];
|
if (outputRank > 0) {
|
outputStrides[outputRank - 1] = 1;
|
for (let d = outputRank - 2; d >= 0; --d) {
|
outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1];
|
}
|
}
|
const newIndices = util.getArrayFromDType(inputDType, nnz * outputRank);
|
for (let i = 0; i < nnz; ++i) {
|
let id = 0;
|
for (let j = 0; j < inputRank; ++j) {
|
// inputIndices is a 2d tensor with shape of [nnz, inputRank]
|
id += inputIndices[i * inputRank + j] * inputStrides[j];
|
}
|
for (let j = 0; j < outputRank; ++j) {
|
// newIndices is a 2d tensor with shape of [nnz, outputRank]
|
newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]);
|
id %= outputStrides[j];
|
}
|
}
|
return [newIndices, [nnz, outputRank], outputShape];
|
}
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds, isMean = false, defaultValue = 0) {
|
const numIndices = indices.length;
|
// Flatten the array to two dimensions
|
const inputFlat = [inputShape[0], input.length / inputShape[0]];
|
const numCol = inputFlat[1];
|
// Note that the current implementation assumes that segmentIds values are
|
// sorted.
|
const lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0;
|
const outputRows = lastSegmentIdPlusOne;
|
if (outputRows < 0) {
|
throw new Error(backend_util.getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
|
}
|
const outputShape = inputShape.slice();
|
outputShape[0] = outputRows;
|
const outputLength = outputShape.reduce((product, value) => product * value, 1);
|
// Output array is initialized with the value 0 by default.
|
const output = util.getArrayFromDType(inputDType, outputLength);
|
// Note that we do not initialize the output buffer with a default value, so
|
// we need to explicitly set missing indices to the default value.
|
if (numIndices === 0) {
|
if (outputRows > 0) {
|
output.fill(defaultValue);
|
}
|
return [output, outputShape];
|
}
|
if (outputRows <= 0) {
|
throw new Error(backend_util.getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
|
}
|
let start = 0, end = 1;
|
// Index from which the output is not initialized.
|
let uninitializedIndex = 0;
|
let outIndex = segmentIds[start];
|
while (true) {
|
// We initialize nextIndex to 0 to avoid may be uninitialized warning
|
let nextIndex = 0;
|
if (end < numIndices) {
|
nextIndex = segmentIds[end];
|
if (outIndex === nextIndex) {
|
++end;
|
continue;
|
}
|
// We have a new segment here. Verify that the segment ids are growing.
|
if (outIndex >= nextIndex) {
|
throw new Error(backend_util
|
.getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage());
|
}
|
}
|
if (outIndex < 0 || outIndex >= outputRows) {
|
throw new Error(backend_util.getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(outIndex, outputRows));
|
}
|
// If there is a gap between two indices, we need to set that gap to the
|
// default value.
|
if (outIndex > uninitializedIndex) {
|
output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol);
|
}
|
for (let i = start; i < end; ++i) {
|
const index = indices[i];
|
if (index < 0 || index >= inputFlat[0]) {
|
throw new Error(backend_util.getSparseSegmentReductionIndicesOutOfRangeErrorMessage(i, indices[i], inputFlat[0]));
|
}
|
for (let j = 0; j < numCol; j++) {
|
output[outIndex * numCol + j] += input[index * numCol + j];
|
}
|
}
|
if (isMean) {
|
for (let j = 0; j < numCol; j++) {
|
output[outIndex * numCol + j] /= end - start;
|
}
|
}
|
start = end;
|
++end;
|
uninitializedIndex = outIndex + 1;
|
outIndex = nextIndex;
|
if (end > numIndices) {
|
break;
|
}
|
}
|
// Fill the gap at the end with the default value.
|
if (uninitializedIndex < outputRows) {
|
output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol);
|
}
|
return [output, outputShape];
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const sqrtImpl = createSimpleUnaryImpl((xi) => Math.sqrt(xi));
|
const sqrt = unaryKernelFunc(Sqrt, (xi) => Math.sqrt(xi));
|
const sqrtConfig = {
|
kernelName: Sqrt,
|
backendName: 'cpu',
|
kernelFunc: sqrt,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const squaredDifferenceImpl = createSimpleBinaryKernelImpl(((a, b) => {
|
const diff = a - b;
|
return diff * diff;
|
}));
|
const squaredDifference = binaryKernelFunc(SquaredDifference, squaredDifferenceImpl);
|
const squaredDifferenceConfig = {
|
kernelName: SquaredDifference,
|
backendName: 'cpu',
|
kernelFunc: squaredDifference
|
};
|
|
/**
|
* @license
|
* Copyright 2023 Google LLC.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const staticRegexReplaceImpl = createSimpleUnaryImpl((x, attrs) => {
|
const { pattern, replaceGlobal, rewrite } = attrs;
|
// TODO(mattSoulanille): Don't create a regex each time.
|
return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite);
|
});
|
const staticRegexReplace = unaryKernelFuncFromImpl(StaticRegexReplace, staticRegexReplaceImpl);
|
const staticRegexReplaceConfig = {
|
kernelName: StaticRegexReplace,
|
backendName: 'cpu',
|
kernelFunc: staticRegexReplace,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function stridedSliceImpl(outShape, xBuf, strides, begin) {
|
const outBuf = buffer(outShape, xBuf.dtype);
|
for (let i = 0; i < outBuf.size; i++) {
|
const loc = outBuf.indexToLoc(i);
|
const newLoc = new Array(loc.length);
|
for (let j = 0; j < newLoc.length; j++) {
|
newLoc[j] = loc[j] * strides[j] + begin[j];
|
}
|
outBuf.set(xBuf.get(...newLoc), ...loc);
|
}
|
return outBuf;
|
}
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/**
|
* The StringNGramsOp class creates ngrams from ragged string data.
|
* The constructor contains all attributes related to the operation such as
|
* padding widths and strings, and the compute function can be used to
|
* compute the ngrams for different ragged tensor inputs.
|
*/
|
class StringNGramsOp {
|
constructor(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
|
this.separator = util.encodeString(separator);
|
this.nGramWidths = nGramWidths;
|
this.leftPad = util.encodeString(leftPad);
|
this.rightPad = util.encodeString(rightPad);
|
this.padWidth = padWidth;
|
this.preserveShort = preserveShortSequences;
|
}
|
getPadWidth(nGramWidth) {
|
// Ngrams can be padded with either a fixed pad width or a dynamic pad
|
// width depending on the 'padWidth' arg, but in no case should the padding
|
// ever be wider than 'nGramWidth' - 1.
|
return Math.min(this.padWidth < 0 ? nGramWidth - 1 : this.padWidth, nGramWidth - 1);
|
}
|
getNumNGrams(length, nGramWidth) {
|
const padWidth = this.getPadWidth(nGramWidth);
|
return Math.max(0, ((length + 2 * padWidth) - nGramWidth) + 1);
|
}
|
createNGrams(data, splitIndex, output, outputStartIndex, numNGrams, nGramWidth) {
|
for (let nGramIndex = 0; nGramIndex < numNGrams; ++nGramIndex) {
|
const padWidth = this.getPadWidth(nGramWidth);
|
const leftPadding = Math.max(0, padWidth - nGramIndex);
|
const rightPadding = Math.max(0, padWidth - (numNGrams - (nGramIndex + 1)));
|
const numTokens = nGramWidth - (leftPadding + rightPadding);
|
const dataStartIndex = splitIndex + (leftPadding > 0 ? 0 : nGramIndex - padWidth);
|
// Calculate the total expected size of the nGram so we can reserve the
|
// correct amount of space in the string.
|
let nGramSize = 0;
|
// Size of the left padding.
|
nGramSize += leftPadding * this.leftPad.length;
|
// Size of the tokens.
|
for (let n = 0; n < numTokens; ++n) {
|
nGramSize += data[dataStartIndex + n].length;
|
}
|
// Size of the right padding.
|
nGramSize += rightPadding * this.rightPad.length;
|
// Size of the separators.
|
const numSeparators = leftPadding + rightPadding + numTokens - 1;
|
nGramSize += numSeparators * this.separator.length;
|
// Build the nGram.
|
output[outputStartIndex + nGramIndex] = new Uint8Array(nGramSize);
|
const nGram = output[outputStartIndex + nGramIndex];
|
let nextNGramIndex = 0;
|
const appendToNGram = (str) => str.forEach((value) => nGram[nextNGramIndex++] = value);
|
for (let n = 0; n < leftPadding; ++n) {
|
appendToNGram(this.leftPad);
|
appendToNGram(this.separator);
|
}
|
// Only output first numTokens - 1 pairs of data and separator
|
for (let n = 0; n < numTokens - 1; ++n) {
|
appendToNGram(data[dataStartIndex + n]);
|
appendToNGram(this.separator);
|
}
|
// Handle case when there are no tokens or no right padding as these
|
// can result in consecutive separators.
|
if (numTokens > 0) {
|
// If we have tokens, then output last and then pair each separator
|
// with the right padding that follows, to ensure nGram ends either with
|
// the token or with the right pad.
|
appendToNGram(data[dataStartIndex + numTokens - 1]);
|
for (let n = 0; n < rightPadding; ++n) {
|
appendToNGram(this.separator);
|
appendToNGram(this.rightPad);
|
}
|
}
|
else {
|
// If we don't have tokens, then the last item inserted into the nGram
|
// has been the separator from the left padding loop above. Hence,
|
// output right pad and separator and make sure to finish with a
|
// padding, not a separator.
|
for (let n = 0; n < rightPadding - 1; ++n) {
|
appendToNGram(this.rightPad);
|
appendToNGram(this.separator);
|
}
|
appendToNGram(this.rightPad);
|
}
|
}
|
}
|
// Data and splits together form the definition of the ragged tensor,
|
// where data is 1 dimensional and contains the values of the tensor
|
// and splits denotes the indices at which each row starts.
|
compute(data, splits) {
|
// Validate that the splits are valid indices into data, only if there are
|
// splits specified.
|
const inputDataSize = data.length;
|
const splitsSize = splits.length;
|
if (splitsSize > 0) {
|
let prevSplit = splits[0];
|
if (prevSplit !== 0) {
|
throw new Error(`First split value must be 0, got ${prevSplit}`);
|
}
|
for (let i = 1; i < splitsSize; ++i) {
|
let validSplits = splits[i] >= prevSplit;
|
validSplits = validSplits && (splits[i] <= inputDataSize);
|
if (!validSplits) {
|
throw new Error(`Invalid split value ${splits[i]}, must be in [${prevSplit}, ${inputDataSize}]`);
|
}
|
prevSplit = splits[i];
|
}
|
if (prevSplit !== inputDataSize) {
|
throw new Error(`Last split value must be data size. Expected ${inputDataSize}, got ${prevSplit}`);
|
}
|
}
|
const numBatchItems = splitsSize - 1;
|
const nGramsSplits = util.getArrayFromDType('int32', splitsSize);
|
// If there is no data or size, return an empty ragged tensor.
|
if (inputDataSize === 0 || splitsSize === 0) {
|
const empty = new Array(inputDataSize);
|
for (let i = 0; i <= numBatchItems; ++i) {
|
nGramsSplits[i] = 0;
|
}
|
return [empty, nGramsSplits];
|
}
|
nGramsSplits[0] = 0;
|
for (let i = 1; i <= numBatchItems; ++i) {
|
const length = splits[i] - splits[i - 1];
|
let numNGrams = 0;
|
this.nGramWidths.forEach((nGramWidth) => {
|
numNGrams += this.getNumNGrams(length, nGramWidth);
|
});
|
if (this.preserveShort && length > 0 && numNGrams === 0) {
|
numNGrams = 1;
|
}
|
nGramsSplits[i] = nGramsSplits[i - 1] + numNGrams;
|
}
|
const nGrams = new Array(nGramsSplits[numBatchItems]);
|
for (let i = 0; i < numBatchItems; ++i) {
|
const splitIndex = splits[i];
|
let outputStartIdx = nGramsSplits[i];
|
this.nGramWidths.forEach((nGramWidth) => {
|
const length = splits[i + 1] - splits[i];
|
const numNGrams = this.getNumNGrams(length, nGramWidth);
|
this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
|
outputStartIdx += numNGrams;
|
});
|
// If we're preserving short sequences, check to see if no sequence was
|
// generated by comparing the current output start idx to the original
|
// one (nGramSplitsdata). If no ngrams were generated, then they will
|
// be equal (since we increment outputStartIdx by numNGrams every
|
// time we create a set of ngrams.)
|
if (this.preserveShort && outputStartIdx === nGramsSplits[i]) {
|
const dataLength = splits[i + 1] - splits[i];
|
// One legitimate reason to not have any ngrams when this.preserveShort
|
// is true is if the sequence itself is empty. In that case, move on.
|
if (dataLength === 0) {
|
continue;
|
}
|
// We don't have to worry about dynamic padding sizes here: if padding
|
// was dynamic, every sequence would have had sufficient padding to
|
// generate at least one nGram.
|
const nGramWidth = dataLength + 2 * this.padWidth;
|
const numNGrams = 1;
|
this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
|
}
|
}
|
return [nGrams, nGramsSplits];
|
}
|
}
|
function stringNGramsImpl(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
|
return new StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences)
|
.compute(data, dataSplits);
|
}
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function split(str, delimiters, skipEmpty, result) {
|
if (!str.length) {
|
return;
|
}
|
// When the delimiter is empty, the input is split into individual characters.
|
if (delimiters.length === 0) {
|
for (let i = 0; i < str.length; ++i) {
|
result.push(str.subarray(i, i + 1));
|
}
|
return;
|
}
|
// When there is one delimiter, the input is split only at that delimiter.
|
if (delimiters.length === 1) {
|
const delimiter = delimiters[0];
|
let f = str.indexOf(delimiter);
|
while (f !== -1) {
|
const token = str.subarray(0, f);
|
if (!skipEmpty || token.length !== 0) {
|
result.push(token);
|
}
|
str = str.subarray(f + 1);
|
f = str.indexOf(delimiter);
|
}
|
if (!skipEmpty || str.length !== 0) {
|
result.push(str);
|
}
|
return;
|
}
|
// When there are multiple delimiters, the input is split at every instance
|
// one of the delimiters appears.
|
let tokenStart = 0;
|
for (let i = 0; i < str.length + 1; i++) {
|
if ((i === str.length) || (delimiters.indexOf(str[i]) !== -1)) {
|
const token = str.subarray(tokenStart, i);
|
if (!skipEmpty || token.length !== 0) {
|
result.push(token);
|
}
|
tokenStart = i + 1;
|
}
|
}
|
}
|
function stringSplitImpl(input, delimiter, skipEmpty) {
|
const batchSize = input.length;
|
// Empty delimiter means split the input character by character.
|
const tokens = [];
|
let outputSize = 0;
|
let maxNumEntries = 0;
|
const numIndices = new Array(batchSize);
|
for (let i = 0; i < batchSize; ++i) {
|
const prevTokensLength = tokens.length;
|
split(input[i], delimiter, skipEmpty, tokens);
|
const nEntries = tokens.length - prevTokensLength;
|
numIndices[i] = nEntries;
|
outputSize += nEntries;
|
maxNumEntries = Math.max(maxNumEntries, nEntries);
|
}
|
const indices = util.getArrayFromDType('int32', outputSize * 2);
|
const values = new Array(outputSize);
|
const shape = [batchSize, maxNumEntries];
|
let c = 0;
|
for (let i = 0; i < batchSize; ++i) {
|
for (let j = 0; j < numIndices[i]; ++j) {
|
// indices is a 2d tensor with shape of [outputSize, 2]
|
indices[c * 2] = i;
|
indices[c * 2 + 1] = j;
|
values[c] = tokens[c];
|
++c;
|
}
|
}
|
return [indices, values, shape];
|
}
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function stringToHashBucketFastImpl(input, numBuckets) {
|
const output = util.getArrayFromDType('int32', input.length);
|
for (let i = 0; i < input.length; ++i) {
|
output[i] =
|
util.fingerPrint64(input[i]).modulo(numBuckets).getLowBitsUnsigned();
|
}
|
return output;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const subImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue - bValue));
|
const subComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
|
return { real: aReal - bReal, imag: aImag - bImag };
|
}));
|
const sub = binaryKernelFunc(Sub, subImpl, subComplexImpl);
|
const subConfig = {
|
kernelName: Sub,
|
backendName: 'cpu',
|
kernelFunc: sub
|
};
|
|
/**
|
* @license
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/**
|
* An implementation of the tile kernel shared between webgl and cpu for string
|
* tensors only.
|
*/
|
function tileImpl(xBuf, reps) {
|
const newShape = new Array(xBuf.rank);
|
for (let i = 0; i < newShape.length; i++) {
|
newShape[i] = xBuf.shape[i] * reps[i];
|
}
|
const result = buffer(newShape, xBuf.dtype);
|
for (let i = 0; i < result.values.length; ++i) {
|
const newLoc = result.indexToLoc(i);
|
const originalLoc = new Array(xBuf.rank);
|
for (let j = 0; j < originalLoc.length; j++) {
|
originalLoc[j] = newLoc[j] % xBuf.shape[j];
|
}
|
const originalIndex = xBuf.locToIndex(originalLoc);
|
result.values[i] = xBuf.values[originalIndex];
|
}
|
return result;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const comparePair = (a, b) => {
|
const valueDiff = b.value - a.value;
|
return valueDiff === 0 ? a.index - b.index : valueDiff;
|
};
|
/**
|
* Partitions array where all elements smaller than the (k+1) smallest element
|
* are found to the left of it, and all larger to the right of it.
|
* Based on the Floyd-Rivest Algorithm, ref:
|
* https://en.wikipedia.org/wiki/Floyd%E2%80%93Rivest_algorithm
|
* @param array: Array to partition
|
* @param left: Left index for the interval
|
* @param right: Right index for the interval
|
* @param k: Desired index value, where array[k] is the (k+1)th smallest element
|
* when left = 0
|
*/
|
function select$1(array, k, left = 0, right = array.length - 1) {
|
while (right > left) {
|
// Use select recursively to sample a smaller set of size s
|
// the arbitrary constants 600 and 0.5 are used in the original
|
// version to minimize execution time.
|
if (right - left > 600) {
|
const n = right - left + 1;
|
const i = k - left + 1;
|
const z = Math.log(n);
|
const s = 0.5 * Math.exp(2 * z / 3);
|
const sd = 0.5 * Math.sqrt(z * s * (n - s) / n) * Math.sign(i - n / 2);
|
const newLeft = Math.max(left, Math.floor(k - i * s / n + sd));
|
const newRight = Math.min(right, Math.floor(k + (n - i) * s / n + sd));
|
select$1(array, k, newLeft, newRight);
|
}
|
// partition the elements between left and right around t
|
const t = array[k];
|
let i = left;
|
let j = right;
|
util.swap(array, left, k);
|
if (comparePair(array[right], t) > 0) {
|
util.swap(array, left, right);
|
}
|
while (i < j) {
|
util.swap(array, i, j);
|
i++;
|
j--;
|
while (comparePair(array[i], t) < 0) {
|
i = i + 1;
|
}
|
while (comparePair(array[j], t) > 0) {
|
j = j - 1;
|
}
|
}
|
if (comparePair(array[left], t) === 0) {
|
util.swap(array, left, j);
|
}
|
else {
|
j = j + 1;
|
util.swap(array, j, right);
|
}
|
// Adjust left and right towards the boundaries of the subset
|
// containing the (k - left + 1)th smallest element.
|
if (j <= k) {
|
left = j + 1;
|
}
|
if (k <= j) {
|
right = j - 1;
|
}
|
}
|
}
|
function topKImpl(x, xShape, xDtype, k, sorted) {
|
// Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
|
const lastDim = xShape[xShape.length - 1];
|
const [batch, size] = [x.length / lastDim, lastDim];
|
const allTopKVals = util.getTypedArrayFromDType(xDtype, batch * k);
|
const allTopKIndices = util.getTypedArrayFromDType('int32', batch * k);
|
for (let b = 0; b < batch; b++) {
|
const offset = b * size;
|
const vals = x.subarray(offset, offset + size);
|
let valAndInd = new Array(vals.length);
|
vals.forEach((value, index) => valAndInd[index] = { value, index });
|
if (k < valAndInd.length) {
|
select$1(valAndInd, k);
|
valAndInd = valAndInd.slice(0, k);
|
}
|
if (sorted) {
|
valAndInd.sort(comparePair);
|
}
|
const outOffset = b * k;
|
const topKVals = allTopKVals.subarray(outOffset, outOffset + k);
|
const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
|
for (let i = 0; i < k; i++) {
|
topKVals[i] = valAndInd[i].value;
|
topKIndices[i] = valAndInd[i].index;
|
}
|
}
|
// Reshape back to the original input shape, except that the last
|
// dimension is k.
|
const outputShape = xShape.slice();
|
outputShape[outputShape.length - 1] = k;
|
return [
|
buffer(outputShape, xDtype, allTopKVals),
|
buffer(outputShape, 'int32', allTopKIndices)
|
];
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function uniqueImpl(values, axis, shape, dtype) {
|
// Normalize and validate axis.
|
const $axis = util.parseAxisParam(axis, shape)[0];
|
// Calculate the new shape that is suitable for extracting data along the
|
// given axis.
|
//
|
// The rank is 3.
|
// The size of the 1st dimension is the size of all the axes < the given axis.
|
// The size of the 2nd dimension is the same as the size of the given axis.
|
// The size of the 3rd dimension is the size of all the axes > the given axis.
|
//
|
// For example, for a 4D tensor with shape=[2, 3, 5, 4] and axis=2, the
|
// newShape would be: [2*3, 5, 4].
|
//
|
// Note that this is not the final output shape. This will be the shape for an
|
// intermediate TensorBuffer (see inputBuffer below) to allow us to extract
|
// values along the given axis. To demonstrate how it works, consider the
|
// following example:
|
//
|
// Input: a 3D tensor, with shape [1, 2, 3]
|
// [
|
// [
|
// [1,2,3],
|
// [4,5,6]
|
// ]
|
// ]
|
// Axis: 2 (the last axis).
|
// Along axis 2, we expect to extract 3 tensors: [1,4], [2,5], [3,6].
|
//
|
// For this example, newShape would be: [2, 3, 1], where 2 is calculated from
|
// 1*2. The re-shaped data would look like:
|
//
|
// [
|
// [
|
// [1], [2], [3]
|
// ],
|
// [
|
// [4], [5], [6]
|
// ]
|
// ]
|
//
|
// Then, we can construct a 3-level nested loop by the following dimension
|
// order to extract the values along the axis (dimension1):
|
// i: dimension1 // 0,1,2 (newShape[1])
|
// m: dimension0 // 0,1 (newShape[0])
|
// n: dimension2 // 0 (newShape[2])
|
//
|
// m, i, n
|
// ---------
|
// Iteration 0: data at [0, 0, 0] => "1"
|
// Iteration 1: data at [1, 0, 0] => "4"
|
// We got [1,4].
|
// Iteration 2: data at [0, 1, 0] => "2"
|
// Iteration 3: data at [1, 1, 0] => "5"
|
// We got [2,5].
|
// Iteration 4: data at [0, 2, 0] => "3"
|
// Iteration 5: data at [1, 2, 0] => "6"
|
// We got [3,6].
|
const newShape = [1, shape[0], 1];
|
for (let i = 0; i < $axis; i++) {
|
newShape[0] *= shape[i];
|
}
|
newShape[1] = shape[$axis];
|
for (let i = $axis + 1; i < shape.length; i++) {
|
newShape[2] *= shape[i];
|
}
|
// A map from unique elements (their string representations) to their values
|
// in "indices" (below).
|
const uniqueElements = new Map();
|
// The indices of each unique element in the original tensor along the given
|
// axis. It is 1D and has the same size as the given axis.
|
const indices = new Int32Array(shape[$axis]);
|
// Create a buffer so we can easily extract value at a given location.
|
const inputBuffer = new TensorBuffer(newShape, dtype, values);
|
// The indices along the given axis that have unique elements. This is a
|
// de-duped version of "indices" above.
|
const uniqueIndices = [];
|
const is1DTensor = newShape[0] === 1 && newShape[2] === 1;
|
for (let i = 0; i < shape[$axis]; i++) {
|
// Extract values along the axis.
|
let element;
|
if (is1DTensor) {
|
// Fast path for 1D tensor input.
|
element = values[i].toString();
|
}
|
else {
|
const axisValues = [];
|
for (let m = 0; m < newShape[0]; m++) {
|
for (let n = 0; n < newShape[2]; n++) {
|
axisValues.push(inputBuffer.get(m, i, n));
|
}
|
}
|
element = axisValues.join(',');
|
}
|
// Dedup and update various indices.
|
const existingIndex = uniqueElements.get(element);
|
if (existingIndex != null) {
|
indices[i] = existingIndex;
|
}
|
else {
|
const uniqueIndex = uniqueElements.size;
|
uniqueElements.set(element, uniqueIndex);
|
indices[i] = uniqueIndex;
|
uniqueIndices.push(i);
|
}
|
}
|
// Now we know where each of the unique elements are located along the axis
|
// (uniqueIndices). Extract them from input buffer and store them in the
|
// output buffer.
|
const outputTmpShape = newShape.slice();
|
outputTmpShape[1] = uniqueElements.size;
|
const outputBuffer = new TensorBuffer(outputTmpShape, dtype);
|
uniqueIndices.forEach((uniqueElementIndex, i) => {
|
for (let m = 0; m < newShape[0]; m++) {
|
for (let n = 0; n < newShape[2]; n++) {
|
outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n);
|
}
|
}
|
});
|
// The output shape can be calculated from the input shape with the size of
|
// the given axis replaced by the number of unique elements along that axis.
|
const outputShape = shape.slice();
|
outputShape[$axis] = outputTmpShape[1];
|
return {
|
outputValues: outputBuffer.values,
|
outputShape,
|
indices,
|
};
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
|
var shared = {
|
__proto__: null,
|
addImpl: addImpl,
|
bincountImpl: bincountImpl,
|
bincountReduceImpl: bincountReduceImpl,
|
bitwiseAndImpl: bitwiseAndImpl,
|
castImpl: castImpl,
|
ceilImpl: ceilImpl,
|
concatImpl: concatImpl,
|
equalImpl: equalImpl,
|
expImpl: expImpl,
|
expm1Impl: expm1Impl,
|
floorDivImpl: floorDivImpl,
|
floorImpl: floorImpl,
|
gatherNdImpl: gatherNdImpl,
|
gatherV2Impl: gatherV2Impl,
|
greaterEqualImpl: greaterEqualImpl,
|
greaterImpl: greaterImpl,
|
lessEqualImpl: lessEqualImpl,
|
lessImpl: lessImpl,
|
linSpaceImpl: linSpaceImpl,
|
logImpl: logImpl,
|
maxImpl: maxImpl,
|
maximumImpl: maximumImpl,
|
minimumImpl: minimumImpl,
|
multiplyImpl: multiplyImpl,
|
negImpl: negImpl,
|
notEqualImpl: notEqualImpl,
|
prodImpl: prodImpl,
|
raggedGatherImpl: raggedGatherImpl,
|
raggedRangeImpl: raggedRangeImpl,
|
raggedTensorToTensorImpl: raggedTensorToTensorImpl,
|
rangeImpl: rangeImpl,
|
rsqrtImpl: rsqrtImpl,
|
scatterImpl: scatterImpl,
|
sigmoidImpl: sigmoidImpl,
|
simpleAbsImpl: simpleAbsImpl,
|
sliceImpl: sliceImpl,
|
sparseFillEmptyRowsImpl: sparseFillEmptyRowsImpl,
|
sparseReshapeImpl: sparseReshapeImpl,
|
sparseSegmentReductionImpl: sparseSegmentReductionImpl,
|
sqrtImpl: sqrtImpl,
|
squaredDifferenceImpl: squaredDifferenceImpl,
|
staticRegexReplaceImpl: staticRegexReplaceImpl,
|
stridedSliceImpl: stridedSliceImpl,
|
stringNGramsImpl: stringNGramsImpl,
|
stringSplitImpl: stringSplitImpl,
|
stringToHashBucketFastImpl: stringToHashBucketFastImpl,
|
subImpl: subImpl,
|
tileImpl: tileImpl,
|
topKImpl: topKImpl,
|
transposeImpl: transposeImpl,
|
uniqueImpl: uniqueImpl
|
};
|
|
/** @license See the LICENSE file. */
|
// This code is auto-generated, do not modify this file!
|
const version = '4.15.0';
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
// Side effects for default initialization of MathBackendCPU
|
registerBackend('cpu', () => new MathBackendCPU(), 1 /* priority */);
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const elu = unaryKernelFunc(Elu, (xi) => xi >= 0 ? xi : (Math.exp(xi) - 1));
|
const eluConfig = {
|
kernelName: Elu,
|
backendName: 'cpu',
|
kernelFunc: elu,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function leakyRelu(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { alpha } = attrs;
|
assertNotComplex([x], 'leakyRelu');
|
const xSize = util.sizeFromShape(x.shape);
|
const xVals = backend.data.get(x.dataId).values;
|
const outVals = util.getTypedArrayFromDType('float32', xSize);
|
for (let i = 0; i < xVals.length; i++) {
|
outVals[i] = xVals[i] < 0 ? alpha * xVals[i] : xVals[i];
|
}
|
return backend.makeTensorInfo(x.shape, 'float32', outVals);
|
}
|
const leakyReluConfig = {
|
kernelName: LeakyRelu,
|
backendName: 'cpu',
|
kernelFunc: leakyRelu
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const preluImpl = createSimpleBinaryKernelImpl((xValue, aValue) => xValue < 0 ? aValue * xValue : xValue);
|
function prelu(args) {
|
const { inputs, backend } = args;
|
const { x, alpha } = inputs;
|
assertNotComplex([x, alpha], 'prelu');
|
const aVals = backend.data.get(x.dataId).values;
|
const bVals = backend.data.get(alpha.dataId).values;
|
const [resultData, resultShape] = preluImpl(x.shape, alpha.shape, aVals, bVals, 'float32');
|
return backend.makeTensorInfo(resultShape, 'float32', resultData);
|
}
|
const preluConfig = {
|
kernelName: Prelu,
|
backendName: 'cpu',
|
kernelFunc: prelu,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const relu = unaryKernelFunc(Relu, (xi) => Math.max(0, xi));
|
const reluConfig = {
|
kernelName: Relu,
|
backendName: 'cpu',
|
kernelFunc: relu,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const relu6 = unaryKernelFunc(Relu6, (xi) => Math.min(Math.max(0, xi), 6));
|
const relu6Config = {
|
kernelName: Relu6,
|
backendName: 'cpu',
|
kernelFunc: relu6,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function applyActivation(backend, x, activation, preluActivationWeights, leakyreluAlpha) {
|
if (activation === 'linear') {
|
return identity({ inputs: { x }, backend });
|
}
|
else if (activation === 'relu') {
|
return relu({ inputs: { x }, backend });
|
}
|
else if (activation === 'elu') {
|
return elu({ inputs: { x }, backend });
|
}
|
else if (activation === 'relu6') {
|
return relu6({ inputs: { x }, backend });
|
}
|
else if (activation === 'prelu') {
|
return prelu({ inputs: { x, alpha: preluActivationWeights }, backend });
|
}
|
else if (activation === 'leakyrelu') {
|
return leakyRelu({ inputs: { x }, backend, attrs: { alpha: leakyreluAlpha } });
|
}
|
else if (activation === 'sigmoid') {
|
return sigmoid({ inputs: { x }, backend });
|
}
|
throw new Error(`Activation ${activation} has not been implemented for the CPU backend.`);
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function reshape(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { shape } = attrs;
|
const xSize = util.sizeFromShape(x.shape);
|
const $shape = util.inferFromImplicitShape(shape, xSize);
|
const $xSize = util.sizeFromShape($shape);
|
util.assert(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
|
`shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
|
`shape must have the same number of elements.`);
|
backend.incRef(x.dataId);
|
const xData = backend.data.get(x.dataId);
|
if (xData.complexTensorInfos != null) {
|
const real = xData.complexTensorInfos.real;
|
const imag = xData.complexTensorInfos.imag;
|
real.shape = $shape;
|
imag.shape = $shape;
|
}
|
return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
|
}
|
const reshapeConfig = {
|
kernelName: Reshape,
|
backendName: 'cpu',
|
kernelFunc: reshape
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function batchMatMul(args) {
|
const { inputs, backend, attrs } = args;
|
const { a, b } = inputs;
|
const { transposeA, transposeB } = attrs;
|
assertNotComplex([a, b], 'matMul');
|
const aRank = a.shape.length;
|
const bRank = b.shape.length;
|
const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
|
const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
|
const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
|
const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
|
const outerDimsA = a.shape.slice(0, -2);
|
const outerDimsB = b.shape.slice(0, -2);
|
const batchDimA = util.sizeFromShape(outerDimsA);
|
const batchDimB = util.sizeFromShape(outerDimsB);
|
const outShapeOuterDims = broadcast_util.assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
|
const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
|
util.assert(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
|
`${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
|
`${b.shape} and transposeA=${transposeA}` +
|
` and transposeB=${transposeB} must match.`);
|
const a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] :
|
[batchDimA, outerShapeA, innerShapeA];
|
const b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] :
|
[batchDimB, innerShapeB, outerShapeB];
|
// The rest of the implementation is designed to operate on rank-3 tensors
|
const a3d = reshape({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
|
const b3d = reshape({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
|
const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
|
const leftDim = transposeA ? a3d.shape[2] : a3d.shape[1];
|
const rightDim = transposeB ? b3d.shape[1] : b3d.shape[2];
|
const batchDim = Math.max(batchDimA, batchDimB);
|
const a3dValues = backend.data.get(a3d.dataId).values;
|
const b3dValues = backend.data.get(b3d.dataId).values;
|
const a3dStrides = util.computeStrides(a3d.shape);
|
const b3dStrides = util.computeStrides(b3d.shape);
|
const [aBatch, aOuterStep, aInnerStep] = transposeA ?
|
[a3dStrides[0], 1, a3dStrides[1]] :
|
[a3dStrides[0], a3dStrides[1], 1];
|
const [bInnerStep, bOuterStep, bBatch] = transposeB ?
|
[1, b3dStrides[1], b3dStrides[0]] :
|
[b3dStrides[1], 1, b3dStrides[0]];
|
const size = leftDim * rightDim;
|
const result = buffer([batchDim, leftDim, rightDim], a3d.dtype);
|
const resVals = result.values;
|
const blockSize = backend.blockSize;
|
for (let bi = 0; bi < batchDim; bi++) {
|
const batchIndexA = bi % batchDimA;
|
const batchIndexB = bi % batchDimB;
|
for (let i0 = 0; i0 < leftDim; i0 += blockSize) {
|
// for when blockSize doesn't evenly divide the input
|
const iBlock = Math.min(i0 + blockSize, leftDim);
|
for (let j0 = 0; j0 < rightDim; j0 += blockSize) {
|
const jBlock = Math.min(j0 + blockSize, rightDim);
|
for (let k0 = 0; k0 < sharedDim; k0 += blockSize) {
|
const kBlock = Math.min(k0 + blockSize, sharedDim);
|
for (let i = i0; i < iBlock; i++) {
|
for (let j = j0; j < jBlock; j++) {
|
let sum = 0.0;
|
for (let k = k0; k < kBlock; k++) {
|
const aVal =
|
// tslint:disable-next-line: max-line-length
|
a3dValues[batchIndexA * aBatch + i * aOuterStep + k * aInnerStep];
|
const bVal =
|
// tslint:disable-next-line: max-line-length
|
b3dValues[k * bInnerStep + j * bOuterStep + batchIndexB * bBatch];
|
sum += aVal * bVal;
|
}
|
resVals[bi * size + (i * rightDim + j)] += sum;
|
}
|
}
|
}
|
}
|
}
|
}
|
backend.disposeIntermediateTensorInfo(a3d);
|
backend.disposeIntermediateTensorInfo(b3d);
|
// set correct shape on output.
|
return backend.makeTensorInfo(outShape, result.dtype, result.values);
|
}
|
const batchMatMulConfig = {
|
kernelName: BatchMatMul,
|
backendName: 'cpu',
|
kernelFunc: batchMatMul,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function _fusedMatMul(args) {
|
const { inputs, backend, attrs } = args;
|
const { a, b, bias, preluActivationWeights } = inputs;
|
const { transposeA, transposeB, activation, leakyreluAlpha } = attrs;
|
let current;
|
let addRes;
|
let activationRes;
|
const intermediates = [];
|
const matMulRes = batchMatMul({ inputs: { a, b }, attrs: { transposeA, transposeB }, backend });
|
current = matMulRes;
|
if (bias) {
|
addRes = add({ inputs: { a: current, b: bias }, backend });
|
intermediates.push(current);
|
current = addRes;
|
}
|
if (activation) {
|
activationRes = applyActivation(backend, current, activation, preluActivationWeights, leakyreluAlpha);
|
intermediates.push(current);
|
current = activationRes;
|
}
|
for (const i of intermediates) {
|
backend.disposeIntermediateTensorInfo(i);
|
}
|
return current;
|
}
|
const _fusedMatMulConfig = {
|
kernelName: _FusedMatMul,
|
backendName: 'cpu',
|
kernelFunc: _fusedMatMul,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const acos = unaryKernelFunc(Acos, (xi) => Math.acos(xi));
|
const acosConfig = {
|
kernelName: Acos,
|
backendName: 'cpu',
|
kernelFunc: acos,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const acosh = unaryKernelFunc(Acosh, (xi) => Math.acosh(xi));
|
const acoshConfig = {
|
kernelName: Acosh,
|
backendName: 'cpu',
|
kernelFunc: acosh,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function addN(args) {
|
const { inputs, backend } = args;
|
const tensors = inputs;
|
assertNotComplex(inputs, 'addN');
|
const vals = tensors.map(t => backend.data.get(t.dataId).values);
|
const outBuf = buffer(tensors[0].shape, tensors[0].dtype);
|
const outVals = outBuf.values;
|
for (let i = 0; i < tensors.length; i++) {
|
const currVals = vals[i];
|
for (let j = 0; j < outVals.length; j++) {
|
outVals[j] += currVals[j];
|
}
|
}
|
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
|
}
|
const addNConfig = {
|
kernelName: AddN,
|
backendName: 'cpu',
|
kernelFunc: addN
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function all(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { axis, keepDims } = attrs;
|
assertNotComplex(x, 'all');
|
const origAxes = util.parseAxisParam(axis, x.shape);
|
let axes = origAxes;
|
const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length);
|
let $x = x;
|
if (permutedAxes != null) {
|
$x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
axes = backend_util.getInnerMostAxes(axes.length, x.shape.length);
|
}
|
backend_util.assertAxesAreInnerMostDims('all', axes, $x.shape.length);
|
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes($x.shape, axes);
|
const reduceSize = util.sizeFromShape(reduceShape);
|
const vals = util.makeZerosTypedArray(util.sizeFromShape(outShape), $x.dtype);
|
const aVals = backend.data.get($x.dataId).values;
|
for (let i = 0; i < vals.length; ++i) {
|
const offset = i * reduceSize;
|
let all = aVals[offset];
|
for (let j = 0; j < reduceSize; ++j) {
|
const value = aVals[offset + j];
|
all = all && value;
|
}
|
vals[i] = all;
|
}
|
if (permutedAxes != null) {
|
backend.disposeIntermediateTensorInfo($x);
|
}
|
const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
|
if (keepDims) {
|
const expandedShape = backend_util.expandShapeToKeepDim(outShape, origAxes);
|
const reshapedResult = reshape({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
|
backend.disposeIntermediateTensorInfo(result);
|
return reshapedResult;
|
}
|
return result;
|
}
|
const allConfig = {
|
kernelName: All,
|
backendName: 'cpu',
|
kernelFunc: all
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function any(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { axis, keepDims } = attrs;
|
assertNotComplex(x, 'any');
|
const origAxes = util.parseAxisParam(axis, x.shape);
|
let axes = origAxes;
|
const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length);
|
let $x = x;
|
if (permutedAxes != null) {
|
$x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
axes = backend_util.getInnerMostAxes(axes.length, x.shape.length);
|
}
|
backend_util.assertAxesAreInnerMostDims('any', axes, $x.shape.length);
|
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes($x.shape, axes);
|
const reduceSize = util.sizeFromShape(reduceShape);
|
const vals = util.makeZerosTypedArray(util.sizeFromShape(outShape), $x.dtype);
|
const aVals = backend.data.get($x.dataId).values;
|
for (let i = 0; i < vals.length; ++i) {
|
const offset = i * reduceSize;
|
let anyVal = aVals[offset];
|
for (let j = 0; j < reduceSize; ++j) {
|
const value = aVals[offset + j];
|
anyVal = anyVal || value;
|
}
|
vals[i] = anyVal;
|
}
|
if (permutedAxes != null) {
|
backend.disposeIntermediateTensorInfo($x);
|
}
|
const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
|
if (keepDims) {
|
const expandedShape = backend_util.expandShapeToKeepDim(outShape, origAxes);
|
const reshapedResult = reshape({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
|
backend.disposeIntermediateTensorInfo(result);
|
return reshapedResult;
|
}
|
return result;
|
}
|
const anyConfig = {
|
kernelName: Any,
|
backendName: 'cpu',
|
kernelFunc: any
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function argMax(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { axis } = attrs;
|
assertNotComplex(x, 'argMax');
|
let axes = util.parseAxisParam(axis, x.shape);
|
const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length);
|
let $x = x;
|
const intermediateTensorInfos = [];
|
if (permutedAxes != null) {
|
$x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
intermediateTensorInfos.push($x);
|
axes = backend_util.getInnerMostAxes(axes.length, $x.shape.length);
|
}
|
axes = [axes[0]];
|
backend_util.assertAxesAreInnerMostDims('argMax', axes, $x.shape.length);
|
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes($x.shape, axes);
|
const outSize = util.sizeFromShape(outShape);
|
const vals = util.makeZerosTypedArray(outSize, 'int32');
|
const reduceSize = util.sizeFromShape(reduceShape);
|
const aVals = backend.data.get($x.dataId).values;
|
for (let i = 0; i < vals.length; ++i) {
|
const offset = i * reduceSize;
|
let max = aVals[offset];
|
let maxIndex = 0;
|
for (let j = 0; j < reduceSize; ++j) {
|
const value = aVals[offset + j];
|
if (value > max) {
|
max = value;
|
maxIndex = j;
|
}
|
}
|
vals[i] = maxIndex;
|
}
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
return backend.makeTensorInfo(outShape, 'int32', vals);
|
}
|
const argMaxConfig = {
|
kernelName: ArgMax,
|
backendName: 'cpu',
|
kernelFunc: argMax
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function argMin(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { axis } = attrs;
|
assertNotComplex(x, 'argMin');
|
let axes = util.parseAxisParam(axis, x.shape);
|
const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length);
|
let $x = x;
|
const intermediateTensorInfos = [];
|
if (permutedAxes != null) {
|
$x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
intermediateTensorInfos.push($x);
|
axes = backend_util.getInnerMostAxes(axes.length, $x.shape.length);
|
}
|
axes = [axes[0]];
|
backend_util.assertAxesAreInnerMostDims('argMin', axes, $x.shape.length);
|
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes($x.shape, axes);
|
const outSize = util.sizeFromShape(outShape);
|
const vals = util.makeZerosTypedArray(outSize, 'int32');
|
const reduceSize = util.sizeFromShape(reduceShape);
|
const aVals = backend.data.get($x.dataId).values;
|
for (let i = 0; i < vals.length; ++i) {
|
const offset = i * reduceSize;
|
let min = aVals[offset];
|
let minIndex = 0;
|
for (let j = 0; j < reduceSize; ++j) {
|
const value = aVals[offset + j];
|
if (value < min) {
|
min = value;
|
minIndex = j;
|
}
|
}
|
vals[i] = minIndex;
|
}
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
return backend.makeTensorInfo(outShape, 'int32', vals);
|
}
|
const argMinConfig = {
|
kernelName: ArgMin,
|
backendName: 'cpu',
|
kernelFunc: argMin
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const asin = unaryKernelFunc(Asin, (xi) => Math.asin(xi));
|
const asinConfig = {
|
kernelName: Asin,
|
backendName: 'cpu',
|
kernelFunc: asin,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const asinh = unaryKernelFunc(Asinh, (xi) => Math.asinh(xi));
|
const asinhConfig = {
|
kernelName: Asinh,
|
backendName: 'cpu',
|
kernelFunc: asinh,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const atan = unaryKernelFunc(Atan, (xi) => Math.atan(xi));
|
const atanConfig = {
|
kernelName: Atan,
|
backendName: 'cpu',
|
kernelFunc: atan,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const atan2Impl = createSimpleBinaryKernelImpl((aValue, bValue) => Math.atan2(aValue, bValue));
|
const atan2 = binaryKernelFunc(Atan2, atan2Impl);
|
const atan2Config = {
|
kernelName: Atan2,
|
backendName: 'cpu',
|
kernelFunc: atan2,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const atanh = unaryKernelFunc(Atanh, (xi) => Math.atanh(xi));
|
const atanhConfig = {
|
kernelName: Atanh,
|
backendName: 'cpu',
|
kernelFunc: atanh,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function pool(xValues, xShape, dtype, strides, convInfo, poolType) {
|
const strideHeight = convInfo.strideHeight;
|
const strideWidth = convInfo.strideWidth;
|
const dilationHeight = convInfo.dilationHeight;
|
const dilationWidth = convInfo.dilationWidth;
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
const padTop = convInfo.padInfo.top;
|
const padLeft = convInfo.padInfo.left;
|
const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
|
Number.POSITIVE_INFINITY);
|
const output = buffer(convInfo.outShape, dtype);
|
const outputVals = output.values;
|
const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
|
const outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
|
const outputColStrides = convInfo.outShape[3];
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
const outputBatchOffset = b * outputBatchStrides;
|
const inputBatchOffset = b * strides[0];
|
for (let d = 0; d < convInfo.inChannels; ++d) {
|
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
|
const xRCorner = yR * strideHeight - padTop;
|
const xRMin = Math.max(0, xRCorner);
|
const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
|
const outputRowOffset = outputBatchOffset + yR * outputRowStrides;
|
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
|
const xCCorner = yC * strideWidth - padLeft;
|
const xCMin = Math.max(0, xCCorner);
|
const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
|
let minMaxValue = initialValue;
|
let avgValue = 0;
|
let count = 0;
|
for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
|
const xROffset = inputBatchOffset + xR * strides[1];
|
for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
|
const xCOffset = xROffset + xC * strides[2];
|
const pixel = xValues[xCOffset + d];
|
if ((poolType === 'max' && pixel > minMaxValue)) {
|
minMaxValue = pixel;
|
}
|
else if (poolType === 'avg') {
|
avgValue += pixel;
|
count++;
|
}
|
}
|
if (isNaN(minMaxValue)) {
|
break;
|
}
|
}
|
const outputOffset = outputRowOffset + yC * outputColStrides + d;
|
outputVals[outputOffset] =
|
poolType === 'avg' ? avgValue / count : minMaxValue;
|
}
|
}
|
}
|
}
|
return output;
|
}
|
function maxPoolPositions(xValues, xShape, dtype, convInfo, flattenPositions = false, includeBatchInIndex = false) {
|
const maxPositions = buffer(convInfo.outShape, 'int32');
|
const strideHeight = convInfo.strideHeight;
|
const strideWidth = convInfo.strideWidth;
|
const dilationHeight = convInfo.dilationHeight;
|
const dilationWidth = convInfo.dilationWidth;
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
const padTop = convInfo.padInfo.top;
|
const padLeft = convInfo.padInfo.left;
|
const xBuf = buffer(xShape, dtype, xValues);
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
for (let d = 0; d < convInfo.inChannels; ++d) {
|
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
|
const xRCorner = yR * strideHeight - padTop;
|
let xRMin = xRCorner;
|
while (xRMin < 0) {
|
xRMin += dilationHeight;
|
}
|
// const xRMin = Math.max(0, xRCorner);
|
const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
|
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
|
const xCCorner = yC * strideWidth - padLeft;
|
let xCMin = xCCorner;
|
while (xCMin < 0) {
|
xCMin += dilationWidth;
|
}
|
const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
|
let maxValue = Number.NEGATIVE_INFINITY;
|
let maxPosition = -1;
|
for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
|
const wR = xR - xRCorner;
|
for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
|
const wC = xC - xCCorner;
|
// For some reason, disable-next-line is not working
|
// TODO(mattsoulanille): Remove this when switching to TS5.
|
/* tslint:disable: no-unnecessary-type-assertion */
|
const pixel = xBuf.get(b, xR, xC, d);
|
if (pixel > maxValue) {
|
maxValue = pixel;
|
if (flattenPositions) {
|
maxPosition = includeBatchInIndex ?
|
((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) *
|
convInfo.inChannels +
|
d :
|
(xR * convInfo.inWidth + xC) * convInfo.inChannels + d;
|
}
|
else {
|
maxPosition = wR * effectiveFilterWidth + wC;
|
}
|
}
|
}
|
}
|
maxPositions.set(maxPosition, b, yR, yC, d);
|
}
|
}
|
}
|
}
|
return maxPositions;
|
}
|
function pool3d(xValues, xShape, dtype, strides, convInfo, poolType) {
|
const strideDepth = convInfo.strideDepth;
|
const strideHeight = convInfo.strideHeight;
|
const strideWidth = convInfo.strideWidth;
|
const dilationDepth = convInfo.dilationDepth;
|
const dilationHeight = convInfo.dilationHeight;
|
const dilationWidth = convInfo.dilationWidth;
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
const padFront = convInfo.padInfo.front;
|
const padTop = convInfo.padInfo.top;
|
const padLeft = convInfo.padInfo.left;
|
const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
|
Number.POSITIVE_INFINITY);
|
const output = buffer(convInfo.outShape, dtype);
|
const outputVals = output.values;
|
const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] *
|
convInfo.outShape[3] * convInfo.outShape[4];
|
const outputDepthStrides = convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
|
const outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];
|
const outputColStrides = convInfo.outShape[4];
|
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
|
const outputBatchOffset = batch * outputBatchStrides;
|
const inputBatchOffset = batch * strides[0];
|
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
|
for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
|
const xDepthCorner = yDepth * strideDepth - padFront;
|
let xDepthMin = xDepthCorner;
|
while (xDepthMin < 0) {
|
xDepthMin += dilationDepth;
|
}
|
const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
|
const outputDepthOffset = outputBatchOffset + yDepth * outputDepthStrides;
|
for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
|
const xRowCorner = yRow * strideHeight - padTop;
|
let xRowMin = xRowCorner;
|
while (xRowMin < 0) {
|
xRowMin += dilationHeight;
|
}
|
const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
|
const outputRowOffset = outputDepthOffset + yRow * outputRowStrides;
|
for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
|
const xColCorner = yCol * strideWidth - padLeft;
|
let xColMin = xColCorner;
|
while (xColMin < 0) {
|
xColMin += dilationWidth;
|
}
|
const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
|
// Shader code begins
|
const outputColOffset = outputRowOffset + yCol * outputColStrides;
|
let minMaxValue = initialValue;
|
let avgValue = 0;
|
let count = 0;
|
for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
|
const xDepthOffset = inputBatchOffset + xDepth * strides[1];
|
for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
|
const xRowOffset = xDepthOffset + xRow * strides[2];
|
for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
|
const xColOffset = xRowOffset + xCol * strides[3];
|
const pixel = xValues[xColOffset + channel];
|
if ((poolType === 'max' && pixel > minMaxValue)) {
|
minMaxValue = pixel;
|
}
|
else if (poolType === 'avg') {
|
avgValue += pixel;
|
count++;
|
}
|
if (isNaN(minMaxValue)) {
|
break;
|
}
|
}
|
if (isNaN(minMaxValue)) {
|
break;
|
}
|
}
|
if (isNaN(minMaxValue)) {
|
break;
|
}
|
}
|
const outputOffset = outputColOffset + channel;
|
outputVals[outputOffset] = poolType === 'avg' ?
|
avgValue / Math.max(count, 1) :
|
minMaxValue;
|
}
|
}
|
}
|
}
|
}
|
return output;
|
}
|
function maxPool3dPositions(xBuf, convInfo) {
|
const maxPositions = buffer(convInfo.outShape, 'int32');
|
const strideDepth = convInfo.strideDepth;
|
const strideHeight = convInfo.strideHeight;
|
const strideWidth = convInfo.strideWidth;
|
const dilationDepth = convInfo.dilationDepth;
|
const dilationHeight = convInfo.dilationHeight;
|
const dilationWidth = convInfo.dilationWidth;
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
const padFront = convInfo.padInfo.front;
|
const padTop = convInfo.padInfo.top;
|
const padLeft = convInfo.padInfo.left;
|
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
|
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
|
for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
|
const xDepthCorner = yDepth * strideDepth - padFront;
|
let xDepthMin = xDepthCorner;
|
while (xDepthMin < 0) {
|
xDepthMin += dilationDepth;
|
}
|
const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
|
for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
|
const xRowCorner = yRow * strideHeight - padTop;
|
let xRowMin = xRowCorner;
|
while (xRowMin < 0) {
|
xRowMin += dilationHeight;
|
}
|
const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
|
for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
|
const xColCorner = yCol * strideWidth - padLeft;
|
let xColMin = xColCorner;
|
while (xColMin < 0) {
|
xColMin += dilationWidth;
|
}
|
const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
|
// Shader code begins
|
let maxValue = Number.NEGATIVE_INFINITY;
|
let maxPosition = -1;
|
for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
|
const wDepth = xDepth - xDepthCorner;
|
for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
|
const wRow = xRow - xRowCorner;
|
for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
|
const wCol = xCol - xColCorner;
|
const pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);
|
if (pixel >= maxValue) {
|
maxValue = pixel;
|
maxPosition =
|
wDepth * effectiveFilterHeight * effectiveFilterWidth +
|
wRow * effectiveFilterHeight + wCol;
|
}
|
}
|
}
|
}
|
maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);
|
}
|
}
|
}
|
}
|
}
|
return maxPositions;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function avgPool(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
assertNotComplex(x, 'avgPool');
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
const dilations = 1;
|
util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
|
`Got strides ${strides} and dilations '${dilations}'`);
|
const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
let res;
|
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
|
util.arraysEqual(convInfo.inShape, convInfo.outShape)) {
|
res = identity({ inputs: { x }, backend });
|
}
|
else {
|
const xValues = backend.data.get(x.dataId).values;
|
const strides = util.computeStrides(x.shape);
|
const buffer = pool(xValues, x.shape, x.dtype, strides, convInfo, 'avg');
|
res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
|
}
|
return res;
|
}
|
const avgPoolConfig = {
|
kernelName: AvgPool,
|
backendName: 'cpu',
|
kernelFunc: avgPool
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function avgPool3D(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
|
assertNotComplex(x, 'avgPool3d');
|
const convInfo = backend_util.computePool3DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode, dataFormat);
|
const xValues = backend.data.get(x.dataId).values;
|
const outBuf = pool3d(xValues, x.shape, x.dtype, util.computeStrides(x.shape), convInfo, 'avg');
|
return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
|
}
|
const avgPool3DConfig = {
|
kernelName: AvgPool3D,
|
backendName: 'cpu',
|
kernelFunc: avgPool3D
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function avgPool3DGrad(args) {
|
const { inputs, backend, attrs } = args;
|
const { dy, input } = inputs;
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
assertNotComplex([dy, input], 'avgPool3DGrad');
|
const convInfo = backend_util.computePool3DInfo(input.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
|
const strideDepth = convInfo.strideDepth;
|
const strideHeight = convInfo.strideHeight;
|
const strideWidth = convInfo.strideWidth;
|
const filterDepth = convInfo.filterDepth;
|
const filterHeight = convInfo.filterHeight;
|
const filterWidth = convInfo.filterWidth;
|
const dilationDepth = convInfo.dilationDepth;
|
const dilationHeight = convInfo.dilationHeight;
|
const dilationWidth = convInfo.dilationWidth;
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
const dx = buffer(input.shape, 'float32');
|
const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
|
const dyBuf = backend.bufferSync(dy);
|
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
|
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
|
for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
|
for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
|
for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
|
// Shader code begins.
|
const dyDepthCorner = dxDepth - padFront;
|
const dyRowCorner = dxRow - padTop;
|
const dyColCorner = dxCol - padLeft;
|
let dotProd = 0;
|
for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
|
const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
|
if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
|
Math.floor(dyDepth) !== dyDepth) {
|
continue;
|
}
|
for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
|
const dyRow = (dyRowCorner + wRow) / strideHeight;
|
if (dyRow < 0 || dyRow >= convInfo.outHeight ||
|
Math.floor(dyRow) !== dyRow) {
|
continue;
|
}
|
for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
|
const dyCol = (dyColCorner + wCol) / strideWidth;
|
if (dyCol < 0 || dyCol >= convInfo.outWidth ||
|
Math.floor(dyCol) !== dyCol) {
|
continue;
|
}
|
const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
|
dotProd += pixel;
|
}
|
}
|
}
|
dx.set(dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol, channel);
|
}
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
}
|
const avgPool3DGradConfig = {
|
kernelName: AvgPool3DGrad,
|
backendName: 'cpu',
|
kernelFunc: avgPool3DGrad
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function avgPoolGrad(args) {
|
const { inputs, backend, attrs } = args;
|
const { dy, input } = inputs;
|
const x = input;
|
assertNotComplex([dy, input], 'avgPoolGrad');
|
const { filterSize, strides, pad } = attrs;
|
const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad);
|
const strideHeight = convInfo.strideHeight;
|
const strideWidth = convInfo.strideWidth;
|
const filterHeight = convInfo.filterHeight;
|
const filterWidth = convInfo.filterWidth;
|
const dilationHeight = convInfo.dilationHeight;
|
const dilationWidth = convInfo.dilationWidth;
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
const dx = buffer(x.shape, 'float32');
|
const avgMultiplier = 1 / (filterHeight * filterWidth);
|
const dyData = backend.data.get(dy.dataId).values;
|
const dyBuf = buffer(dy.shape, 'float32', dyData);
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
for (let d = 0; d < convInfo.inChannels; ++d) {
|
for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
|
for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
|
// Shader code begins.
|
const dyRCorner = dxR - padTop;
|
const dyCCorner = dxC - padLeft;
|
let dotProd = 0;
|
for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
|
const dyR = (dyRCorner + wR) / strideHeight;
|
if (dyR < 0 || dyR >= convInfo.outHeight ||
|
Math.floor(dyR) !== dyR) {
|
continue;
|
}
|
for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
|
const dyC = (dyCCorner + wC) / strideWidth;
|
if (dyC < 0 || dyC >= convInfo.outWidth ||
|
Math.floor(dyC) !== dyC) {
|
continue;
|
}
|
const pixel = dyBuf.get(b, dyR, dyC, d);
|
dotProd += pixel;
|
}
|
}
|
dx.set(dotProd * avgMultiplier, b, dxR, dxC, d);
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
}
|
const avgPoolGradConfig = {
|
kernelName: AvgPoolGrad,
|
backendName: 'cpu',
|
kernelFunc: avgPoolGrad
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function batchNorm(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, scale, offset, mean, variance } = inputs;
|
util.assert(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
|
'equal ranks.');
|
util.assert(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
|
'equal ranks.');
|
util.assert(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
|
'equal ranks.');
|
assertNotComplex([x, mean, variance, scale, offset], 'batchNorm');
|
let { varianceEpsilon } = attrs;
|
if (varianceEpsilon == null) {
|
varianceEpsilon = 0.001;
|
}
|
const xVals = backend.data.get(x.dataId).values;
|
const mVals = backend.data.get(mean.dataId).values;
|
const varVals = backend.data.get(variance.dataId).values;
|
const sVals = scale ? backend.data.get(scale.dataId).values :
|
new Float32Array([1]);
|
const offVals = offset ?
|
backend.data.get(offset.dataId).values :
|
new Float32Array([0]);
|
const outVals = new Float32Array(xVals.length);
|
const offValsLength = offVals.length;
|
const sValsLength = sVals.length;
|
const varValsLength = varVals.length;
|
const mValsLength = mVals.length;
|
let offi = 0;
|
let mi = 0;
|
let si = 0;
|
let vi = 0;
|
for (let i = 0; i < xVals.length; ++i) {
|
outVals[i] = offVals[offi++] +
|
(xVals[i] - mVals[mi++]) * sVals[si++] /
|
Math.sqrt(varVals[vi++] + varianceEpsilon);
|
if (offi >= offValsLength) {
|
offi = 0;
|
}
|
if (mi >= mValsLength) {
|
mi = 0;
|
}
|
if (si >= sValsLength) {
|
si = 0;
|
}
|
if (vi >= varValsLength) {
|
vi = 0;
|
}
|
}
|
return backend.makeTensorInfo(x.shape, x.dtype, outVals);
|
}
|
const batchNormConfig = {
|
kernelName: FusedBatchNorm,
|
backendName: 'cpu',
|
kernelFunc: batchNorm,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function batchToSpaceND(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { blockShape, crops } = attrs;
|
assertNotComplex([x], 'batchToSpaceND');
|
const prod = blockShape.reduce((a, b) => a * b);
|
const reshaped = backend_util.getReshaped(x.shape, blockShape, prod);
|
const permuted = backend_util.getPermuted(reshaped.length, blockShape.length);
|
const reshapedPermuted = backend_util.getReshapedPermuted(x.shape, blockShape, prod);
|
const sliceBeginCoords = backend_util.getSliceBeginCoords(crops, blockShape.length);
|
const sliceSize = backend_util.getSliceSize(reshapedPermuted, crops, blockShape.length);
|
const xReshaped = reshape({ inputs: { x }, backend, attrs: { shape: reshaped } });
|
const xTransposed = transpose({ inputs: { x: xReshaped }, backend, attrs: { perm: permuted } });
|
const xTransposedReshaped = reshape({ inputs: { x: xTransposed }, backend, attrs: { shape: reshapedPermuted } });
|
const result = slice({
|
inputs: { x: xTransposedReshaped },
|
backend,
|
attrs: { begin: sliceBeginCoords, size: sliceSize }
|
});
|
backend.disposeIntermediateTensorInfo(xReshaped);
|
backend.disposeIntermediateTensorInfo(xTransposed);
|
backend.disposeIntermediateTensorInfo(xTransposedReshaped);
|
return result;
|
}
|
const batchToSpaceNDConfig = {
|
kernelName: BatchToSpaceND,
|
backendName: 'cpu',
|
kernelFunc: batchToSpaceND
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function bincount(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, weights } = inputs;
|
const { size } = attrs;
|
const xVals = backend.data.get(x.dataId).values;
|
const weightsVals = backend.data.get(weights.dataId).values;
|
const outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
|
return backend.makeTensorInfo([size], weights.dtype, outVals);
|
}
|
const bincountConfig = {
|
kernelName: Bincount,
|
backendName: 'cpu',
|
kernelFunc: bincount
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function broadcastArgs(args) {
|
const { inputs, backend } = args;
|
const { s0, s1 } = inputs;
|
const s0Vals = backend.data.get(s0.dataId).values;
|
const s1Vals = backend.data.get(s1.dataId).values;
|
const broadcastShape = backend_util.assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
|
return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
|
}
|
const broadcastArgsConfig = {
|
kernelName: BroadcastArgs,
|
backendName: 'cpu',
|
kernelFunc: broadcastArgs
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const clipByValue = unaryKernelFunc(ClipByValue, (xi, attrs) => {
|
const clipAttrs = attrs;
|
if (xi > clipAttrs.clipValueMax) {
|
return clipAttrs.clipValueMax;
|
}
|
return xi < clipAttrs.clipValueMin ? clipAttrs.clipValueMin : xi;
|
});
|
const clipByValueConfig = {
|
kernelName: ClipByValue,
|
backendName: 'cpu',
|
kernelFunc: clipByValue,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const complexAbs = (args) => {
|
const { x } = args.inputs;
|
const cpuBackend = args.backend;
|
const resultValues = new Float32Array(util.sizeFromShape(x.shape));
|
const complexVals = cpuBackend.data.get(x.dataId);
|
const real = complexVals.complexTensorInfos.real;
|
const imag = complexVals.complexTensorInfos.imag;
|
const realVals = cpuBackend.data.get(real.dataId).values;
|
const imagVals = cpuBackend.data.get(imag.dataId).values;
|
for (let i = 0; i < realVals.length; i++) {
|
const real = realVals[i];
|
const imag = imagVals[i];
|
resultValues[i] = Math.hypot(real, imag);
|
}
|
return cpuBackend.makeOutput(resultValues, x.shape, 'float32');
|
};
|
const complexAbsConfig = {
|
kernelName: ComplexAbs,
|
backendName: 'cpu',
|
kernelFunc: complexAbs,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function imag(args) {
|
const { inputs, backend } = args;
|
const { input } = inputs;
|
const imag = backend.data.get(input.dataId).complexTensorInfos.imag;
|
const imagVal = backend.data.get(imag.dataId).values;
|
// When complex tensor is disposed, its underlying parts will be disposed too.
|
// Make new tensor out of the imag value of the complex. This makes sure the
|
// value is still accessible even if complex tensor is disposed.
|
return backend.makeTensorInfo(imag.shape, imag.dtype, imagVal);
|
}
|
const imagConfig = {
|
kernelName: Imag,
|
backendName: 'cpu',
|
kernelFunc: imag
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function concat(args) {
|
const { inputs, backend, attrs } = args;
|
const { axis } = attrs;
|
const $axis = util.parseAxisParam(axis, inputs[0].shape)[0];
|
const shapes = inputs.map(t => t.shape);
|
backend_util.assertParamsConsistent(shapes, $axis);
|
let outShape = backend_util.computeOutShape(inputs.map(t => t.shape), $axis);
|
if (util.sizeFromShape(outShape) === 0) {
|
return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
|
}
|
// Keep only non-empty tensors (ignore tensors with 0 in their shape).
|
const $inputs = inputs.filter(t => util.sizeFromShape(t.shape) > 0);
|
if ($inputs.length === 1) {
|
return identity({ inputs: { x: $inputs[0] }, backend });
|
}
|
if ($inputs[0].dtype === 'complex64') {
|
const reals = $inputs.map((t) => real({ inputs: { input: t }, backend }));
|
const imags = $inputs.map((t) => imag({ inputs: { input: t }, backend }));
|
const realConcated = concat({ inputs: reals, backend, attrs: { axis: $axis } });
|
const imagConcated = concat({ inputs: imags, backend, attrs: { axis: $axis } });
|
const result = complex({ inputs: { real: realConcated, imag: imagConcated }, backend });
|
reals.forEach(r => backend.disposeIntermediateTensorInfo(r));
|
imags.forEach(i => backend.disposeIntermediateTensorInfo(i));
|
backend.disposeIntermediateTensorInfo(realConcated);
|
backend.disposeIntermediateTensorInfo(imagConcated);
|
return result;
|
}
|
// Any concat of n-dimensional tensors across any axis can be reduced to
|
// a concatenation of two-dimensional tensors across the axis 1 by first
|
// partitioning the axes of the original tensors into those less than the
|
// axis to be concatenated and the rest. Then reshape the tensors
|
// into a two-dimensional tensor by collapsing these two sets of axes and
|
// concatenate the resulting matrices across the axis 1, finally reshaping
|
// the result to have the proper shape.
|
const inputs2D = $inputs.map(t => {
|
const innerSize = util.sizeFromShape(t.shape.slice($axis));
|
const shape = [-1, innerSize];
|
return reshape({ inputs: { x: t }, backend, attrs: { shape } });
|
});
|
const inputsValShapes = inputs2D.map(t => {
|
return { vals: backend.data.get(t.dataId).values, shape: t.shape };
|
});
|
// Concats 2d tensors along axis=1.
|
outShape =
|
backend_util.computeOutShape(inputs2D.map(t => t.shape), 1 /* axis */);
|
const simplyConcat = inputs2D[0].shape[0] === 1;
|
const outVals = concatImpl(inputsValShapes, outShape, inputs[0].dtype, simplyConcat);
|
const finalOutShape = backend_util.computeOutShape($inputs.map(t => t.shape), $axis);
|
const outInfo = backend.makeTensorInfo(finalOutShape, inputs[0].dtype, outVals);
|
inputs2D.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
return outInfo;
|
}
|
const concatConfig = {
|
kernelName: Concat,
|
backendName: 'cpu',
|
kernelFunc: concat
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function conv2D(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, filter } = inputs;
|
const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs;
|
assertNotComplex([x, filter], 'conv2d');
|
const $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);
|
const convInfo = backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
|
const filterHeight = convInfo.filterHeight;
|
const filterWidth = convInfo.filterWidth;
|
const dilationHeight = convInfo.dilationHeight;
|
const dilationWidth = convInfo.dilationWidth;
|
const padLeft = convInfo.padInfo.left;
|
const padTop = convInfo.padInfo.top;
|
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
const y = new TensorBuffer(convInfo.outShape, x.dtype);
|
const xStrides = util.computeStrides(x.shape);
|
const filterStrides = util.computeStrides(filter.shape);
|
const xBatchStride = xStrides[0];
|
const xRowStride = isChannelsLast ? xStrides[1] : xStrides[2];
|
const xColStride = isChannelsLast ? xStrides[2] : 1;
|
const xChannelStride = isChannelsLast ? 1 : xStrides[1];
|
const yBatchStride = y.strides[0];
|
const yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];
|
const yColStride = isChannelsLast ? y.strides[2] : 1;
|
const yChannelStride = isChannelsLast ? 1 : y.strides[1];
|
const xVals = backend.data.get(x.dataId).values;
|
const wVals = backend.data.get(filter.dataId).values;
|
const yVals = y.values;
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
const xOffset1 = b * xBatchStride;
|
const yOffset1 = b * yBatchStride;
|
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
|
const yOffset2 = yOffset1 + yR * yRowStride;
|
const xRCorner = yR * convInfo.strideHeight - padTop;
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
const xR = xRCorner + wR * dilationHeight;
|
if (xR < 0 || xR >= convInfo.inHeight) {
|
continue;
|
}
|
const wOffset1 = wR * filterStrides[0];
|
const xOffset2 = xOffset1 + xR * xRowStride;
|
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
|
const yOffset3 = yOffset2 + yC * yColStride;
|
const xCCorner = yC * convInfo.strideWidth - padLeft;
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
const xC = xCCorner + wC * dilationWidth;
|
if (xC < 0 || xC >= convInfo.inWidth) {
|
continue;
|
}
|
const wOffset2 = wOffset1 + wC * filterStrides[1];
|
const xOffset3 = xOffset2 + xC * xColStride;
|
let wOffset3 = wOffset2;
|
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
|
const xVal = xVals[xOffset3 + d1 * xChannelStride];
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
yVals[yOffset3 + d2 * yChannelStride] +=
|
xVal * wVals[wOffset3 + d2];
|
}
|
wOffset3 += convInfo.outChannels;
|
}
|
}
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(y.shape, y.dtype, yVals);
|
}
|
const conv2DConfig = {
|
kernelName: Conv2D,
|
backendName: 'cpu',
|
kernelFunc: conv2D
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function conv2DBackpropFilter(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, dy } = inputs;
|
const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs;
|
assertNotComplex([x, dy], 'conv2dBackpropFilter');
|
const $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);
|
const convInfo = backend_util.computeConv2DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
|
const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
|
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
const dW = new TensorBuffer(convInfo.filterShape, 'float32');
|
const leftPad = convInfo.padInfo.left;
|
const topPad = convInfo.padInfo.top;
|
const xVals = backend.data.get(x.dataId).values;
|
const dyVals = backend.data.get(dy.dataId).values;
|
const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
|
const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
|
const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
|
const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
|
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
let dotProd = 0;
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
for (let yR = yRMin; yR < yRMax; ++yR) {
|
const xR = wR + yR * strideHeight - topPad;
|
for (let yC = yCMin; yC < yCMax; ++yC) {
|
const xC = wC + yC * strideWidth - leftPad;
|
if (isChannelsLast) {
|
dotProd += xBuf.get(b, xR, xC, d1) *
|
dyBuf.get(b, yR, yC, d2);
|
}
|
else {
|
dotProd += xBuf.get(b, d1, xR, xC) *
|
dyBuf.get(b, d2, yR, yC);
|
}
|
}
|
}
|
}
|
dW.set(dotProd, wR, wC, d1, d2);
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
|
}
|
const conv2DBackpropFilterConfig = {
|
kernelName: Conv2DBackpropFilter,
|
backendName: 'cpu',
|
kernelFunc: conv2DBackpropFilter
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function conv2DBackpropInput(args) {
|
const { inputs, backend, attrs } = args;
|
const { dy, filter } = inputs;
|
const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs;
|
assertNotComplex([dy, filter], 'conv2dBackpropInput');
|
const filterStrides = util.computeStrides(filter.shape);
|
const dyStrides = util.computeStrides(dy.shape);
|
let $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);
|
const convInfo = backend_util.computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat);
|
const dx = new TensorBuffer(convInfo.inShape, 'float32');
|
const dxValues = dx.values;
|
const dyValues = backend.data.get(dy.dataId).values;
|
const fltValues = backend.data.get(filter.dataId).values;
|
const [fltS0, fltS1, fltS2] = filterStrides;
|
const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
|
$dataFormat = convInfo.dataFormat;
|
const topPad = filterHeight - 1 - convInfo.padInfo.top;
|
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
|
const isChannelsLast = $dataFormat === 'channelsLast';
|
const xBatchStride = dx.strides[0];
|
const xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];
|
const xColStride = isChannelsLast ? dx.strides[2] : 1;
|
const xChannelStride = isChannelsLast ? 1 : dx.strides[1];
|
const yBatchStride = dyStrides[0];
|
const yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];
|
const yColStride = isChannelsLast ? dyStrides[2] : 1;
|
const yChannelStride = isChannelsLast ? 1 : dyStrides[1];
|
for (let b = 0; b < batchSize; ++b) {
|
for (let d1 = 0; d1 < inChannels; ++d1) {
|
for (let xR = 0; xR < inHeight; ++xR) {
|
const xRCorner = xR - topPad;
|
const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
|
const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
|
for (let xC = 0; xC < inWidth; ++xC) {
|
const xCCorner = xC - leftPad;
|
const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
|
const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
|
let dotProd = 0;
|
for (let yR = xRMin; yR < yRMax; ++yR) {
|
const wR = yR * strideHeight - xRCorner;
|
for (let yC = xCMin; yC < yCMax; ++yC) {
|
const wC = yC * strideWidth - xCCorner;
|
const dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC;
|
const fltOffset = fltS0 * (filterHeight - 1 - wR) +
|
fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
|
for (let d2 = 0; d2 < outChannels; ++d2) {
|
const pixel = dyValues[dyOffset + yChannelStride * d2];
|
const weight = fltValues[fltOffset + d2];
|
dotProd += pixel * weight;
|
}
|
}
|
}
|
const dxOffset = xBatchStride * b + xRowStride * xR +
|
xColStride * xC + xChannelStride * d1;
|
dxValues[dxOffset] = dotProd;
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
}
|
const conv2DBackpropInputConfig = {
|
kernelName: Conv2DBackpropInput,
|
backendName: 'cpu',
|
kernelFunc: conv2DBackpropInput
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function conv3D(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, filter } = inputs;
|
const { strides, pad, dilations } = attrs;
|
assertNotComplex([x, filter], 'conv3d');
|
const convInfo = backend_util.computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
|
const { filterDepth, filterHeight, filterWidth, dilationDepth, dilationHeight, dilationWidth, padInfo } = convInfo;
|
const padFront = padInfo.front;
|
const padLeft = padInfo.left;
|
const padTop = padInfo.top;
|
const y = new TensorBuffer(convInfo.outShape, x.dtype);
|
const xVals = backend.data.get(x.dataId).values;
|
const wVals = backend.data.get(filter.dataId).values;
|
const yVals = y.values;
|
const xStrides = util.computeStrides(x.shape);
|
const filterStrides = util.computeStrides(filter.shape);
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
const xOffset1 = b * xStrides[0];
|
const yOffset1 = b * y.strides[0];
|
for (let yF = 0; yF < convInfo.outDepth; ++yF) {
|
const yOffset2 = yOffset1 + yF * y.strides[1];
|
const xFCorner = yF * convInfo.strideDepth - padFront;
|
for (let wF = 0; wF < filterDepth; ++wF) {
|
const xF = xFCorner + wF * dilationDepth;
|
if (xF < 0 || xF >= convInfo.inDepth) {
|
continue;
|
}
|
const wOffset1 = wF * filterStrides[0];
|
const xOffset2 = xOffset1 + xF * xStrides[1];
|
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
|
const yOffset3 = yOffset2 + yR * y.strides[2];
|
const xRCorner = yR * convInfo.strideHeight - padTop;
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
const xR = xRCorner + wR * dilationHeight;
|
if (xR < 0 || xR >= convInfo.inHeight) {
|
continue;
|
}
|
const wOffset2 = wOffset1 + wR * filterStrides[1];
|
const xOffset3 = xOffset2 + xR * xStrides[2];
|
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
|
const yOffset4 = yOffset3 + yC * convInfo.outChannels;
|
const xCCorner = yC * convInfo.strideWidth - padLeft;
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
const xC = xCCorner + wC * dilationWidth;
|
if (xC < 0 || xC >= convInfo.inWidth) {
|
continue;
|
}
|
const wOffset3 = wOffset2 + wC * filterStrides[2];
|
const xOffset4 = xOffset3 + xC * convInfo.inChannels;
|
let wOffset4 = wOffset3;
|
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
|
const xVal = xVals[xOffset4 + d1];
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2];
|
}
|
wOffset4 += convInfo.outChannels;
|
}
|
}
|
}
|
}
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(y.shape, y.dtype, y.values);
|
}
|
const conv3DConfig = {
|
kernelName: Conv3D,
|
backendName: 'cpu',
|
kernelFunc: conv3D
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function conv3DBackpropFilterV2(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, dy } = inputs;
|
const { strides, pad, filterShape } = attrs;
|
assertNotComplex([x, dy], 'conv3dBackpropFilterV2');
|
const xStrides = util.computeStrides(x.shape);
|
const dyStrides = util.computeStrides(dy.shape);
|
const convInfo = backend_util.computeConv3DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad);
|
const strideDepth = convInfo.strideDepth;
|
const strideHeight = convInfo.strideHeight;
|
const strideWidth = convInfo.strideWidth;
|
const filterDepth = convInfo.filterDepth;
|
const filterHeight = convInfo.filterHeight;
|
const filterWidth = convInfo.filterWidth;
|
const dw = new TensorBuffer(convInfo.filterShape, 'float32');
|
const dwValues = dw.values;
|
const [dwS0, dwS1, dwS2, dwS3] = dw.strides;
|
const dyValues = backend.data.get(dy.dataId).values;
|
const [dyS0, dyS1, dyS2, dyS3] = dyStrides;
|
const xValues = backend.data.get(x.dataId).values;
|
const [xS0, xS1, xS2, xS3] = xStrides;
|
const frontPad = convInfo.padInfo.front;
|
const leftPad = convInfo.padInfo.left;
|
const topPad = convInfo.padInfo.top;
|
for (let wF = 0; wF < filterDepth; ++wF) {
|
const yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth));
|
const yFMax = Math.min(convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth);
|
const wOffset1 = wF * dwS0;
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
|
const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
|
const wOffset2 = wR * dwS1 + wOffset1;
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
|
const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
|
const wOffset3 = wC * dwS2 + wOffset2;
|
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
|
const wOffset4 = d1 * dwS3 + wOffset3;
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
let dotProd = 0;
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
const xOffset1 = b * xS0;
|
const yOffset1 = b * dyS0;
|
for (let yF = yFMin; yF < yFMax; ++yF) {
|
const xF = wF + yF * strideDepth - frontPad;
|
const xOffset2 = xF * xS1 + xOffset1;
|
const yOffset2 = yF * dyS1 + yOffset1;
|
for (let yR = yRMin; yR < yRMax; ++yR) {
|
const xR = wR + yR * strideHeight - topPad;
|
const xOffset3 = xR * xS2 + xOffset2;
|
const yOffset3 = yR * dyS2 + yOffset2;
|
for (let yC = yCMin; yC < yCMax; ++yC) {
|
const xC = wC + yC * strideWidth - leftPad;
|
const xOffset4 = xC * xS3 + xOffset3;
|
const yOffset4 = yC * dyS3 + yOffset3;
|
dotProd += xValues[xOffset4 + d1] * dyValues[yOffset4 + d2];
|
}
|
}
|
}
|
}
|
dwValues[wOffset4 + d2] = dotProd;
|
}
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(dw.shape, dw.dtype, dw.values);
|
}
|
const conv3DBackpropFilterV2Config = {
|
kernelName: Conv3DBackpropFilterV2,
|
backendName: 'cpu',
|
kernelFunc: conv3DBackpropFilterV2
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function conv3DBackpropInputV2(args) {
|
const { inputs, backend, attrs } = args;
|
const { dy, filter } = inputs;
|
const { pad, strides, inputShape } = attrs;
|
assertNotComplex([dy], 'conv3dBackpropInputV2');
|
const dyStrides = util.computeStrides(dy.shape);
|
const filterStrides = util.computeStrides(filter.shape);
|
const convInfo = backend_util.computeConv3DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad);
|
const dx = new TensorBuffer(convInfo.inShape, 'float32');
|
const dxValues = dx.values;
|
const [dxS0, dxS1, dxS2, dxS3] = dx.strides;
|
const dyValues = backend.data.get(dy.dataId).values;
|
const [dyS0, dyS1, dyS2, dyS3] = dyStrides;
|
const fltValues = backend.data.get(filter.dataId).values;
|
const [fltS0, fltS1, fltS2, fltS3] = filterStrides;
|
const { batchSize, filterDepth, filterHeight, filterWidth, inChannels, inDepth, inHeight, inWidth, outChannels, outDepth, outHeight, outWidth, strideDepth, strideHeight, strideWidth } = convInfo;
|
const frontPad = filterDepth - 1 - convInfo.padInfo.front;
|
const topPad = filterHeight - 1 - convInfo.padInfo.top;
|
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
|
for (let b = 0; b < batchSize; ++b) {
|
for (let d1 = 0; d1 < inChannels; ++d1) {
|
// Frames of depth
|
for (let xF = 0; xF < inDepth; ++xF) {
|
const xFCorner = xF - frontPad;
|
const xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth));
|
const yFMax = Math.min(outDepth, (filterDepth + xFCorner) / strideDepth);
|
// Rows as per standard 2d matrix notation
|
for (let xR = 0; xR < inHeight; ++xR) {
|
const xRCorner = xR - topPad;
|
const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
|
const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
|
// Columns as per standard 2d matrix notation
|
for (let xC = 0; xC < inWidth; ++xC) {
|
const xCCorner = xC - leftPad;
|
const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
|
const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
|
let dotProd = 0;
|
for (let yF = xFMin; yF < yFMax; ++yF) {
|
const wF = yF * strideDepth - xFCorner;
|
for (let yR = xRMin; yR < yRMax; ++yR) {
|
const wR = yR * strideHeight - xRCorner;
|
for (let yC = xCMin; yC < yCMax; ++yC) {
|
const wC = yC * strideWidth - xCCorner;
|
const dyOffset = dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC;
|
const fltOffset = fltS0 * (filterDepth - 1 - wF) +
|
fltS1 * (filterHeight - 1 - wR) +
|
fltS2 * (filterWidth - 1 - wC) + fltS3 * d1;
|
for (let d2 = 0; d2 < outChannels; ++d2) {
|
const pixel = dyValues[dyOffset + d2];
|
const weight = fltValues[fltOffset + d2];
|
dotProd += pixel * weight;
|
}
|
}
|
}
|
}
|
dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] =
|
dotProd;
|
}
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
}
|
const conv3DBackpropInputV2Config = {
|
kernelName: Conv3DBackpropInputV2,
|
backendName: 'cpu',
|
kernelFunc: conv3DBackpropInputV2
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const cos = unaryKernelFunc(Cos, (xi) => Math.cos(xi));
|
const cosConfig = {
|
kernelName: Cos,
|
backendName: 'cpu',
|
kernelFunc: cos,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const cosh = unaryKernelFunc(Cosh, (xi) => Math.cosh(xi));
|
const coshConfig = {
|
kernelName: Cosh,
|
backendName: 'cpu',
|
kernelFunc: cosh,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function cropAndResize(args) {
|
const { inputs, backend, attrs } = args;
|
const { image, boxes, boxInd } = inputs;
|
const { cropSize, method, extrapolationValue } = attrs;
|
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
|
const numBoxes = boxes.shape[0];
|
const [cropHeight, cropWidth] = cropSize;
|
const output = buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32');
|
const boxVals = backend.data.get(boxes.dataId).values;
|
const boxIndVals = backend.data.get(boxInd.dataId).values;
|
const imageVals = backend.data.get(image.dataId).values;
|
const inStride = util.computeStrides(image.shape); // to calculate flat indexes into image
|
const outStride = util.computeStrides(output.shape); // to calculate flat indexes into output
|
// Reference implementation
|
// tslint:disable-next-line:max-line-length
|
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op.cc
|
for (let b = 0; b < numBoxes; b++) {
|
const startInd = b * 4;
|
const y1 = boxVals[startInd];
|
const x1 = boxVals[startInd + 1];
|
const y2 = boxVals[startInd + 2];
|
const x2 = boxVals[startInd + 3];
|
const bInd = boxIndVals[b];
|
if (bInd >= batch) {
|
continue;
|
}
|
const heightScale = (cropHeight > 1) ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : 0;
|
const widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0;
|
for (let y = 0; y < cropHeight; y++) {
|
const yInd = (cropHeight > 1) ?
|
y1 * (imageHeight - 1) + y * (heightScale) :
|
0.5 * (y1 + y2) * (imageHeight - 1);
|
if (yInd < 0 || yInd > imageHeight - 1) {
|
for (let x = 0; x < cropWidth; x++) {
|
for (let c = 0; c < numChannels; c++) {
|
const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
|
output.values[ind] = extrapolationValue;
|
}
|
}
|
continue;
|
}
|
if (method === 'bilinear') {
|
const topInd = Math.floor(yInd);
|
const bottomInd = Math.ceil(yInd);
|
const yLerp = yInd - topInd;
|
for (let x = 0; x < cropWidth; x++) {
|
const xInd = (cropWidth > 1) ?
|
x1 * (imageWidth - 1) + x * widthScale :
|
0.5 * (x1 + x2) * (imageWidth - 1);
|
if (xInd < 0 || xInd > imageWidth - 1) {
|
for (let c = 0; c < numChannels; c++) {
|
const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
|
output.values[ind] = extrapolationValue;
|
}
|
continue;
|
}
|
const leftInd = Math.floor(xInd);
|
const rightInd = Math.ceil(xInd);
|
const xLerp = xInd - leftInd;
|
for (let c = 0; c < numChannels; c++) {
|
let ind = c + leftInd * inStride[2] + topInd * inStride[1] +
|
bInd * inStride[0];
|
const topLeft = imageVals[ind];
|
ind = c + rightInd * inStride[2] + topInd * inStride[1] +
|
bInd * inStride[0];
|
const topRight = imageVals[ind];
|
ind = c + leftInd * inStride[2] + bottomInd * inStride[1] +
|
bInd * inStride[0];
|
const bottomLeft = imageVals[ind];
|
ind = c + rightInd * inStride[2] + bottomInd * inStride[1] +
|
bInd * inStride[0];
|
const bottomRight = imageVals[ind];
|
const top = topLeft + (topRight - topLeft) * xLerp;
|
const bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp;
|
ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
|
output.values[ind] = top + ((bottom - top) * yLerp);
|
}
|
}
|
}
|
else { // method == "nearest"
|
for (let x = 0; x < cropWidth; ++x) {
|
const xInd = (cropWidth > 1) ?
|
x1 * (imageWidth - 1) + x * widthScale :
|
0.5 * (x1 + x2) * (imageWidth - 1);
|
if (xInd < 0 || xInd > imageWidth - 1) {
|
for (let c = 0; c < numChannels; c++) {
|
const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
|
output.values[ind] = extrapolationValue;
|
}
|
continue;
|
}
|
const closestX = Math.round(xInd);
|
const closestY = Math.round(yInd);
|
for (let c = 0; c < numChannels; c++) {
|
const inInd = c + closestX * inStride[2] + closestY * inStride[1] +
|
bInd * inStride[0];
|
const outInd = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
|
output.values[outInd] = imageVals[inInd];
|
}
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(output.shape, output.dtype, output.values);
|
}
|
const cropAndResizeConfig = {
|
kernelName: CropAndResize,
|
backendName: 'cpu',
|
kernelFunc: cropAndResize
|
};
|
|
/**
|
* @license
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function cumprod(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { axis, exclusive, reverse } = attrs;
|
assertNotComplex(x, 'cumprod');
|
const permutation = backend_util.getAxesPermutation([axis], x.shape.length);
|
let $x = x;
|
if (permutation != null) {
|
$x = transpose({ inputs: { x }, backend, attrs: { perm: permutation } });
|
}
|
const permutedAxis = backend_util.getInnerMostAxes(1, x.shape.length)[0];
|
if (permutedAxis !== $x.shape.length - 1) {
|
throw new Error(`backend.cumprod in CPU expects an inner-most ` +
|
`axis=${$x.shape.length - 1} but got axis=${permutedAxis}`);
|
}
|
const resultDtype = upcastType($x.dtype, 'int32');
|
const vals = util.makeOnesTypedArray(util.sizeFromShape($x.shape), resultDtype);
|
const aVals = backend.data.get($x.dataId).values;
|
const finalDim = $x.shape[$x.shape.length - 1];
|
const indexAdjuster = reverse ?
|
(i, j) => i + finalDim - j - 1 :
|
(i, j) => i + j;
|
for (let i = 0; i < aVals.length; i += finalDim) {
|
for (let j = 0; j < finalDim; j++) {
|
const idx = indexAdjuster(i, j);
|
if (j === 0) {
|
vals[idx] = exclusive ? 1 : aVals[idx];
|
}
|
else {
|
const prevIdx = indexAdjuster(i, j - 1);
|
vals[idx] = exclusive ? aVals[prevIdx] * vals[prevIdx] :
|
aVals[idx] * vals[prevIdx];
|
}
|
}
|
}
|
const result = backend.makeTensorInfo($x.shape, resultDtype, vals);
|
if (permutation != null) {
|
const reversePermutation = backend_util.getUndoAxesPermutation(permutation);
|
const reverseTransposedResult = transpose({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
|
backend.disposeIntermediateTensorInfo(result);
|
backend.disposeIntermediateTensorInfo($x);
|
return reverseTransposedResult;
|
}
|
return result;
|
}
|
const cumprodConfig = {
|
kernelName: Cumprod,
|
backendName: 'cpu',
|
kernelFunc: cumprod
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function cumsum(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { axis, exclusive, reverse } = attrs;
|
assertNotComplex(x, 'cumsum');
|
const permutation = backend_util.getAxesPermutation([axis], x.shape.length);
|
let $x = x;
|
if (permutation != null) {
|
$x = transpose({ inputs: { x }, backend, attrs: { perm: permutation } });
|
}
|
const permutedAxis = backend_util.getInnerMostAxes(1, x.shape.length)[0];
|
if (permutedAxis !== $x.shape.length - 1) {
|
throw new Error(`backend.cumsum in CPU expects an inner-most ` +
|
`axis=${$x.shape.length - 1} but got axis=${permutedAxis}`);
|
}
|
const resultDtype = upcastType($x.dtype, 'int32');
|
const vals = util.makeZerosTypedArray(util.sizeFromShape($x.shape), resultDtype);
|
const aVals = backend.data.get($x.dataId).values;
|
const finalDim = $x.shape[$x.shape.length - 1];
|
const indexAdjuster = reverse ?
|
(i, j) => i + finalDim - j - 1 :
|
(i, j) => i + j;
|
for (let i = 0; i < aVals.length; i += finalDim) {
|
for (let j = 0; j < finalDim; j++) {
|
const idx = indexAdjuster(i, j);
|
if (j === 0) {
|
vals[idx] = exclusive ? 0 : aVals[idx];
|
}
|
else {
|
const prevIdx = indexAdjuster(i, j - 1);
|
vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] :
|
aVals[idx] + vals[prevIdx];
|
}
|
}
|
}
|
const result = backend.makeTensorInfo($x.shape, resultDtype, vals);
|
if (permutation != null) {
|
const reversePermutation = backend_util.getUndoAxesPermutation(permutation);
|
const reverseTransposedResult = transpose({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
|
backend.disposeIntermediateTensorInfo(result);
|
backend.disposeIntermediateTensorInfo($x);
|
return reverseTransposedResult;
|
}
|
return result;
|
}
|
const cumsumConfig = {
|
kernelName: Cumsum,
|
backendName: 'cpu',
|
kernelFunc: cumsum
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function denseBincount(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, weights } = inputs;
|
const { size, binaryOutput } = attrs;
|
if (x.shape.length === 1) {
|
const xVals = backend.data.get(x.dataId).values;
|
const weightsVals = backend.data.get(weights.dataId).values;
|
const outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
|
return backend.makeTensorInfo([size], weights.dtype, outVals);
|
}
|
else if (x.shape.length === 2) {
|
const xBuf = backend.bufferSync(x);
|
const weightsBuf = backend.bufferSync(weights);
|
const outBuf = bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput);
|
return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
|
}
|
throw new Error(`Error in denseBincount: input must be at most rank 2, but got rank` +
|
`${x.shape.length}.`);
|
}
|
const denseBincountConfig = {
|
kernelName: DenseBincount,
|
backendName: 'cpu',
|
kernelFunc: denseBincount
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function depthToSpace(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { blockSize, dataFormat } = attrs;
|
util.assert(dataFormat === 'NHWC', () => `Only NHWC dataFormat supported on CPU for depthToSpace. Got ${dataFormat}`);
|
const batchSize = x.shape[0];
|
const inputHeight = x.shape[1];
|
const inputWidth = x.shape[2];
|
const inputDepth = x.shape[3];
|
const outputHeight = inputHeight * blockSize;
|
const outputWidth = inputWidth * blockSize;
|
const outputDepth = inputDepth / (blockSize * blockSize);
|
const xValues = backend.data.get(x.dataId).values;
|
const result = new Float32Array(batchSize * outputHeight * outputWidth * outputDepth);
|
let outputIdx = 0;
|
for (let b = 0; b < batchSize; ++b) {
|
for (let h = 0; h < outputHeight; ++h) {
|
const inH = Math.floor(h / blockSize);
|
const offsetH = (h % blockSize);
|
for (let w = 0; w < outputWidth; ++w) {
|
const inW = Math.floor(w / blockSize);
|
const offsetW = (w % blockSize);
|
const offsetD = (offsetH * blockSize + offsetW) * outputDepth;
|
for (let d = 0; d < outputDepth; ++d) {
|
const inD = d + offsetD;
|
const inputIdx = inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b));
|
result[outputIdx++] = xValues[inputIdx];
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo([batchSize, outputHeight, outputWidth, outputDepth], x.dtype, result);
|
}
|
const depthToSpaceConfig = {
|
kernelName: DepthToSpace,
|
backendName: 'cpu',
|
kernelFunc: depthToSpace
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function depthwiseConv2dNative(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, filter } = inputs;
|
const { strides, pad, dilations, dimRoundingMode } = attrs;
|
assertNotComplex([x, filter], 'depthwiseConv2DNative');
|
const xStrides = util.computeStrides(x.shape);
|
const filterStrides = util.computeStrides(filter.shape);
|
let $dilations = dilations;
|
if ($dilations == null) {
|
$dilations = [1, 1];
|
}
|
util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
|
`1. Got strides ${strides} and dilations '${$dilations}'`);
|
const convInfo = backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
|
const { filterHeight, filterWidth, dilationHeight, dilationWidth, padInfo } = convInfo;
|
const padLeft = padInfo.left;
|
const padTop = padInfo.top;
|
const chMul = convInfo.outChannels / convInfo.inChannels;
|
const y = new TensorBuffer(convInfo.outShape, x.dtype);
|
const xVals = backend.data.get(x.dataId).values;
|
const wVals = backend.data.get(filter.dataId).values;
|
const yVals = y.values;
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
const xOffset1 = b * xStrides[0];
|
const yOffset1 = b * y.strides[0];
|
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
|
const yOffset2 = yOffset1 + yR * y.strides[1];
|
const xRCorner = yR * convInfo.strideHeight - padTop;
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
const xR = xRCorner + wR * dilationHeight;
|
if (xR < 0 || xR >= convInfo.inHeight) {
|
continue;
|
}
|
const wOffset1 = wR * filterStrides[0];
|
const xOffset2 = xOffset1 + xR * xStrides[1];
|
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
|
const yOffset3 = yOffset2 + yC * y.strides[2];
|
const xCCorner = yC * convInfo.strideWidth - padLeft;
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
const xC = xCCorner + wC * dilationWidth;
|
if (xC < 0 || xC >= convInfo.inWidth) {
|
continue;
|
}
|
const wOffset2 = wOffset1 + wC * filterStrides[1];
|
const xOffset3 = xOffset2 + xC * convInfo.inChannels;
|
let yOffset4 = yOffset3;
|
let wOffset3 = wOffset2;
|
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
|
const xVal = xVals[xOffset3 + d1];
|
for (let q = 0; q < chMul; ++q) {
|
yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q];
|
}
|
yOffset4 += chMul;
|
wOffset3 += chMul;
|
}
|
}
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(y.shape, y.dtype, y.values);
|
}
|
const depthwiseConv2dNativeConfig = {
|
kernelName: DepthwiseConv2dNative,
|
backendName: 'cpu',
|
kernelFunc: depthwiseConv2dNative
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function depthwiseConv2dNativeBackpropFilter(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, dy } = inputs;
|
const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs;
|
assertNotComplex([x, dy], 'depthwiseConv2dNativeBackpropFilter');
|
const convInfo = backend_util.computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
|
const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
|
const dW = new TensorBuffer(convInfo.filterShape, 'float32');
|
const leftPad = convInfo.padInfo.left;
|
const topPad = convInfo.padInfo.top;
|
const chMul = convInfo.outChannels / convInfo.inChannels;
|
const xVals = backend.data.get(x.dataId).values;
|
const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
|
const dyVals = backend.data.get(dy.dataId).values;
|
const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
|
const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
|
const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
const d1 = Math.trunc(d2 / chMul);
|
const dm = d2 % chMul;
|
let dotProd = 0;
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
for (let yR = yRMin; yR < yRMax; ++yR) {
|
const xR = wR + yR * strideHeight - topPad;
|
for (let yC = yCMin; yC < yCMax; ++yC) {
|
const xC = wC + yC * strideWidth - leftPad;
|
dotProd += xBuf.get(b, xR, xC, d1) *
|
dyBuf.get(b, yR, yC, d2);
|
}
|
}
|
}
|
dW.set(dotProd, wR, wC, d1, dm);
|
}
|
}
|
}
|
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
|
}
|
const depthwiseConv2dNativeBackpropFilterConfig = {
|
kernelName: DepthwiseConv2dNativeBackpropFilter,
|
backendName: 'cpu',
|
kernelFunc: depthwiseConv2dNativeBackpropFilter
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function depthwiseConv2dNativeBackpropInput(args) {
|
const { inputs, backend, attrs } = args;
|
const { dy, filter } = inputs;
|
const { strides, dilations, pad, dimRoundingMode, inputShape } = attrs;
|
assertNotComplex([dy, filter], 'depthwiseConv2DNativeBackpropInput');
|
const dyStrides = util.computeStrides(dy.shape);
|
const filterStrides = util.computeStrides(filter.shape);
|
const convInfo = backend_util.computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
|
const dx = new TensorBuffer(convInfo.inShape, 'float32');
|
const dxValues = dx.values;
|
const [dxS0, dxS1, dxS2] = dx.strides;
|
const dyValues = backend.data.get(dy.dataId).values;
|
const [dyS0, dyS1, dyS2] = dyStrides;
|
const fltValues = backend.data.get(filter.dataId).values;
|
const [fltS0, fltS1, fltS2] = filterStrides;
|
const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
|
const topPad = filterHeight - 1 - convInfo.padInfo.top;
|
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
|
const chMul = outChannels / inChannels;
|
for (let b = 0; b < batchSize; ++b) {
|
for (let d1 = 0; d1 < inChannels; ++d1) {
|
for (let xR = 0; xR < inHeight; ++xR) {
|
const xRCorner = xR - topPad;
|
const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
|
const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
|
for (let xC = 0; xC < inWidth; ++xC) {
|
const xCCorner = xC - leftPad;
|
const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
|
const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
|
let dotProd = 0;
|
for (let yR = xRMin; yR < yRMax; ++yR) {
|
const wR = yR * strideHeight - xRCorner;
|
for (let yC = xCMin; yC < yCMax; ++yC) {
|
const wC = yC * strideWidth - xCCorner;
|
const dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC;
|
const fltOffset = fltS0 * (filterHeight - 1 - wR) +
|
fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
|
for (let dm = 0; dm < chMul; ++dm) {
|
const d2 = d1 * chMul + dm;
|
const pixel = dyValues[dyOffset + d2];
|
const weight = fltValues[fltOffset + dm];
|
dotProd += pixel * weight;
|
}
|
}
|
}
|
dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd;
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
}
|
const depthwiseConv2dNativeBackpropInputConfig = {
|
kernelName: DepthwiseConv2dNativeBackpropInput,
|
backendName: 'cpu',
|
kernelFunc: depthwiseConv2dNativeBackpropInput
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function diag(args) {
|
const { inputs, backend } = args;
|
const { x } = inputs;
|
const xSize = util.sizeFromShape(x.shape);
|
const xVals = backend.data.get(x.dataId).values;
|
const outBuf = buffer([xSize, xSize], x.dtype);
|
const vals = outBuf.values;
|
for (let i = 0; i < xVals.length; i++) {
|
vals[i * xSize + i] = xVals[i];
|
}
|
const outShape = [...x.shape, ...x.shape];
|
return backend.makeTensorInfo(outShape, outBuf.dtype, outBuf.values);
|
}
|
const diagConfig = {
|
kernelName: Diag,
|
backendName: 'cpu',
|
kernelFunc: diag
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const dilation2DConfig = {
|
kernelName: Dilation2D,
|
backendName: 'cpu',
|
kernelFunc: ({ inputs, backend, attrs }) => {
|
const { x, filter } = inputs;
|
const { strides, pad, dilations } = attrs;
|
const cpuBackend = backend;
|
const xVals = cpuBackend.data.get(x.dataId).values;
|
const xRank = x.shape.length;
|
const filterVals = cpuBackend.data.get(filter.dataId).values;
|
const filterRank = filter.shape.length;
|
const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = backend_util.computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
|
const outSize = util.sizeFromShape(outShape);
|
const outRank = outShape.length;
|
const outputVals = util.getArrayFromDType(x.dtype, outSize);
|
// Upsampling the input by fill in `dilation size - 1` values between each
|
// input value.
|
// This implementation follows the TF c++ implementation:
|
// https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
|
for (let b = 0; b < batchSize; ++b) {
|
for (let hOut = 0; hOut < outHeight; ++hOut) {
|
const hBeg = hOut * strideHeight - padInfo.top;
|
for (let wOut = 0; wOut < outWidth; ++wOut) {
|
const wBeg = wOut * strideWidth - padInfo.left;
|
for (let d = 0; d < inChannels; ++d) {
|
let curVal = Number.MIN_SAFE_INTEGER;
|
for (let h = 0; h < filterHeight; ++h) {
|
const hIn = hBeg + h * dilationHeight;
|
if (hIn >= 0 && hIn < inHeight) {
|
for (let w = 0; w < filterWidth; ++w) {
|
const wIn = wBeg + w * dilationWidth;
|
if (wIn >= 0 && wIn < inWidth) {
|
const xIndex = util.locToIndex([b, hIn, wIn, d], xRank, util.computeStrides(x.shape));
|
const filterIndex = util.locToIndex([h, w, d], filterRank, util.computeStrides(filter.shape));
|
const val = xVals[xIndex] + filterVals[filterIndex];
|
if (val > curVal) {
|
curVal = val;
|
}
|
}
|
}
|
}
|
}
|
const outputIndex = util.locToIndex([b, hOut, wOut, d], outRank, util.computeStrides(outShape));
|
outputVals[outputIndex] = curVal;
|
}
|
}
|
}
|
}
|
const dataId = cpuBackend.write(util.toTypedArray(outputVals, x.dtype), outShape, x.dtype);
|
return { dataId, shape: outShape, dtype: x.dtype };
|
}
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const dilation2DBackpropFilterConfig = {
|
kernelName: Dilation2DBackpropFilter,
|
backendName: 'cpu',
|
kernelFunc: ({ inputs, backend, attrs }) => {
|
const { x, filter, dy } = inputs;
|
const { strides, pad, dilations } = attrs;
|
const cpuBackend = backend;
|
const $x = util.toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
|
const $filter = util.toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
|
const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = backend_util.computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
|
util.assert(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropFilter}, dy ` +
|
`must have the same rank as output ${outShape.length}, but got ` +
|
`${dy.rank}`);
|
const $dy = util.toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
|
// The computed filter gradients has the same dimensions as the filter:
|
// [filterHeight, filterWidth, depth]
|
const gradients = util.makeZerosNestedTypedArray(filter.shape, filter.dtype);
|
// In the case of multiple argmax branches, we only back-propagate along the
|
// last branch, i.e., the one with largest value of `h * filter_cols + w`,
|
// similarly to the max-pooling backward routines.
|
// This implementation follows the TF c++ implementation:
|
// https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
|
for (let b = 0; b < batchSize; ++b) {
|
for (let hOut = 0; hOut < outHeight; ++hOut) {
|
const hBeg = hOut * strideHeight - padInfo.top;
|
for (let wOut = 0; wOut < outWidth; ++wOut) {
|
const wBeg = wOut * strideWidth - padInfo.left;
|
for (let d = 0; d < inChannels; ++d) {
|
let curVal = Number.MIN_SAFE_INTEGER;
|
let hMax = 0;
|
let wMax = 0;
|
for (let h = 0; h < filterHeight; ++h) {
|
const hIn = hBeg + h * dilationHeight;
|
if (hIn >= 0 && hIn < inHeight) {
|
for (let w = 0; w < filterWidth; ++w) {
|
const wIn = wBeg + w * dilationWidth;
|
if (wIn >= 0 && wIn < inWidth) {
|
const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
|
if (val > curVal) {
|
curVal = val;
|
hMax = h;
|
wMax = w;
|
}
|
}
|
}
|
}
|
}
|
gradients[hMax][wMax][d] += $dy[b][hOut][wOut][d];
|
}
|
}
|
}
|
}
|
const dataId = cpuBackend.write(util.toTypedArray(gradients, x.dtype), filter.shape, filter.dtype);
|
return { dataId, shape: filter.shape, dtype: filter.dtype };
|
}
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const dilation2DBackpropInputConfig = {
|
kernelName: Dilation2DBackpropInput,
|
backendName: 'cpu',
|
kernelFunc: ({ inputs, backend, attrs }) => {
|
const { x, filter, dy } = inputs;
|
const { strides, pad, dilations } = attrs;
|
const cpuBackend = backend;
|
const $x = util.toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
|
const $filter = util.toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
|
const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = backend_util.computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
|
util.assert(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropInput}, dy ` +
|
`must have the same rank as output ${outShape.length}, but got ` +
|
`${dy.rank}`);
|
const $dy = util.toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
|
// The computed gradients has the same dimensions as the input:
|
// [batch, inputHeight, inputCols, inChannel]
|
const gradients = util.makeZerosNestedTypedArray(x.shape, x.dtype);
|
// In the case of multiple argmax branches, we only back-propagate along the
|
// last branch, i.e., the one with largest value of `h * filter_cols + w`,
|
// similarly to the max-pooling backward routines.
|
// This implementation follows the TF c++ implementation:
|
// https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
|
for (let b = 0; b < batchSize; ++b) {
|
for (let hOut = 0; hOut < outHeight; ++hOut) {
|
const hBeg = hOut * strideHeight - padInfo.top;
|
for (let wOut = 0; wOut < outWidth; ++wOut) {
|
const wBeg = wOut * strideWidth - padInfo.left;
|
for (let d = 0; d < inChannels; ++d) {
|
let curVal = Number.MIN_SAFE_INTEGER;
|
let hInMax = (hBeg < 0) ? 0 : hBeg;
|
let wInMax = (wBeg < 0) ? 0 : wBeg;
|
for (let h = 0; h < filterHeight; ++h) {
|
const hIn = hBeg + h * dilationHeight;
|
if (hIn >= 0 && hIn < inHeight) {
|
for (let w = 0; w < filterWidth; ++w) {
|
const wIn = wBeg + w * dilationWidth;
|
if (wIn >= 0 && wIn < inWidth) {
|
const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
|
if (val > curVal) {
|
curVal = val;
|
hInMax = hIn;
|
wInMax = wIn;
|
}
|
}
|
}
|
}
|
}
|
gradients[b][hInMax][wInMax][d] += $dy[b][hOut][wOut][d];
|
}
|
}
|
}
|
}
|
const dataId = cpuBackend.write(util.toTypedArray(gradients, x.dtype), x.shape, x.dtype);
|
return { dataId, shape: x.shape, dtype: x.dtype };
|
}
|
};
|
|
/**
|
* @license
|
* Copyright 2023 Google LLC.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function draw(args) {
|
const { inputs, backend, attrs } = args;
|
const { image } = inputs;
|
const { canvas, options } = attrs;
|
const { contextOptions, imageOptions } = options || {};
|
const alpha = (imageOptions === null || imageOptions === void 0 ? void 0 : imageOptions.alpha) || 1;
|
const contextType = (contextOptions === null || contextOptions === void 0 ? void 0 : contextOptions.contextType) || '2d';
|
if (contextType !== '2d') {
|
throw new Error(`Context type ${contextOptions.contextType} is not supported by the CPU backend.`);
|
}
|
const ctx = canvas.getContext(contextType, (contextOptions === null || contextOptions === void 0 ? void 0 : contextOptions.contextAttributes) || {});
|
if (ctx == null) {
|
throw new Error(`Could not get the context with ${contextType} type.`);
|
}
|
const [height, width] = image.shape.slice(0, 2);
|
const depth = image.shape.length === 2 ? 1 : image.shape[2];
|
const data = backend.data.get(image.dataId).values;
|
const multiplier = image.dtype === 'float32' ? 255 : 1;
|
const bytes = new Uint8ClampedArray(width * height * 4);
|
for (let i = 0; i < height * width; ++i) {
|
const rgba = [0, 0, 0, 255 * alpha];
|
for (let d = 0; d < depth; d++) {
|
const value = data[i * depth + d];
|
if (image.dtype === 'float32') {
|
if (value < 0 || value > 1) {
|
throw new Error(`Tensor values for a float32 Tensor must be in the ` +
|
`range [0 - 1] but encountered ${value}.`);
|
}
|
}
|
else if (image.dtype === 'int32') {
|
if (value < 0 || value > 255) {
|
throw new Error(`Tensor values for a int32 Tensor must be in the ` +
|
`range [0 - 255] but encountered ${value}.`);
|
}
|
}
|
if (depth === 1) {
|
rgba[0] = value * multiplier;
|
rgba[1] = value * multiplier;
|
rgba[2] = value * multiplier;
|
}
|
else {
|
rgba[d] = value * multiplier;
|
}
|
}
|
const j = i * 4;
|
bytes[j + 0] = Math.round(rgba[0]);
|
bytes[j + 1] = Math.round(rgba[1]);
|
bytes[j + 2] = Math.round(rgba[2]);
|
bytes[j + 3] = Math.round(rgba[3]);
|
}
|
canvas.width = width;
|
canvas.height = height;
|
const imageData = new ImageData(bytes, width, height);
|
ctx.putImageData(imageData, 0, 0);
|
return image;
|
}
|
const drawConfig = {
|
kernelName: Draw,
|
backendName: 'cpu',
|
kernelFunc: draw
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function sum(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { axis, keepDims } = attrs;
|
assertNotComplex(x, 'sum');
|
let $x;
|
if (x.dtype === 'bool') {
|
$x = cast({ inputs: { x }, backend, attrs: { dtype: 'int32' } });
|
}
|
else {
|
$x = identity({ inputs: { x }, backend });
|
}
|
const xRank = $x.shape.length;
|
const axes = util.parseAxisParam(axis, $x.shape);
|
const permutation = backend_util.getAxesPermutation(axes, xRank);
|
let reductionAxes = axes;
|
let permutedX = $x;
|
if (permutation != null) {
|
permutedX =
|
transpose({ inputs: { x: $x }, backend, attrs: { perm: permutation } });
|
reductionAxes = backend_util.getInnerMostAxes(reductionAxes.length, xRank);
|
}
|
backend_util.assertAxesAreInnerMostDims('sum', reductionAxes, permutedX.shape.length);
|
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes(permutedX.shape, reductionAxes);
|
const resultDtype = backend_util.upcastType(permutedX.dtype, 'int32');
|
let result = zeros(backend, outShape, resultDtype);
|
const reduceSize = util.sizeFromShape(reduceShape);
|
const vals = backend.data.get(result.dataId).values;
|
const aVals = backend.data.get(permutedX.dataId).values;
|
for (let i = 0; i < vals.length; ++i) {
|
const offset = i * reduceSize;
|
let sum = 0;
|
for (let j = 0; j < reduceSize; ++j) {
|
sum += aVals[offset + j];
|
}
|
vals[i] = sum;
|
}
|
if (keepDims) {
|
const newShape = backend_util.expandShapeToKeepDim(result.shape, axes);
|
const oldResult = result;
|
result = reshape({ inputs: { x: result }, backend, attrs: { shape: newShape } });
|
backend.disposeIntermediateTensorInfo(oldResult);
|
}
|
backend.disposeIntermediateTensorInfo($x);
|
if (permutation != null) {
|
backend.disposeIntermediateTensorInfo(permutedX);
|
}
|
return result;
|
}
|
const sumConfig = {
|
kernelName: Sum,
|
backendName: 'cpu',
|
kernelFunc: sum
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function einsum(args) {
|
const { inputs, backend, attrs } = args;
|
const { equation } = attrs;
|
const tensors = inputs;
|
const { allDims, summedDims, idDims } = backend_util.decodeEinsumEquation(equation, tensors.length);
|
backend_util.checkEinsumDimSizes(allDims.length, idDims, tensors);
|
const { path, steps } = backend_util.getEinsumComputePath(summedDims, idDims);
|
const nSteps = steps.length;
|
let out = null;
|
let numDimsRemaining = allDims.length;
|
const tensorsToDispose = [];
|
for (let i = 0; i < nSteps; ++i) {
|
for (const idTerm of steps[i]) {
|
const { permutationIndices: perm, expandDims: dimsToExpand } = backend_util.getEinsumPermutation(numDimsRemaining, idDims[idTerm]);
|
let x;
|
if (backend_util.isIdentityPermutation(perm)) {
|
x = tensors[idTerm];
|
}
|
else {
|
x = transpose({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } });
|
tensorsToDispose.push(x);
|
}
|
const targetShape = x.shape.slice();
|
for (let k = 0; k < dimsToExpand.length; ++k) {
|
targetShape.splice(dimsToExpand[k], 0, 1);
|
}
|
if (!util.arraysEqual(x.shape, targetShape)) {
|
x = reshape({ inputs: { x }, backend, attrs: { shape: targetShape } });
|
tensorsToDispose.push(x);
|
}
|
if (out === null) {
|
out = x;
|
}
|
else {
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
out = multiply({ inputs: { a: x, b: out }, backend });
|
tensorsToDispose.push(out);
|
}
|
}
|
if (i < nSteps - 1) {
|
if (path[i] >= 0) {
|
out = sum({
|
inputs: { x: out },
|
backend,
|
attrs: {
|
axis: path[i] - (allDims.length - numDimsRemaining),
|
keepDims: false
|
}
|
});
|
tensorsToDispose.push(out);
|
}
|
numDimsRemaining--;
|
}
|
}
|
// Clean up intermediate tensors.
|
for (const tensorInfo of tensorsToDispose) {
|
if (tensorInfo === out) {
|
continue;
|
}
|
backend.disposeIntermediateTensorInfo(tensorInfo);
|
}
|
return out;
|
}
|
const einsumConfig = {
|
kernelName: Einsum,
|
backendName: 'cpu',
|
kernelFunc: einsum
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function eluGrad(args) {
|
const { inputs, backend } = args;
|
const { dy, y } = inputs;
|
assertNotComplex([dy, y], 'eluGrad');
|
const resultValues = new Float32Array(util.sizeFromShape(y.shape));
|
const values = backend.data.get(y.dataId).values;
|
const dyValues = backend.data.get(dy.dataId).values;
|
for (let i = 0; i < values.length; ++i) {
|
const v = values[i];
|
if (v >= 0) {
|
resultValues[i] = dyValues[i];
|
}
|
else {
|
resultValues[i] = dyValues[i] * (v + 1);
|
}
|
}
|
return backend.makeTensorInfo(y.shape, 'float32', resultValues);
|
}
|
const eluGradConfig = {
|
kernelName: EluGrad,
|
backendName: 'cpu',
|
kernelFunc: eluGrad
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const p = backend_util.ERF_P;
|
const a1 = backend_util.ERF_A1;
|
const a2 = backend_util.ERF_A2;
|
const a3 = backend_util.ERF_A3;
|
const a4 = backend_util.ERF_A4;
|
const a5 = backend_util.ERF_A5;
|
const erf = unaryKernelFunc(Erf, (xi) => {
|
const sign = Math.sign(xi);
|
const v = Math.abs(xi);
|
const t = 1.0 / (1.0 + p * v);
|
return sign *
|
(1.0 -
|
(((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
|
Math.exp(-v * v));
|
});
|
const erfConfig = {
|
kernelName: Erf,
|
backendName: 'cpu',
|
kernelFunc: erf,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function expandDims(args) {
|
const { inputs, backend, attrs } = args;
|
const { input } = inputs;
|
const { dim } = attrs;
|
const inputRank = input.shape.length;
|
const newShape = input.shape.slice();
|
let $dim = dim;
|
if (dim < 0) {
|
// Negative value is counted from the tail of rank.
|
util.assert(-(inputRank + 1) <= dim, () => `Axis must be in the interval [${-(inputRank + 1)}, ${inputRank}]`);
|
$dim = inputRank + dim + 1;
|
}
|
newShape.splice($dim, 0, 1);
|
return reshape({ inputs: { x: input }, backend, attrs: { shape: newShape } });
|
}
|
const expandDimsConfig = {
|
kernelName: ExpandDims,
|
backendName: 'cpu',
|
kernelFunc: expandDims
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const realDivImpl = createSimpleBinaryKernelImpl((a, b) => a / b);
|
const div = binaryKernelFunc(RealDiv, realDivImpl);
|
const realDivConfig = {
|
kernelName: RealDiv,
|
backendName: 'cpu',
|
kernelFunc: div
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/**
|
* Calculate FFT of inner most elements of batch tensor.
|
*/
|
function fftBatch(input, inverse, cpuBackend) {
|
const inputShape = input.shape;
|
const batch = inputShape[0];
|
const innerDim = inputShape[1];
|
const inputVals = cpuBackend.data.get(input.dataId);
|
const real2D = inputVals.complexTensorInfos.real;
|
const imag2D = inputVals.complexTensorInfos.imag;
|
// Collects real and imaginary values separately.
|
const resultShape = [batch, innerDim];
|
const resultSize = util.sizeFromShape(resultShape);
|
const resultReal = util.getTypedArrayFromDType('float32', resultSize);
|
const resultImag = util.getTypedArrayFromDType('float32', resultSize);
|
for (let b = 0; b < batch; b++) {
|
// TODO: Support slice ops for complex type.
|
const r = slice({
|
inputs: { x: real2D },
|
backend: cpuBackend,
|
attrs: { begin: [b, 0], size: [1, innerDim] }
|
});
|
const i = slice({
|
inputs: { x: imag2D },
|
backend: cpuBackend,
|
attrs: { begin: [b, 0], size: [1, innerDim] }
|
});
|
const input = complex({ inputs: { real: r, imag: i }, backend: cpuBackend });
|
// Run FFT by batch element.
|
const { real, imag } = fftImpl(input, inverse, cpuBackend);
|
const res = backend_util.mergeRealAndImagArrays(real, imag);
|
for (let d = 0; d < innerDim; d++) {
|
const c = backend_util.getComplexWithIndex(res, d);
|
resultReal[b * innerDim + d] = c.real;
|
resultImag[b * innerDim + d] = c.imag;
|
}
|
cpuBackend.disposeIntermediateTensorInfo(r);
|
cpuBackend.disposeIntermediateTensorInfo(i);
|
cpuBackend.disposeIntermediateTensorInfo(input);
|
}
|
const $realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultReal);
|
const $imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImag);
|
const result = complex({ inputs: { real: $realInfo, imag: $imagInfo }, backend: cpuBackend });
|
cpuBackend.disposeIntermediateTensorInfo($realInfo);
|
cpuBackend.disposeIntermediateTensorInfo($imagInfo);
|
return result;
|
}
|
function fftImpl(input, inverse, cpuBackend) {
|
const inputSize = util.sizeFromShape(input.shape);
|
const inputVals = cpuBackend.data.get(input.dataId);
|
const realVals = cpuBackend.data.get(inputVals.complexTensorInfos.real.dataId).values;
|
const imagVals = cpuBackend.data.get(inputVals.complexTensorInfos.imag.dataId).values;
|
if (isExponentOf2(inputSize)) {
|
const result = fftRadix2(realVals, imagVals, inputSize, inverse, cpuBackend);
|
const resultShape = [input.shape[0], input.shape[1]];
|
if (inverse) {
|
const realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.real);
|
const imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.imag);
|
const sizeInfo = cpuBackend.makeTensorInfo([], 'float32', util.createScalarValue(inputSize, 'float32'));
|
const sizeInfoCopy = identity({ inputs: { x: sizeInfo }, backend: cpuBackend });
|
const divRealInfo = realDivConfig.kernelFunc({ inputs: { a: realInfo, b: sizeInfo }, backend: cpuBackend });
|
const divImagInfo = realDivConfig.kernelFunc({ inputs: { a: imagInfo, b: sizeInfoCopy }, backend: cpuBackend });
|
const divRealVals = cpuBackend.data.get(divRealInfo.dataId).values;
|
const divImagVals = cpuBackend.data.get(divImagInfo.dataId).values;
|
cpuBackend.disposeIntermediateTensorInfo(realInfo);
|
cpuBackend.disposeIntermediateTensorInfo(imagInfo);
|
cpuBackend.disposeIntermediateTensorInfo(sizeInfo);
|
cpuBackend.disposeIntermediateTensorInfo(sizeInfoCopy);
|
cpuBackend.disposeIntermediateTensorInfo(divRealInfo);
|
cpuBackend.disposeIntermediateTensorInfo(divImagInfo);
|
return { real: divRealVals, imag: divImagVals };
|
}
|
return result;
|
}
|
else {
|
const data = backend_util.mergeRealAndImagArrays(realVals, imagVals);
|
const rawOutput = fourierTransformByMatmul(data, inputSize, inverse);
|
return backend_util.splitRealAndImagArrays(rawOutput);
|
}
|
}
|
function isExponentOf2(size) {
|
return (size & size - 1) === 0;
|
}
|
// FFT using Cooley-Tukey algorithm on radix 2 dimensional input.
|
function fftRadix2(realVals, imagVals, size, inverse, cpuBackend) {
|
if (size === 1) {
|
return { real: realVals, imag: imagVals };
|
}
|
const data = backend_util.mergeRealAndImagArrays(realVals, imagVals);
|
const half = size / 2;
|
const evenComplex = backend_util.complexWithEvenIndex(data);
|
const evenRealVals = evenComplex.real;
|
const evenImagVals = evenComplex.imag;
|
const evenShape = [evenRealVals.length];
|
const evenRealInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenRealVals);
|
const evenImagInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenImagVals);
|
const evenTensorInfo = complex({ inputs: { real: evenRealInfo, imag: evenImagInfo }, backend: cpuBackend });
|
const oddComplex = backend_util.complexWithOddIndex(data);
|
const oddRealVals = oddComplex.real;
|
const oddImagVals = oddComplex.imag;
|
const oddShape = [oddRealVals.length];
|
const oddRealInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddRealVals);
|
const oddImagInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddImagVals);
|
const oddTensorInfo = complex({ inputs: { real: oddRealInfo, imag: oddImagInfo }, backend: cpuBackend });
|
// Recursive call for half part of original input.
|
const $evenComplex = fftRadix2(evenRealVals, evenImagVals, half, inverse, cpuBackend);
|
const $evenRealVals = $evenComplex.real;
|
const $evenImagVals = $evenComplex.imag;
|
const $evenShape = [$evenRealVals.length];
|
const $evenRealInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenRealVals);
|
const $evenImagInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenImagVals);
|
const $evenTensorInfo = complex({
|
inputs: { real: $evenRealInfo, imag: $evenImagInfo },
|
backend: cpuBackend
|
});
|
const $oddComplex = fftRadix2(oddRealVals, oddImagVals, half, inverse, cpuBackend);
|
const $oddRealVals = $oddComplex.real;
|
const $oddImagVals = $oddComplex.imag;
|
const $oddShape = [$oddRealVals.length];
|
const $oddRealInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddRealVals);
|
const $oddImagInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddImagVals);
|
const $oddTensorInfo = complex({ inputs: { real: $oddRealInfo, imag: $oddImagInfo }, backend: cpuBackend });
|
const e = backend_util.exponents(size, inverse);
|
const eShape = [e.real.length];
|
const eRealInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.real);
|
const eImagInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.imag);
|
const complexInfo = complex({ inputs: { real: eRealInfo, imag: eImagInfo }, backend: cpuBackend });
|
const exponentInfo = multiply({ inputs: { a: complexInfo, b: $oddTensorInfo }, backend: cpuBackend });
|
const addPart = add({
|
inputs: { a: $evenTensorInfo, b: exponentInfo },
|
backend: cpuBackend
|
});
|
const subPart = sub({
|
inputs: { a: $evenTensorInfo, b: exponentInfo },
|
backend: cpuBackend
|
});
|
const addPartReal = real({ inputs: { input: addPart }, backend: cpuBackend });
|
const subPartReal = real({ inputs: { input: subPart }, backend: cpuBackend });
|
const addPartImag = imag({ inputs: { input: addPart }, backend: cpuBackend });
|
const subPartImag = imag({ inputs: { input: subPart }, backend: cpuBackend });
|
const $real = concat({
|
inputs: [addPartReal, subPartReal],
|
backend: cpuBackend,
|
attrs: { axis: 0 }
|
});
|
const $imag = concat({
|
inputs: [addPartImag, subPartImag],
|
backend: cpuBackend,
|
attrs: { axis: 0 }
|
});
|
const $realVals = cpuBackend.data.get($real.dataId).values;
|
const $imagVals = cpuBackend.data.get($imag.dataId).values;
|
cpuBackend.disposeIntermediateTensorInfo(evenRealInfo);
|
cpuBackend.disposeIntermediateTensorInfo(evenImagInfo);
|
cpuBackend.disposeIntermediateTensorInfo(evenTensorInfo);
|
cpuBackend.disposeIntermediateTensorInfo(oddRealInfo);
|
cpuBackend.disposeIntermediateTensorInfo(oddImagInfo);
|
cpuBackend.disposeIntermediateTensorInfo(oddTensorInfo);
|
cpuBackend.disposeIntermediateTensorInfo($evenRealInfo);
|
cpuBackend.disposeIntermediateTensorInfo($evenImagInfo);
|
cpuBackend.disposeIntermediateTensorInfo($evenTensorInfo);
|
cpuBackend.disposeIntermediateTensorInfo($oddRealInfo);
|
cpuBackend.disposeIntermediateTensorInfo($oddImagInfo);
|
cpuBackend.disposeIntermediateTensorInfo($oddTensorInfo);
|
cpuBackend.disposeIntermediateTensorInfo(eRealInfo);
|
cpuBackend.disposeIntermediateTensorInfo(eImagInfo);
|
cpuBackend.disposeIntermediateTensorInfo(complexInfo);
|
cpuBackend.disposeIntermediateTensorInfo(exponentInfo);
|
cpuBackend.disposeIntermediateTensorInfo(addPart);
|
cpuBackend.disposeIntermediateTensorInfo(subPart);
|
cpuBackend.disposeIntermediateTensorInfo(addPartReal);
|
cpuBackend.disposeIntermediateTensorInfo(addPartImag);
|
cpuBackend.disposeIntermediateTensorInfo(subPartReal);
|
cpuBackend.disposeIntermediateTensorInfo(subPartImag);
|
cpuBackend.disposeIntermediateTensorInfo($real);
|
cpuBackend.disposeIntermediateTensorInfo($imag);
|
return { real: $realVals, imag: $imagVals };
|
}
|
// Calculate fourier transform by multplying sinusoid matrix.
|
function fourierTransformByMatmul(data, size, inverse) {
|
const ret = new Float32Array(size * 2);
|
// TODO: Use matmul instead once it supports complex64 type.
|
for (let r = 0; r < size; r++) {
|
let real = 0.0;
|
let imag = 0.0;
|
for (let c = 0; c < size; c++) {
|
const e = backend_util.exponent(r * c, size, inverse);
|
const term = backend_util.getComplexWithIndex(data, c);
|
real += term.real * e.real - term.imag * e.imag;
|
imag += term.real * e.imag + term.imag * e.real;
|
}
|
if (inverse) {
|
real /= size;
|
imag /= size;
|
}
|
backend_util.assignToTypedArray(ret, real, imag, r);
|
}
|
return ret;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function fft(args) {
|
const { inputs, backend } = args;
|
const { input } = inputs;
|
const inputSize = util.sizeFromShape(input.shape);
|
// Collapse all outer dimensions to a single batch dimension.
|
const innerDimensionSize = input.shape[input.shape.length - 1];
|
const batch = inputSize / innerDimensionSize;
|
const input2D = reshape({
|
inputs: { x: input },
|
backend,
|
attrs: { shape: [batch, innerDimensionSize] }
|
});
|
const result = fftBatch(input2D, false, backend);
|
const resultReshaped = reshape({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
|
backend.disposeIntermediateTensorInfo(input2D);
|
backend.disposeIntermediateTensorInfo(result);
|
return resultReshaped;
|
}
|
const fftConfig = {
|
kernelName: FFT,
|
backendName: 'cpu',
|
kernelFunc: fft
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function fill(args) {
|
const { backend, attrs } = args;
|
const { shape, value, dtype } = attrs;
|
const $dtype = dtype || util.inferDtype(value);
|
const values = util.getArrayFromDType($dtype, util.sizeFromShape(shape));
|
fillValues(values, value, $dtype);
|
return backend.makeTensorInfo(shape, $dtype, values);
|
}
|
const fillConfig = {
|
kernelName: Fill,
|
backendName: 'cpu',
|
kernelFunc: fill
|
};
|
function fillValues(values, value, dtype) {
|
if (dtype === 'string') {
|
values.fill(value);
|
}
|
else {
|
values.fill(value);
|
}
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const flipLeftRightConfig = {
|
kernelName: FlipLeftRight,
|
backendName: 'cpu',
|
kernelFunc: ({ inputs, attrs, backend }) => {
|
const { image } = inputs;
|
const cpuBackend = backend;
|
const output = util.getTypedArrayFromDType(image.dtype, util.sizeFromShape(image.shape));
|
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
|
const imageVals = cpuBackend.data.get(image.dataId).values;
|
for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
|
const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
|
for (let row = 0; row < imageHeight; row++) {
|
const rowOffset = row * (imageWidth * numChannels);
|
for (let col = 0; col < imageWidth; col++) {
|
const colOffset = col * numChannels;
|
for (let channel = 0; channel < numChannels; channel++) {
|
const coordX = Math.round(imageWidth - col - 1);
|
const outIdx = batchOffset + rowOffset + colOffset + channel;
|
let outputValue = imageVals[outIdx];
|
// If the coordinate position falls within the image boundaries...
|
if (coordX >= 0 && coordX < imageWidth) {
|
// set the output to the image value at the coordinate position.
|
const rotatedColOffset = coordX * numChannels;
|
const imageIdx = batchOffset + rowOffset + rotatedColOffset + channel;
|
outputValue = imageVals[imageIdx];
|
}
|
output[outIdx] = outputValue;
|
}
|
}
|
}
|
}
|
const dataId = cpuBackend.write(output, image.shape, image.dtype);
|
return { dataId, shape: image.shape, dtype: image.dtype };
|
}
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function fusedConv2D(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, filter, bias, preluActivationWeights } = inputs;
|
const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
|
let result = conv2D({
|
inputs: { x, filter },
|
backend,
|
attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
|
});
|
if (bias) {
|
const resultOld = result;
|
// For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned
|
// to the channel of the conv2d's result; if the bias is a scalar, the
|
// bias_add is computed as if the bias was broadcasted to the shape of the
|
// conv2d's result.
|
if (dataFormat === 'NCHW' && bias.shape.length === 1 &&
|
bias.shape[0] !== 1) {
|
const reshapedBias = reshape({ inputs: { x: bias }, backend, attrs: { shape: [bias.shape[0], 1, 1] } });
|
result =
|
add({ inputs: { a: result, b: reshapedBias }, backend });
|
backend.disposeIntermediateTensorInfo(reshapedBias);
|
}
|
else {
|
// This condition handles NHWC and NCHW (scalar case). The only other case
|
// for NCHW (1D case) is handled above.
|
result = add({ inputs: { a: result, b: bias }, backend });
|
}
|
backend.disposeIntermediateTensorInfo(resultOld);
|
}
|
if (activation) {
|
const resultOld = result;
|
// For NCHW format, if PReLu activation weights is a 1-D tensor, it is
|
// supposed to be aligned with the channel of the conv2d's result. For other
|
// cases, whether NCHW or NHWC data format, the conv2d result is
|
// already aligned with the activation weights.
|
if (dataFormat === 'NCHW' && activation === 'prelu' &&
|
preluActivationWeights.shape.length === 1 &&
|
preluActivationWeights.shape[0] !== 1) {
|
const reshapedAlpha = reshape({
|
inputs: { x: preluActivationWeights },
|
backend,
|
attrs: { shape: [preluActivationWeights.shape[0], 1, 1] }
|
});
|
result = applyActivation(backend, result, activation, reshapedAlpha, leakyreluAlpha);
|
backend.disposeIntermediateTensorInfo(reshapedAlpha);
|
}
|
else {
|
result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
|
}
|
backend.disposeIntermediateTensorInfo(resultOld);
|
}
|
return result;
|
}
|
const fusedConv2DConfig = {
|
kernelName: FusedConv2D,
|
backendName: 'cpu',
|
kernelFunc: fusedConv2D
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function fusedDepthwiseConv2D(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, filter, bias, preluActivationWeights } = inputs;
|
const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
|
let result = depthwiseConv2dNative({
|
inputs: { x, filter },
|
backend,
|
attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
|
});
|
if (bias) {
|
const oldResult = result;
|
result = add({ inputs: { a: result, b: bias }, backend });
|
backend.disposeIntermediateTensorInfo(oldResult);
|
}
|
if (activation) {
|
const oldResult = result;
|
result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
|
backend.disposeIntermediateTensorInfo(oldResult);
|
}
|
return result;
|
}
|
const fusedDepthwiseConv2DConfig = {
|
kernelName: FusedDepthwiseConv2D,
|
backendName: 'cpu',
|
kernelFunc: fusedDepthwiseConv2D
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function gatherNd(args) {
|
const { inputs, backend } = args;
|
const { params, indices } = inputs;
|
const paramsSize = util.sizeFromShape(params.shape);
|
const indicesShape = indices.shape;
|
const sliceRank = indicesShape[indicesShape.length - 1];
|
const [resultShape, numSlices, sliceSize, strides] = backend_util.prepareAndValidate(params, indices);
|
if (numSlices === 0) {
|
return backend.makeTensorInfo(resultShape, params.dtype, []);
|
}
|
const indicesData = backend.data.get(indices.dataId).values;
|
const paramsBuf = backend.bufferSync(params);
|
const outBuf = gatherNdImpl(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
|
return backend.makeTensorInfo(resultShape, params.dtype, outBuf.values);
|
}
|
const gatherNdConfig = {
|
kernelName: GatherNd,
|
backendName: 'cpu',
|
kernelFunc: gatherNd
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function gatherV2(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, indices } = inputs;
|
const { axis, batchDims } = attrs;
|
assertNotComplex([x, indices], 'gatherV2');
|
// Throw error when any index is out of bound.
|
const parsedAxis = util.parseAxisParam(axis, x.shape)[0];
|
const indicesVals = backend.data.get(indices.dataId).values;
|
const axisDim = x.shape[parsedAxis];
|
for (let i = 0; i < indicesVals.length; ++i) {
|
const index = indicesVals[i];
|
util.assert(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
|
}
|
let $batchDims = batchDims;
|
if (batchDims == null) {
|
$batchDims = 0;
|
}
|
const indicesSize = util.sizeFromShape(indices.shape);
|
const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo(x, indices, parsedAxis, $batchDims);
|
const flattenX = reshape({
|
inputs: { x },
|
backend,
|
attrs: {
|
shape: [
|
shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
|
shapeInfo.sliceSize
|
]
|
}
|
});
|
const flattenIndex = reshape({
|
inputs: { x: indices },
|
backend,
|
attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
|
});
|
const flattenOutputShape = [
|
shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
|
shapeInfo.sliceSize
|
];
|
const indicesBuf = backend.bufferSync(flattenIndex);
|
const xBuf = backend.bufferSync(flattenX);
|
const outBuf = gatherV2Impl(xBuf, indicesBuf, flattenOutputShape);
|
backend.disposeIntermediateTensorInfo(flattenX);
|
backend.disposeIntermediateTensorInfo(flattenIndex);
|
return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
|
}
|
const gatherV2Config = {
|
kernelName: GatherV2,
|
backendName: 'cpu',
|
kernelFunc: gatherV2
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function ifft(args) {
|
const { inputs, backend } = args;
|
const { input } = inputs;
|
const inputSize = util.sizeFromShape(input.shape);
|
// Collapse all outer dimensions to a single batch dimension.
|
const innerDimensionSize = input.shape[input.shape.length - 1];
|
const batch = inputSize / innerDimensionSize;
|
const input2D = reshape({
|
inputs: { x: input },
|
backend,
|
attrs: { shape: [batch, innerDimensionSize] }
|
});
|
const result = fftBatch(input2D, true, backend);
|
const resultReshaped = reshape({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
|
backend.disposeIntermediateTensorInfo(input2D);
|
backend.disposeIntermediateTensorInfo(result);
|
return resultReshaped;
|
}
|
const ifftConfig = {
|
kernelName: IFFT,
|
backendName: 'cpu',
|
kernelFunc: ifft
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const isFinite = unaryKernelFunc(IsFinite, (xi) => Number.isFinite(xi) ? 1 : 0, 'bool');
|
const isFiniteConfig = {
|
kernelName: IsFinite,
|
backendName: 'cpu',
|
kernelFunc: isFinite,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const isInf = unaryKernelFunc(IsInf, (xi) => Math.abs(xi) === Infinity ? 1 : 0, 'bool');
|
const isInfConfig = {
|
kernelName: IsInf,
|
backendName: 'cpu',
|
kernelFunc: isInf,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const isNaN$1 = unaryKernelFunc(IsNan, (xi) => Number.isNaN(xi) ? 1 : 0, 'bool');
|
const isNaNConfig = {
|
kernelName: IsNan,
|
backendName: 'cpu',
|
kernelFunc: isNaN$1,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function linSpace(args) {
|
const { backend, attrs } = args;
|
const { start, stop, num } = attrs;
|
const outVals = linSpaceImpl(start, stop, num);
|
return backend.makeTensorInfo([outVals.length], 'float32', outVals);
|
}
|
const linSpaceConfig = {
|
kernelName: LinSpace,
|
backendName: 'cpu',
|
kernelFunc: linSpace
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const log1p = unaryKernelFunc(Log1p, (xi) => Math.log1p(xi));
|
const log1pConfig = {
|
kernelName: Log1p,
|
backendName: 'cpu',
|
kernelFunc: log1p,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const logicalAndImpl = createSimpleBinaryKernelImpl((a, b) => a && b);
|
const logicalAnd = binaryKernelFunc(LogicalAnd, logicalAndImpl, null /* complexImpl */, 'bool');
|
const logicalAndConfig = {
|
kernelName: LogicalAnd,
|
backendName: 'cpu',
|
kernelFunc: logicalAnd
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const logicalNot = unaryKernelFunc(LogicalNot, (xi) => xi ? 0 : 1, 'bool');
|
const logicalNotConfig = {
|
kernelName: LogicalNot,
|
backendName: 'cpu',
|
kernelFunc: logicalNot,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const logicalOrImpl = createSimpleBinaryKernelImpl((a, b) => a || b);
|
const logicalOr = binaryKernelFunc(LogicalOr, logicalOrImpl, null /* complexImpl */, 'bool');
|
const logicalOrConfig = {
|
kernelName: LogicalOr,
|
backendName: 'cpu',
|
kernelFunc: logicalOr
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function lRN(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { depthRadius, bias, alpha, beta } = attrs;
|
assertNotComplex(x, 'LRN');
|
const channels = x.shape[3];
|
const maxD = channels - 1;
|
const xValues = backend.data.get(x.dataId).values;
|
const size = util.sizeFromShape(x.shape);
|
const result = new Float32Array(size);
|
function sumAcrossChannels(offset) {
|
const currentChannel = offset % channels;
|
let beginSumOffset = offset - currentChannel + Math.max(0, currentChannel - depthRadius);
|
const endSumOffset = offset - currentChannel + Math.min(currentChannel + depthRadius, maxD);
|
let sum = 0.0;
|
for (; beginSumOffset <= endSumOffset; beginSumOffset++) {
|
const z = xValues[beginSumOffset];
|
sum += z * z;
|
}
|
return sum;
|
}
|
for (let offset = 0; offset < size; offset++) {
|
const sum = sumAcrossChannels(offset);
|
const val = xValues[offset] * Math.pow(bias + alpha * sum, -beta);
|
result[offset] = val;
|
}
|
return backend.makeTensorInfo(x.shape, x.dtype, result);
|
}
|
// tslint:disable-next-line: variable-name
|
const LRNConfig = {
|
kernelName: LRN,
|
backendName: 'cpu',
|
kernelFunc: lRN
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function lRNGrad(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, y, dy } = inputs;
|
const { depthRadius, bias, alpha, beta } = attrs;
|
assertNotComplex(dy, 'LRNGrad');
|
const dySize = util.sizeFromShape(dy.shape);
|
const channels = dy.shape[3];
|
const dyValues = backend.data.get(dy.dataId).values;
|
const xValues = backend.data.get(x.dataId).values;
|
const yValues = backend.data.get(y.dataId).values;
|
const result = new Float32Array(dySize);
|
const size = dySize;
|
for (let offset = 0; offset < size; offset++) {
|
const currentChannel = offset % channels;
|
const depthBegin = (offset - currentChannel) + Math.max(0, currentChannel - depthRadius);
|
const depthEnd = (offset - currentChannel) +
|
Math.min(channels, currentChannel + depthRadius + 1);
|
let norm = 0;
|
for (let k = depthBegin; k < depthEnd; k++) {
|
norm += Math.pow(xValues[k], 2);
|
}
|
norm = alpha * norm + bias;
|
for (let k = depthBegin; k < depthEnd; k++) {
|
let dyi = -2 * alpha * beta * xValues[k] * yValues[offset] / norm;
|
if (offset === k) {
|
dyi += Math.pow(norm, -beta);
|
}
|
dyi *= dyValues[offset];
|
result[k] += dyi;
|
}
|
}
|
return backend.makeTensorInfo(dy.shape, x.dtype, result);
|
}
|
// tslint:disable-next-line: variable-name
|
const LRNGradConfig = {
|
kernelName: LRNGrad,
|
backendName: 'cpu',
|
kernelFunc: lRNGrad
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function max(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { reductionIndices, keepDims } = attrs;
|
const cpuBackend = backend;
|
let xShape = x.shape;
|
const xRank = xShape.length;
|
const origAxes = util.parseAxisParam(reductionIndices, xShape);
|
let axes = origAxes;
|
const permutedAxes = backend_util.getAxesPermutation(axes, xRank);
|
let xVals = cpuBackend.data.get(x.dataId).values;
|
if (permutedAxes != null) {
|
const newShape = new Array(xRank);
|
for (let i = 0; i < newShape.length; i++) {
|
newShape[i] = xShape[permutedAxes[i]];
|
}
|
xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes, newShape);
|
axes = backend_util.getInnerMostAxes(axes.length, xRank);
|
xShape = newShape;
|
}
|
assertNotComplex(x, 'max');
|
backend_util.assertAxesAreInnerMostDims('max', axes, xRank);
|
const [maxOutShape, reduceShape] = backend_util.computeOutAndReduceShapes(xShape, axes);
|
const reduceSize = util.sizeFromShape(reduceShape);
|
const result = maxImpl(xVals, reduceSize, maxOutShape, x.dtype);
|
const dataId = cpuBackend.write(result, maxOutShape, x.dtype);
|
let outShape = maxOutShape;
|
if (keepDims) {
|
// reshape
|
const newShape = backend_util.expandShapeToKeepDim(maxOutShape, origAxes);
|
outShape = newShape;
|
}
|
return { dataId, shape: outShape, dtype: x.dtype };
|
}
|
const maxConfig = {
|
kernelName: Max,
|
backendName: 'cpu',
|
kernelFunc: max
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function maxPool(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
assertNotComplex(x, 'maxPool');
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
const dilations = 1;
|
util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
|
`Got strides ${strides} and dilations '${dilations}'`);
|
const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
let res;
|
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
|
util.arraysEqual(convInfo.inShape, convInfo.outShape)) {
|
res = identity({ inputs: { x }, backend });
|
}
|
else {
|
const xValues = backend.data.get(x.dataId).values;
|
const strides = util.computeStrides(x.shape);
|
const buffer = pool(xValues, x.shape, x.dtype, strides, convInfo, 'max');
|
res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
|
}
|
return res;
|
}
|
const maxPoolConfig = {
|
kernelName: MaxPool,
|
backendName: 'cpu',
|
kernelFunc: maxPool
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function maxPool3D(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
|
assertNotComplex(x, 'maxPool3d');
|
const convInfo = backend_util.computePool3DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode, dataFormat);
|
const xValues = backend.data.get(x.dataId).values;
|
const outBuf = pool3d(xValues, x.shape, x.dtype, util.computeStrides(x.shape), convInfo, 'max');
|
return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
|
}
|
const maxPool3DConfig = {
|
kernelName: MaxPool3D,
|
backendName: 'cpu',
|
kernelFunc: maxPool3D
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function maxPool3DGrad(args) {
|
const { inputs, backend, attrs } = args;
|
const { dy, input } = inputs;
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
assertNotComplex([dy, input], 'maxPool3DGrad');
|
const convInfo = backend_util.computePool3DInfo(input.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
|
const inputBuf = backend.bufferSync(input);
|
const maxPosBuf = maxPool3dPositions(inputBuf, convInfo);
|
const strideDepth = convInfo.strideDepth;
|
const strideHeight = convInfo.strideHeight;
|
const strideWidth = convInfo.strideWidth;
|
const dilationDepth = convInfo.dilationDepth;
|
const dilationHeight = convInfo.dilationHeight;
|
const dilationWidth = convInfo.dilationWidth;
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
const dx = buffer(input.shape, 'float32');
|
const dyBuf = backend.bufferSync(dy);
|
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
|
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
|
for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
|
for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
|
for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
|
// Shader code begins
|
const dyDepthCorner = dxDepth - padFront;
|
const dyRowCorner = dxRow - padTop;
|
const dyColCorner = dxCol - padLeft;
|
let dotProd = 0;
|
for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
|
const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
|
if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
|
Math.floor(dyDepth) !== dyDepth) {
|
continue;
|
}
|
for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
|
const dyRow = (dyRowCorner + wRow) / strideHeight;
|
if (dyRow < 0 || dyRow >= convInfo.outHeight ||
|
Math.floor(dyRow) !== dyRow) {
|
continue;
|
}
|
for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
|
const dyCol = (dyColCorner + wCol) / strideWidth;
|
if (dyCol < 0 || dyCol >= convInfo.outWidth ||
|
Math.floor(dyCol) !== dyCol) {
|
continue;
|
}
|
const maxPos = effectiveFilterDepth * effectiveFilterHeight *
|
effectiveFilterWidth -
|
1 -
|
maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);
|
const curPos = wDepth * effectiveFilterHeight * effectiveFilterWidth +
|
wRow * effectiveFilterWidth + wCol;
|
const mask = maxPos === curPos ? 1 : 0;
|
if (mask === 0) {
|
continue;
|
}
|
const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
|
dotProd += pixel * mask;
|
}
|
}
|
}
|
dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);
|
}
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
}
|
const maxPool3DGradConfig = {
|
kernelName: MaxPool3DGrad,
|
backendName: 'cpu',
|
kernelFunc: maxPool3DGrad
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function maxPoolGrad(args) {
|
const { inputs, backend, attrs } = args;
|
const { dy, input, output } = inputs;
|
const x = input;
|
assertNotComplex([input, output], 'maxPoolGrad');
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
|
const xValues = backend.data.get(x.dataId).values;
|
const maxPosBuf = buffer(convInfo.outShape, x.dtype, maxPoolPositions(xValues, x.shape, x.dtype, convInfo).values);
|
const strideHeight = convInfo.strideHeight;
|
const strideWidth = convInfo.strideWidth;
|
const dilationHeight = convInfo.dilationHeight;
|
const dilationWidth = convInfo.dilationWidth;
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
const dx = buffer(x.shape, 'float32');
|
const dyData = backend.data.get(dy.dataId).values;
|
const dyBuf = buffer(dy.shape, 'float32', dyData);
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
for (let d = 0; d < convInfo.inChannels; ++d) {
|
for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
|
for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
|
// Shader code begins.
|
const dyRCorner = dxR - padTop;
|
const dyCCorner = dxC - padLeft;
|
let dotProd = 0;
|
for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
|
const dyR = (dyRCorner + wR) / strideHeight;
|
if (dyR < 0 || dyR >= convInfo.outHeight ||
|
Math.floor(dyR) !== dyR) {
|
continue;
|
}
|
for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
|
const dyC = (dyCCorner + wC) / strideWidth;
|
if (dyC < 0 || dyC >= convInfo.outWidth ||
|
Math.floor(dyC) !== dyC) {
|
continue;
|
}
|
const maxPos = effectiveFilterHeight * effectiveFilterWidth - 1 -
|
maxPosBuf.get(b, dyR, dyC, d);
|
const curPos = wR * effectiveFilterWidth + wC;
|
const mask = maxPos === curPos ? 1 : 0;
|
if (mask === 0) {
|
continue;
|
}
|
const pixel = dyBuf.get(b, dyR, dyC, d);
|
dotProd += pixel * mask;
|
}
|
}
|
dx.set(dotProd, b, dxR, dxC, d);
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
}
|
const maxPoolGradConfig = {
|
kernelName: MaxPoolGrad,
|
backendName: 'cpu',
|
kernelFunc: maxPoolGrad
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function maxPoolWithArgmaxImpl(xValues, xShape, dtype, includeBatchInIndex, convInfo) {
|
const strides = util.computeStrides(xShape);
|
const maxPools = pool(xValues, xShape, dtype, strides, convInfo, 'max');
|
const maxPositions = maxPoolPositions(xValues, xShape, dtype, convInfo, true, includeBatchInIndex);
|
return [maxPools.values, maxPositions.values];
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const maxPoolWithArgmaxConfig = {
|
kernelName: MaxPoolWithArgmax,
|
backendName: 'cpu',
|
kernelFunc: ({ inputs, attrs, backend }) => {
|
const { x } = inputs;
|
const { filterSize, strides, pad, includeBatchInIndex } = attrs;
|
const cpuBackend = backend;
|
assertNotComplex(x, 'MaxPoolWithArgmax');
|
const values = cpuBackend.data.get(x.dataId).values;
|
const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad);
|
const [pooled, indexes] = maxPoolWithArgmaxImpl(values, x.shape, x.dtype, includeBatchInIndex, convInfo);
|
const pooledDataId = cpuBackend.write(pooled, convInfo.outShape, x.dtype);
|
const indexesDataId = cpuBackend.write(indexes, convInfo.outShape, x.dtype);
|
return [
|
{ dataId: pooledDataId, shape: convInfo.outShape, dtype: x.dtype },
|
{ dataId: indexesDataId, shape: convInfo.outShape, dtype: 'int32' }
|
];
|
}
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function mean(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { axis, keepDims } = attrs;
|
const axes = util.parseAxisParam(axis, x.shape);
|
const shapes = backend_util.computeOutAndReduceShapes(x.shape, axes);
|
const reduceShape = shapes[1];
|
const reduceSize = util.sizeFromShape(reduceShape);
|
const toDispose = [];
|
const reduceSizeScalar = backend.makeTensorInfo([], 'float32', new Float32Array([reduceSize]));
|
toDispose.push(reduceSizeScalar);
|
const $x = cast({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
|
toDispose.push($x);
|
const res = div({ inputs: { a: $x, b: reduceSizeScalar }, backend });
|
toDispose.push(res);
|
const result = sum({ inputs: { x: res }, backend, attrs: { axis, keepDims } });
|
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
return result;
|
}
|
const meanConfig = {
|
kernelName: Mean,
|
backendName: 'cpu',
|
kernelFunc: mean
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function min(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { axis, keepDims } = attrs;
|
assertNotComplex(x, 'min');
|
const origAxes = util.parseAxisParam(axis, x.shape);
|
let axes = origAxes;
|
const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length);
|
let $x = x;
|
if (permutedAxes != null) {
|
$x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
axes = backend_util.getInnerMostAxes(axes.length, x.shape.length);
|
}
|
backend_util.assertAxesAreInnerMostDims('min', axes, $x.shape.length);
|
const [outShape, reduceShape] = backend_util.computeOutAndReduceShapes($x.shape, axes);
|
const reduceSize = util.sizeFromShape(reduceShape);
|
const vals = util.makeZerosTypedArray(util.sizeFromShape(outShape), $x.dtype);
|
const aVals = backend.data.get($x.dataId).values;
|
for (let i = 0; i < vals.length; ++i) {
|
const offset = i * reduceSize;
|
let min = aVals[offset];
|
for (let j = 0; j < reduceSize; ++j) {
|
const value = aVals[offset + j];
|
if (Number.isNaN(value) ||
|
value < min) { // comparison with NaN always return false
|
min = value;
|
}
|
}
|
vals[i] = min;
|
}
|
if (permutedAxes != null) {
|
backend.disposeIntermediateTensorInfo($x);
|
}
|
const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
|
if (keepDims) {
|
const expandedShape = backend_util.expandShapeToKeepDim(outShape, origAxes);
|
const reshapedResult = reshape({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
|
backend.disposeIntermediateTensorInfo(result);
|
return reshapedResult;
|
}
|
return result;
|
}
|
const minConfig = {
|
kernelName: Min,
|
backendName: 'cpu',
|
kernelFunc: min
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function mirrorPad(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { paddings, mode } = attrs;
|
assertNotComplex(x, 'mirrorPad');
|
const outShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
|
const start = paddings.map(p => p[0]);
|
const end = paddings.map((p, i) => p[0] + x.shape[i]);
|
const offset = mode === 'reflect' ? 0 : 1;
|
const xVals = backend.data.get(x.dataId).values;
|
const xRank = x.shape.length;
|
const xStrides = util.computeStrides(x.shape);
|
const resultSize = util.sizeFromShape(outShape);
|
const resultRank = outShape.length;
|
const resultStrides = util.computeStrides(outShape);
|
const resVals = util.getTypedArrayFromDType(x.dtype, resultSize);
|
for (let i = 0; i < resultSize; i++) {
|
let coords = util.indexToLoc(i, resultRank, resultStrides);
|
for (let i = 0; i < resultRank; i++) {
|
if (coords[i] < start[i]) {
|
coords[i] = start[i] * 2 - coords[i] - offset;
|
}
|
else if (coords[i] >= end[i]) {
|
coords[i] = (end[i] - 1) * 2 - coords[i] + offset;
|
}
|
}
|
coords = coords.map((c, i) => c - start[i]);
|
const inIndex = util.locToIndex(coords, xRank, xStrides);
|
resVals[i] = xVals[inIndex];
|
}
|
const outId = backend.write(resVals, outShape, x.dtype);
|
return { dataId: outId, shape: outShape, dtype: x.dtype };
|
}
|
const mirrorPadConfig = {
|
kernelName: MirrorPad,
|
backendName: 'cpu',
|
kernelFunc: mirrorPad
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const modImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => {
|
const rem = aValue % bValue;
|
if ((aValue < 0 && bValue < 0) || (aValue >= 0 && bValue >= 0)) {
|
return rem;
|
}
|
else {
|
return (rem + bValue) % bValue;
|
}
|
}));
|
const mod = binaryKernelFunc(Mod, modImpl);
|
const modConfig = {
|
kernelName: Mod,
|
backendName: 'cpu',
|
kernelFunc: mod
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function softmax(args) {
|
const { inputs, backend, attrs } = args;
|
const { logits } = inputs;
|
const { dim } = attrs;
|
const logitsRank = logits.shape.length;
|
let $dim = dim;
|
if ($dim === -1) {
|
$dim = logitsRank - 1;
|
}
|
if ($dim !== logitsRank - 1) {
|
throw Error('Softmax along a non-last dimension is not yet supported. ' +
|
`Logits was rank ${logitsRank} and dim was ${$dim}`);
|
}
|
const axes = util.parseAxisParam([$dim], logits.shape);
|
const maxLogit = max({
|
inputs: { x: logits },
|
backend,
|
attrs: { reductionIndices: axes, keepDims: false }
|
});
|
const expandedShape = backend_util.expandShapeToKeepDim(maxLogit.shape, axes);
|
const maxLogitReshaped = reshape({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } });
|
const a = sub({ inputs: { a: logits, b: maxLogitReshaped }, backend });
|
const b = exp({ inputs: { x: a }, backend });
|
const sumExp = sum({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } });
|
const sumReshaped = reshape({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } });
|
const result = div({ inputs: { a: b, b: sumReshaped }, backend });
|
backend.disposeIntermediateTensorInfo(maxLogit);
|
backend.disposeIntermediateTensorInfo(maxLogitReshaped);
|
backend.disposeIntermediateTensorInfo(a);
|
backend.disposeIntermediateTensorInfo(b);
|
backend.disposeIntermediateTensorInfo(sumExp);
|
backend.disposeIntermediateTensorInfo(sumReshaped);
|
return result;
|
}
|
const softmaxConfig = {
|
kernelName: Softmax,
|
backendName: 'cpu',
|
kernelFunc: softmax
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function multinomial(args) {
|
const { inputs, backend, attrs } = args;
|
const { logits } = inputs;
|
const { numSamples, seed, normalized } = attrs;
|
assertNotComplex(logits, 'multinomial');
|
const probabilities = normalized ?
|
logits :
|
softmax({ inputs: { logits }, backend, attrs: { dim: -1 } });
|
const batchSize = probabilities.shape[0];
|
const numEvents = probabilities.shape[1];
|
const probVals = backend.data.get(probabilities.dataId).values;
|
const resShape = [batchSize, numSamples];
|
const resVals = util.makeZerosTypedArray(util.sizeFromShape(resShape), 'int32');
|
for (let b = 0; b < batchSize; ++b) {
|
const offset = b * numEvents;
|
// The cdf won't include the last event. It will be implicit if no other
|
// event happened.
|
const cdf = new Float32Array(numEvents - 1);
|
cdf[0] = probVals[offset];
|
for (let event = 1; event < cdf.length; ++event) {
|
cdf[event] = cdf[event - 1] + probVals[offset + event];
|
}
|
const random = seedrandom.alea(seed.toString());
|
const outOffset = b * numSamples;
|
for (let sampleId = 0; sampleId < numSamples; ++sampleId) {
|
const r = random();
|
// Assume last event happened by default.
|
resVals[outOffset + sampleId] = cdf.length;
|
for (let event = 0; event < cdf.length; event++) {
|
if (r < cdf[event]) {
|
resVals[outOffset + sampleId] = event;
|
break;
|
}
|
}
|
}
|
}
|
if (!normalized) {
|
backend.disposeIntermediateTensorInfo(probabilities);
|
}
|
return backend.makeTensorInfo(resShape, 'int32', resVals);
|
}
|
const multinomialConfig = {
|
kernelName: Multinomial,
|
backendName: 'cpu',
|
kernelFunc: multinomial
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const nonMaxSuppressionV3Impl = kernel_impls.nonMaxSuppressionV3Impl;
|
function nonMaxSuppressionV3(args) {
|
const { inputs, backend, attrs } = args;
|
const { boxes, scores } = inputs;
|
const { maxOutputSize, iouThreshold, scoreThreshold } = attrs;
|
assertNotComplex(boxes, 'NonMaxSuppression');
|
const boxesVals = backend.data.get(boxes.dataId).values;
|
const scoresVals = backend.data.get(scores.dataId).values;
|
const { selectedIndices } = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
|
return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
|
}
|
const nonMaxSuppressionV3Config = {
|
kernelName: NonMaxSuppressionV3,
|
backendName: 'cpu',
|
kernelFunc: nonMaxSuppressionV3
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const nonMaxSuppressionV4Impl = kernel_impls.nonMaxSuppressionV4Impl;
|
function nonMaxSuppressionV4(args) {
|
const { inputs, backend, attrs } = args;
|
const { boxes, scores } = inputs;
|
const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs;
|
assertNotComplex(boxes, 'NonMaxSuppressionPadded');
|
const boxesVals = backend.data.get(boxes.dataId).values;
|
const scoresVals = backend.data.get(scores.dataId).values;
|
const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
|
return [
|
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
|
backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))
|
];
|
}
|
const nonMaxSuppressionV4Config = {
|
kernelName: NonMaxSuppressionV4,
|
backendName: 'cpu',
|
kernelFunc: nonMaxSuppressionV4
|
};
|
|
/**
|
* @license
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const nonMaxSuppressionV5Impl = kernel_impls.nonMaxSuppressionV5Impl;
|
function nonMaxSuppressionV5(args) {
|
const { inputs, backend, attrs } = args;
|
const { boxes, scores } = inputs;
|
const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs;
|
assertNotComplex(boxes, 'NonMaxSuppressionWithScore');
|
const boxesVals = backend.data.get(boxes.dataId).values;
|
const scoresVals = backend.data.get(scores.dataId).values;
|
const maxOutputSizeVal = maxOutputSize;
|
const iouThresholdVal = iouThreshold;
|
const scoreThresholdVal = scoreThreshold;
|
const softNmsSigmaVal = softNmsSigma;
|
const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal);
|
return [
|
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
|
backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))
|
];
|
}
|
const nonMaxSuppressionV5Config = {
|
kernelName: NonMaxSuppressionV5,
|
backendName: 'cpu',
|
kernelFunc: nonMaxSuppressionV5
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function oneHot(args) {
|
const { inputs, backend, attrs } = args;
|
const { indices } = inputs;
|
const { dtype, depth, onValue, offValue } = attrs;
|
assertNotComplex(indices, 'oneHot');
|
const indicesSize = util.sizeFromShape(indices.shape);
|
const res = new Float32Array(indicesSize * depth);
|
res.fill(offValue);
|
const indicesVal = backend.data.get(indices.dataId).values;
|
for (let event = 0; event < indicesSize; ++event) {
|
if (indicesVal[event] >= 0 && indicesVal[event] < depth) {
|
res[event * depth + indicesVal[event]] = onValue;
|
}
|
}
|
return backend.makeTensorInfo([...indices.shape, depth], dtype, res);
|
}
|
const oneHotConfig = {
|
kernelName: OneHot,
|
backendName: 'cpu',
|
kernelFunc: oneHot
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function zerosLike(args) {
|
const { inputs, backend } = args;
|
const { x } = inputs;
|
if (x.dtype === 'string') {
|
throw new Error('zerosLike is not supported for string tensors');
|
}
|
else if (x.dtype === 'complex64') {
|
const realPart = real({ inputs: { input: x }, backend });
|
const r = zerosLike({ inputs: { x: realPart }, backend });
|
const imagPart = imag({ inputs: { input: x }, backend });
|
const i = zerosLike({ inputs: { x: imagPart }, backend });
|
const result = complex({ inputs: { real: r, imag: i }, backend });
|
backend.disposeIntermediateTensorInfo(realPart);
|
backend.disposeIntermediateTensorInfo(r);
|
backend.disposeIntermediateTensorInfo(imagPart);
|
backend.disposeIntermediateTensorInfo(i);
|
return result;
|
}
|
else {
|
return fill({ backend, attrs: { shape: x.shape, value: 0, dtype: x.dtype } });
|
}
|
}
|
const zerosLikeConfig = {
|
kernelName: ZerosLike,
|
backendName: 'cpu',
|
kernelFunc: zerosLike
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function onesLike(args) {
|
const { inputs, backend } = args;
|
const { x } = inputs;
|
if (x.dtype === 'string') {
|
throw new Error('onesLike is not supported for string tensors');
|
}
|
else if (x.dtype === 'complex64') {
|
const realPart = real({ inputs: { input: x }, backend });
|
const r = onesLike({ inputs: { x: realPart }, backend });
|
const imagPart = imag({ inputs: { input: x }, backend });
|
const i = zerosLike({ inputs: { x: imagPart }, backend });
|
const result = complex({ inputs: { real: r, imag: i }, backend });
|
backend.disposeIntermediateTensorInfo(realPart);
|
backend.disposeIntermediateTensorInfo(r);
|
backend.disposeIntermediateTensorInfo(imagPart);
|
backend.disposeIntermediateTensorInfo(i);
|
return result;
|
}
|
else {
|
return fill({ backend, attrs: { shape: x.shape, value: 1, dtype: x.dtype } });
|
}
|
}
|
const onesLikeConfig = {
|
kernelName: OnesLike,
|
backendName: 'cpu',
|
kernelFunc: onesLike
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function pack(args) {
|
const { inputs, backend, attrs } = args;
|
const { axis } = attrs;
|
if (inputs.length === 1) {
|
return expandDims({ inputs: { input: inputs[0] }, backend, attrs: { dim: axis } });
|
}
|
const shape = inputs[0].shape;
|
const dtype = inputs[0].dtype;
|
inputs.forEach(t => {
|
util.assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
|
util.assert(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes');
|
});
|
const intermediateTensorInfos = [];
|
const expandedTensors = inputs.map(t => {
|
const expandedT = expandDims({ inputs: { input: t }, backend, attrs: { dim: axis } });
|
intermediateTensorInfos.push(expandedT);
|
return expandedT;
|
});
|
const result = concat({ inputs: expandedTensors, backend, attrs: { axis } });
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
return result;
|
}
|
const packConfig = {
|
kernelName: Pack,
|
backendName: 'cpu',
|
kernelFunc: pack
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function padV2(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { paddings, constantValue } = attrs;
|
assertNotComplex(x, 'pad');
|
const outShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */);
|
const start = paddings.map(p => p[0]);
|
const xVals = backend.data.get(x.dataId).values;
|
const xSize = util.sizeFromShape(x.shape);
|
const xRank = x.shape.length;
|
const xStrides = util.computeStrides(x.shape);
|
const resultSize = util.sizeFromShape(outShape);
|
const resultRank = outShape.length;
|
const resultStrides = util.computeStrides(outShape);
|
const resVals = util.getTypedArrayFromDType(x.dtype, resultSize);
|
if (constantValue !== 0) {
|
resVals.fill(constantValue);
|
}
|
for (let i = 0; i < xSize; i++) {
|
const coords = util.indexToLoc(i, xRank, xStrides);
|
const outCoords = coords.map((c, i) => c + start[i]);
|
const outIndex = util.locToIndex(outCoords, resultRank, resultStrides);
|
resVals[outIndex] = xVals[i];
|
}
|
const outId = backend.write(resVals, outShape, x.dtype);
|
return { dataId: outId, shape: outShape, dtype: x.dtype };
|
}
|
const padV2Config = {
|
kernelName: PadV2,
|
backendName: 'cpu',
|
kernelFunc: padV2
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const powImpl = createSimpleBinaryKernelImpl((a, b) => Math.pow(a, b));
|
const pow = binaryKernelFunc(Pow, powImpl);
|
const powConfig = {
|
kernelName: Pow,
|
backendName: 'cpu',
|
kernelFunc: pow
|
};
|
|
/**
|
* @license
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function raggedGather(args) {
|
const { inputs, backend, attrs } = args;
|
const { paramsNestedSplits, paramsDenseValues, indices } = inputs;
|
const $paramsNestedSplits = paramsNestedSplits.map(t => backend.data.get(t.dataId).values);
|
const $paramsNestedSplitsShapes = paramsNestedSplits.map(t => t.shape);
|
const $paramsDenseValues = backend.data.get(paramsDenseValues.dataId).values;
|
const $indices = backend.data.get(indices.dataId).values;
|
const [outputNestedSplits, outputDenseValues, outputDenseValuesShape] = raggedGatherImpl($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape);
|
const outputNestedSplitsTensors = outputNestedSplits.map((splits) => backend.makeTensorInfo([splits.length], 'int32', splits));
|
const outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues);
|
return outputNestedSplitsTensors.concat([outputDenseValuesTensor]);
|
}
|
const raggedGatherConfig = {
|
kernelName: RaggedGather,
|
backendName: 'cpu',
|
kernelFunc: raggedGather,
|
};
|
|
/**
|
* @license
|
* Copyright 2022 Google LLC.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function raggedRange(args) {
|
const { inputs, backend } = args;
|
const { starts, limits, deltas } = inputs;
|
const $starts = backend.data.get(starts.dataId).values;
|
const $limits = backend.data.get(limits.dataId).values;
|
const $deltas = backend.data.get(deltas.dataId).values;
|
const [rtNestedSplitsData, rtDenseValuesData] = raggedRangeImpl($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape);
|
const rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData);
|
const rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData);
|
return [rtNestedSplits, rtDenseValues];
|
}
|
const raggedRangeConfig = {
|
kernelName: RaggedRange,
|
backendName: 'cpu',
|
kernelFunc: raggedRange,
|
};
|
|
/**
|
* @license
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function raggedTensorToTensor(args) {
|
const { inputs, backend, attrs } = args;
|
const { shape, values, defaultValue, rowPartitionTensors } = inputs;
|
const { rowPartitionTypes } = attrs;
|
const $shape = backend.data.get(shape.dataId).values;
|
const $values = backend.data.get(values.dataId).values;
|
const $defaultValue = backend.data.get(defaultValue.dataId).values;
|
const $rowPartitionValues = rowPartitionTensors.map(t => backend.data.get(t.dataId).values);
|
const rowPartitionValuesShapes = rowPartitionTensors.map(t => t.shape);
|
const [outputShape, output] = raggedTensorToTensorImpl($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes);
|
return backend.makeTensorInfo(outputShape, values.dtype, output);
|
}
|
const raggedTensorToTensorConfig = {
|
kernelName: RaggedTensorToTensor,
|
backendName: 'cpu',
|
kernelFunc: raggedTensorToTensor,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function range(args) {
|
const { backend, attrs } = args;
|
const { start, stop, dtype, step } = attrs;
|
const values = rangeImpl(start, stop, step, dtype);
|
return backend.makeTensorInfo([values.length], dtype, values);
|
}
|
const rangeConfig = {
|
kernelName: Range,
|
backendName: 'cpu',
|
kernelFunc: range
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const reciprocal = unaryKernelFunc(Reciprocal, (xi) => 1 / xi);
|
const reciprocalConfig = {
|
kernelName: Reciprocal,
|
backendName: 'cpu',
|
kernelFunc: reciprocal,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function resizeBilinear(args) {
|
const { inputs, backend, attrs } = args;
|
const { images } = inputs;
|
const { alignCorners, halfPixelCenters, size } = attrs;
|
assertNotComplex(images, 'resizeBilinear');
|
const imagesStrides = util.computeStrides(images.shape);
|
const [newHeight, newWidth] = size;
|
const [batch, oldHeight, oldWidth, numChannels] = images.shape;
|
const xValues = backend.data.get(images.dataId).values;
|
const result = new Float32Array(util.sizeFromShape([batch, newHeight, newWidth, numChannels]));
|
const effectiveInputSize = [
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
];
|
const effectiveOutputSize = [
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
];
|
let outputIdx = 0;
|
const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
|
const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
|
for (let b = 0; b < batch; b++) {
|
for (let r = 0; r < newHeight; r++) {
|
let sourceFracRow;
|
if (halfPixelCenters) {
|
sourceFracRow = effectiveRowSizeRatio * (r + 0.5) - 0.5;
|
}
|
else {
|
sourceFracRow = effectiveRowSizeRatio * r;
|
}
|
const sourceRowFloor = Math.max(0, Math.floor(sourceFracRow));
|
const rowFrac = sourceFracRow - sourceRowFloor;
|
const sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow));
|
const topRowOffset = b * imagesStrides[0] + sourceRowFloor * imagesStrides[1];
|
const botRowOffset = b * imagesStrides[0] + sourceRowCeil * imagesStrides[1];
|
for (let c = 0; c < newWidth; c++) {
|
let sourceFracCol;
|
if (halfPixelCenters) {
|
sourceFracCol = effectiveColSizeRatio * (c + 0.5) - 0.5;
|
}
|
else {
|
sourceFracCol = effectiveColSizeRatio * c;
|
}
|
const sourceColFloor = Math.max(0, Math.floor(sourceFracCol));
|
const colFrac = sourceFracCol - sourceColFloor;
|
const sourceColCeil = Math.min(oldWidth - 1, Math.ceil(sourceFracCol));
|
const topLeftOffest = topRowOffset + sourceColFloor * imagesStrides[2];
|
const botLeftOffset = botRowOffset + sourceColFloor * imagesStrides[2];
|
const topRightOffset = topRowOffset + sourceColCeil * imagesStrides[2];
|
const botRightOffest = botRowOffset + sourceColCeil * imagesStrides[2];
|
for (let d = 0; d < numChannels; d++) {
|
// Begin shader.
|
// Compute the fractional index of the source.
|
const topLeft = xValues[topLeftOffest + d];
|
const bottomLeft = xValues[botLeftOffset + d];
|
const topRight = xValues[topRightOffset + d];
|
const bottomRight = xValues[botRightOffest + d];
|
const top = topLeft + (topRight - topLeft) * colFrac;
|
const bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac;
|
const newValue = top + (bottom - top) * rowFrac;
|
result[outputIdx++] = newValue;
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], 'float32', result);
|
}
|
const resizeBilinearConfig = {
|
kernelName: ResizeBilinear,
|
backendName: 'cpu',
|
kernelFunc: resizeBilinear
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function resizeBilinearGrad(args) {
|
const { inputs, backend, attrs } = args;
|
const { images, dy } = inputs;
|
const { alignCorners } = attrs;
|
assertNotComplex([dy, images], 'resizeBilinearGrad');
|
const imagesStrides = util.computeStrides(images.shape);
|
const [batch, xHeight, xWidth, depth] = images.shape;
|
const [, yHeight, yWidth] = dy.shape;
|
const output = new Float32Array(batch * xHeight * xWidth * depth);
|
// In the backwards pass, we want to find the pixels that were generated
|
// for each pixel in the input image the forward pass and add the
|
// corresponding coefficient from dy to the gradient (with some
|
// interpolation).
|
const effectiveXSize = [
|
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
|
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
|
];
|
const effectiveYSize = [
|
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
|
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
|
];
|
const heightScale = effectiveXSize[0] / effectiveYSize[0];
|
const widthScale = effectiveXSize[1] / effectiveYSize[1];
|
// Reference implementation
|
// tslint:disable-next-line:max-line-length
|
// https://github.com/tensorflow/tensorflow/blob/3039375c86a5bbc9610c7725dcaa95d635f87ba2/tensorflow/core/kernels/resize_bilinear_op.cc#L275
|
const dyValues = backend.data.get(dy.dataId).values;
|
let offset = 0;
|
for (let b = 0; b < batch; b++) {
|
const bOffset = b * imagesStrides[0];
|
for (let r = 0; r < yHeight; r++) {
|
const dxR = r * heightScale;
|
const topDxRIndex = Math.floor(dxR);
|
const bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1);
|
const topDxROffset = bOffset + topDxRIndex * imagesStrides[1];
|
const bottomDxROffset = bOffset + bottomDxRIndex * imagesStrides[1];
|
const dxRLerp = dxR - topDxRIndex;
|
const inverseDxRLerp = 1.0 - dxRLerp;
|
for (let c = 0; c < yWidth; c++) {
|
const dxC = c * widthScale;
|
const leftDxCIndex = Math.floor(dxC);
|
const rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1);
|
const dxCLerp = dxC - leftDxCIndex;
|
const inverseDxCLerp = 1.0 - dxCLerp;
|
const topLeftRCOffset = topDxROffset + leftDxCIndex * imagesStrides[2];
|
const topRightRCOffset = topDxROffset + rightDxCIndex * imagesStrides[2];
|
const bottomLeftRCOffset = bottomDxROffset + leftDxCIndex * imagesStrides[2];
|
const bottomRightRCOffset = bottomDxROffset + rightDxCIndex * imagesStrides[2];
|
const inverseDxRLerpTimesInverseDxCLerp = inverseDxRLerp * inverseDxCLerp;
|
const inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp;
|
const dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp;
|
const dxRLerpTimesDxCLerp = dxRLerp * dxCLerp;
|
for (let d = 0; d < depth; d++) {
|
const dyVal = dyValues[offset++];
|
output[topLeftRCOffset + d] +=
|
dyVal * inverseDxRLerpTimesInverseDxCLerp;
|
output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp;
|
output[bottomLeftRCOffset + d] += dyVal * dxRLerpTimesInverseDxCLerp;
|
output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp;
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo([batch, xWidth, xHeight, depth], 'float32', output);
|
}
|
const resizeBilinearGradConfig = {
|
kernelName: ResizeBilinearGrad,
|
backendName: 'cpu',
|
kernelFunc: resizeBilinearGrad
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function resizeNearestNeighbor(args) {
|
const { inputs, backend, attrs } = args;
|
const { images } = inputs;
|
const { alignCorners, halfPixelCenters, size } = attrs;
|
assertNotComplex(images, 'resizeNearestNeighbor');
|
const imagesStrides = util.computeStrides(images.shape);
|
const [newHeight, newWidth] = size;
|
const [batch, oldHeight, oldWidth, numChannels] = images.shape;
|
const xValues = backend.data.get(images.dataId).values;
|
const output = new Float32Array(batch * newHeight * newWidth * numChannels);
|
const effectiveInputSize = [
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
];
|
const effectiveOutputSize = [
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
];
|
const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
|
const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
|
let outputOffset = 0;
|
for (let b = 0; b < batch; b++) {
|
const batchOffset = b * imagesStrides[0];
|
for (let r = 0; r < newHeight; r++) {
|
const sourceFracRow = halfPixelCenters ?
|
effectiveRowSizeRatio * (r + 0.5) :
|
effectiveRowSizeRatio * r;
|
let sourceNearestRow = Math.min(oldHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow));
|
if (halfPixelCenters) {
|
sourceNearestRow = Math.max(0, sourceNearestRow);
|
}
|
const rowOffset = batchOffset + sourceNearestRow * imagesStrides[1];
|
for (let c = 0; c < newWidth; c++) {
|
const sourceFracCol = halfPixelCenters ?
|
effectiveColSizeRatio * (c + 0.5) :
|
effectiveColSizeRatio * c;
|
let sourceNearestCol = Math.min(oldWidth - 1, alignCorners ? Math.round(sourceFracCol) :
|
Math.floor(sourceFracCol));
|
if (halfPixelCenters) {
|
sourceNearestCol = Math.max(0, sourceNearestCol);
|
}
|
const colOffset = rowOffset + sourceNearestCol * imagesStrides[2];
|
for (let d = 0; d < numChannels; d++) {
|
// Begin shader.
|
// Compute the fractional index of the source.
|
const newVal = xValues[colOffset + d];
|
output[outputOffset++] = newVal;
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], images.dtype, output);
|
}
|
const resizeNearestNeighborConfig = {
|
kernelName: ResizeNearestNeighbor,
|
backendName: 'cpu',
|
kernelFunc: resizeNearestNeighbor
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function resizeNearestNeighborGrad(args) {
|
const { inputs, backend, attrs } = args;
|
const { images, dy } = inputs;
|
const { alignCorners } = attrs;
|
assertNotComplex([dy, images], 'resizeNearestNeighborGrad');
|
const imagesStrides = util.computeStrides(images.shape);
|
const dyStrides = util.computeStrides(dy.shape);
|
const [batch, xHeight, xWidth, depth] = images.shape;
|
const [, yHeight, yWidth] = dy.shape;
|
const output = new Float32Array(batch * xHeight * xWidth * depth);
|
const dyValues = backend.data.get(dy.dataId).values;
|
// In the backwards pass, we want to find the pixels that were generated
|
// for each pixel in the input image the forward pass
|
const effectiveXSize = [
|
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
|
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
|
];
|
const effectiveYSize = [
|
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
|
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
|
];
|
const heightScale = effectiveXSize[0] / effectiveYSize[0];
|
const widthScale = effectiveXSize[1] / effectiveYSize[1];
|
const invHeightScale = 1 / heightScale;
|
const invWidthScale = 1 / widthScale;
|
// This defines the size of the window of values around a particular
|
// index in dy that we want to search for contributions to dx.
|
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
|
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
|
// Loop over the output space.
|
for (let b = 0; b < batch; b++) {
|
const batchOffset = b * imagesStrides[0];
|
for (let r = 0; r < xHeight; r++) {
|
const rowOffset = batchOffset + r * imagesStrides[1];
|
// Compute bounds for where in dy we will look
|
const startRLerp = Math.floor(r * invHeightScale);
|
const startDyR = Math.floor(startRLerp - (winHeight / 2));
|
for (let c = 0; c < xWidth; c++) {
|
const colOffset = rowOffset + c * imagesStrides[2];
|
// Compute bounds for where in dy we will look
|
const startCLerp = Math.floor(c * invWidthScale);
|
const startDyC = Math.floor(startCLerp - (winWidth / 2));
|
for (let d = 0; d < depth; d++) {
|
let accum = 0;
|
// loop over dy
|
for (let dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {
|
const dyR = dyRIndex + startDyR;
|
// Guard against the window exceeding the bounds of dy
|
if (dyR < 0 || dyR >= yHeight) {
|
continue;
|
}
|
const dyROffset = batchOffset + dyR * dyStrides[1];
|
const sourceFracRow = dyR * heightScale;
|
const sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) :
|
Math.floor(sourceFracRow));
|
if (r !== sourceNearestRow) {
|
continue;
|
}
|
for (let dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {
|
const dyC = dyCIndex + startDyC;
|
// Guard against the window exceeding the bounds of dy
|
if (dyC < 0 || dyC >= yWidth) {
|
continue;
|
}
|
const dyCOffset = dyROffset + dyC * dyStrides[2];
|
const sourceFracCol = dyC * widthScale;
|
const sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) :
|
Math.floor(sourceFracCol));
|
if (c === sourceNearestCol) {
|
accum += dyValues[dyCOffset + d];
|
}
|
}
|
}
|
output[colOffset + d] = accum;
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(images.shape, images.dtype, output);
|
}
|
const resizeNearestNeighborGradConfig = {
|
kernelName: ResizeNearestNeighborGrad,
|
backendName: 'cpu',
|
kernelFunc: resizeNearestNeighborGrad
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function reverse(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { dims } = attrs;
|
assertNotComplex(x, 'reverse');
|
const xRank = x.shape.length;
|
const $dims = util.parseAxisParam(dims, x.shape);
|
if (xRank === 0) {
|
return identity({ inputs: { x }, backend });
|
}
|
const outBuf = new TensorBuffer(x.shape, x.dtype);
|
const xBuf = backend.bufferSync(x);
|
for (let i = 0; i < outBuf.size; i++) {
|
const outLoc = outBuf.indexToLoc(i);
|
const inLoc = outLoc.slice();
|
$dims.forEach(d => inLoc[d] = x.shape[d] - 1 - inLoc[d]);
|
outBuf.set(xBuf.get(...inLoc), ...outLoc);
|
}
|
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
|
}
|
const reverseConfig = {
|
kernelName: Reverse,
|
backendName: 'cpu',
|
kernelFunc: reverse
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const rotateWithOffsetConfig = {
|
kernelName: RotateWithOffset,
|
backendName: 'cpu',
|
kernelFunc: ({ inputs, attrs, backend }) => {
|
const { image } = inputs;
|
const { radians, fillValue, center } = attrs;
|
const cpuBackend = backend;
|
const output = util.getTypedArrayFromDType(image.dtype, util.sizeFromShape(image.shape));
|
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
|
const [centerX, centerY] = backend_util.getImageCenter(center, imageHeight, imageWidth);
|
const fullOpacityValue = 255;
|
const sinFactor = Math.sin(radians);
|
const cosFactor = Math.cos(radians);
|
const imageVals = cpuBackend.data.get(image.dataId).values;
|
for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
|
const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
|
for (let row = 0; row < imageHeight; row++) {
|
const rowOffset = row * (imageWidth * numChannels);
|
for (let col = 0; col < imageWidth; col++) {
|
const colOffset = col * numChannels;
|
for (let channel = 0; channel < numChannels; channel++) {
|
const coords = [batch, row, col, channel];
|
const x = coords[2];
|
const y = coords[1];
|
// coordX/coordY are the result of rotating and translating x/y.
|
let coordX = (x - centerX) * cosFactor - (y - centerY) * sinFactor;
|
let coordY = (x - centerX) * sinFactor + (y - centerY) * cosFactor;
|
coordX = Math.round(coordX + centerX);
|
coordY = Math.round(coordY + centerY);
|
let outputValue = fillValue;
|
if (typeof fillValue !== 'number') {
|
if (channel === 3) {
|
outputValue = fullOpacityValue;
|
}
|
else {
|
outputValue = fillValue[channel];
|
}
|
}
|
// If the coordinate position falls within the image boundaries...
|
if (coordX >= 0 && coordX < imageWidth && coordY >= 0 &&
|
coordY < imageHeight) {
|
// set the output to the image value at the coordinate position.
|
const rotatedRowOffset = coordY * (imageWidth * numChannels);
|
const rotatedColOffset = coordX * numChannels;
|
const imageIdx = batchOffset + rotatedRowOffset + rotatedColOffset + channel;
|
outputValue = imageVals[imageIdx];
|
}
|
const outIdx = batchOffset + rowOffset + colOffset + channel;
|
output[outIdx] = outputValue;
|
}
|
}
|
}
|
}
|
const dataId = cpuBackend.write(output, image.shape, image.dtype);
|
return { dataId, shape: image.shape, dtype: image.dtype };
|
}
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const round = unaryKernelFunc(Round, (xi) => {
|
// The algorithm is based on banker's rounding.
|
const base = Math.floor(xi);
|
if (xi - base < 0.5) {
|
return Math.floor(xi);
|
}
|
else if (xi - base > 0.5) {
|
return Math.ceil(xi);
|
}
|
else {
|
if (base % 2.0 === 0.0) {
|
return base;
|
}
|
else {
|
return base + 1.0;
|
}
|
}
|
});
|
const roundConfig = {
|
kernelName: Round,
|
backendName: 'cpu',
|
kernelFunc: round,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function scatterNd(args) {
|
const { inputs, backend, attrs } = args;
|
const { indices, updates } = inputs;
|
const { shape } = attrs;
|
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = backend_util.calculateShapes(updates, indices, shape);
|
const sumDupeIndices = true;
|
const indicesBuf = backend.bufferSync(indices);
|
const updatesBuf = backend.bufferSync(updates);
|
const outBuf = scatterImpl(indicesBuf, updatesBuf, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, 0 /* defaultValue */, sumDupeIndices);
|
return backend.makeTensorInfo(shape, outBuf.dtype, outBuf.values);
|
}
|
const scatterNdConfig = {
|
kernelName: ScatterNd,
|
backendName: 'cpu',
|
kernelFunc: scatterNd
|
};
|
|
/**
|
* @license
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function lowerBound(array, value) {
|
let left = 0;
|
let right = array.length;
|
let mid = 0;
|
while (left < right) {
|
mid = Math.floor((left + right) / 2);
|
if (array[mid] < value) {
|
left = mid + 1;
|
}
|
else {
|
right = mid;
|
}
|
}
|
return right;
|
}
|
function upperBound(array, value) {
|
let left = 0;
|
let right = array.length;
|
let mid = 0;
|
while (left < right) {
|
mid = Math.floor((left + right) / 2);
|
if (array[mid] <= value) {
|
left = mid + 1;
|
}
|
else {
|
right = mid;
|
}
|
}
|
return right;
|
}
|
function searchSortedImpl(sortedInputs, values, batchSize, numInputs, numValues, side) {
|
const output = util.getArrayFromDType('int32', batchSize * numValues);
|
for (let b = 0; b < batchSize; ++b) {
|
const sortedInputsSlice = sortedInputs.slice(b * numInputs, (b + 1) * numInputs);
|
const outputOffset = b * numValues;
|
for (let i = 0; i < numValues; ++i) {
|
output[outputOffset + i] = side === 'left' ?
|
lowerBound(sortedInputsSlice, values[i + outputOffset]) :
|
upperBound(sortedInputsSlice, values[i + outputOffset]);
|
}
|
}
|
return output;
|
}
|
|
/**
|
* @license
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function searchSorted(args) {
|
const { inputs, backend, attrs } = args;
|
const { sortedSequence, values } = inputs;
|
const { side } = attrs;
|
const $sortedSequence = backend.data.get(sortedSequence.dataId).values;
|
const $values = backend.data.get(values.dataId).values;
|
const output = searchSortedImpl($sortedSequence, $values, sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side);
|
return backend.makeTensorInfo(values.shape, 'int32', output);
|
}
|
const searchSortedConfig = {
|
kernelName: SearchSorted,
|
backendName: 'cpu',
|
kernelFunc: searchSorted,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function select(args) {
|
const { inputs, backend } = args;
|
const { condition, t, e } = inputs;
|
assertNotComplex([condition, t, e], 'select');
|
const conditionRank = condition.shape.length;
|
const values = backend.data.get(condition.dataId).values;
|
const tValues = backend.data.get(t.dataId).values;
|
const eValues = backend.data.get(e.dataId).values;
|
const resultDtype = upcastType(t.dtype, e.dtype);
|
const newValues = util.makeZerosTypedArray(util.sizeFromShape(t.shape), resultDtype);
|
let index = 0;
|
const offset = conditionRank === 0 || conditionRank > 1 || t.shape.length === 1 ?
|
1 :
|
util.sizeFromShape(t.shape.slice(1));
|
for (let i = 0; i < values.length; i++) {
|
for (let j = 0; j < offset; j++) {
|
if (values[i] === 1) {
|
newValues[index++] = tValues[i];
|
}
|
else {
|
newValues[index++] = eValues[i];
|
}
|
}
|
}
|
return backend.makeTensorInfo(t.shape, resultDtype, newValues);
|
}
|
const selectConfig = {
|
kernelName: Select,
|
backendName: 'cpu',
|
kernelFunc: select
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const scaleAlpha = backend_util.SELU_SCALEALPHA;
|
const scale = backend_util.SELU_SCALE;
|
const selu = unaryKernelFunc(Selu, (xi) => {
|
if (xi >= 0) {
|
return scale * xi;
|
}
|
else {
|
return scaleAlpha * (Math.exp(xi) - 1);
|
}
|
});
|
const seluConfig = {
|
kernelName: Selu,
|
backendName: 'cpu',
|
kernelFunc: selu,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const sign = unaryKernelFunc(Sign, (xi) => {
|
if (xi < 0) {
|
return -1;
|
}
|
else if (xi > 0) {
|
return 1;
|
}
|
else {
|
return 0;
|
}
|
});
|
const signConfig = {
|
kernelName: Sign,
|
backendName: 'cpu',
|
kernelFunc: sign,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const sin = unaryKernelFunc(Sin, (xi) => Math.sin(xi));
|
const sinConfig = {
|
kernelName: Sin,
|
backendName: 'cpu',
|
kernelFunc: sin,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const sinh = unaryKernelFunc(Sinh, (xi) => Math.sinh(xi));
|
const sinhConfig = {
|
kernelName: Sinh,
|
backendName: 'cpu',
|
kernelFunc: sinh,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
// mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX
|
// epsilon is the difference between 1.0 and the next representable float.
|
// For a single precision 32 bit float this should be 2^-23, see:
|
// https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm
|
const epsilon = 1.1920928955078125e-7;
|
const threshold = Math.log(epsilon) + 2.0;
|
const softplus = unaryKernelFunc(Softplus, (xi) => {
|
// Value above which exp(x) may overflow, but softplus(x) == x
|
// is within machine epsilon.
|
const tooLarge = xi > -threshold;
|
// Value below which exp(x) may underflow, but softplus(x) == exp(x)
|
// is within machine epsilon.
|
const tooSmall = xi < threshold;
|
const expX = Math.exp(xi);
|
let result;
|
if (tooSmall) {
|
result = expX;
|
}
|
else if (tooLarge) {
|
result = xi;
|
}
|
else {
|
result = Math.log(1.0 + expX);
|
}
|
return result;
|
});
|
const softplusConfig = {
|
kernelName: Softplus,
|
backendName: 'cpu',
|
kernelFunc: softplus,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function spaceToBatchND(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { blockShape, paddings } = attrs;
|
assertNotComplex([x], 'spaceToBatchND');
|
const prod = util.sizeFromShape(blockShape);
|
const completePaddings = [[0, 0]];
|
completePaddings.push(...paddings);
|
for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
|
completePaddings.push([0, 0]);
|
}
|
const paddedX = padV2Config.kernelFunc({
|
inputs: { x },
|
backend,
|
attrs: { paddings: completePaddings, constantValue: 0 }
|
});
|
const reshapedPaddedShape = backend_util.getReshaped(paddedX.shape, blockShape, prod, false);
|
const permutedReshapedPaddedPermutation = backend_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false);
|
const flattenShape = backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false);
|
const reshapeInputs = { x: paddedX };
|
const reshapeAttrs = { shape: reshapedPaddedShape };
|
const paddedXReshaped = reshape({ inputs: reshapeInputs, backend, attrs: reshapeAttrs });
|
const transposeInputs = { x: paddedXReshaped };
|
const transposeAttrs = { perm: permutedReshapedPaddedPermutation };
|
const paddedXT = transpose({ inputs: transposeInputs, backend, attrs: transposeAttrs });
|
const resultReshapeInputs = { x: paddedXT };
|
const resultReshapeAttrs = { shape: flattenShape };
|
const result = reshape({ inputs: resultReshapeInputs, backend, attrs: resultReshapeAttrs });
|
backend.disposeIntermediateTensorInfo(paddedX);
|
backend.disposeIntermediateTensorInfo(paddedXReshaped);
|
backend.disposeIntermediateTensorInfo(paddedXT);
|
return result;
|
}
|
const spaceToBatchNDConfig = {
|
kernelName: SpaceToBatchND,
|
backendName: 'cpu',
|
kernelFunc: spaceToBatchND
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function sparseFillEmptyRows(args) {
|
const { inputs, backend } = args;
|
const { indices, values, denseShape, defaultValue } = inputs;
|
if (denseShape.shape.length !== 1) {
|
throw new Error(`Dense shape must be a vector, saw:
|
${denseShape.shape}`);
|
}
|
if (indices.shape.length !== 2) {
|
throw new Error(`Indices must be a matrix, saw:
|
${indices.shape}`);
|
}
|
if (values.shape.length !== 1) {
|
throw new Error(`Values must be a vector, saw:
|
${values.shape}`);
|
}
|
if (defaultValue.shape.length !== 0) {
|
throw new Error(`Default value must be a scalar, saw:
|
${defaultValue.shape}`);
|
}
|
const $indices = backend.data.get(indices.dataId).values;
|
const $values = backend.data.get(values.dataId).values;
|
const $denseShape = backend.data.get(denseShape.dataId).values;
|
const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
|
const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImpl($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue);
|
return [
|
backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
|
backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
|
backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))),
|
backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
|
];
|
}
|
const sparseFillEmptyRowsConfig = {
|
kernelName: SparseFillEmptyRows,
|
backendName: 'cpu',
|
kernelFunc: sparseFillEmptyRows,
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function sparseReshape(args) {
|
const { inputs, backend } = args;
|
const { inputIndices, inputShape, newShape } = inputs;
|
if (inputIndices.shape.length !== 2) {
|
throw new Error(`Input indices should be a matrix but received shape
|
${inputIndices.shape}`);
|
}
|
if (inputShape.shape.length !== 1) {
|
throw new Error(`Input shape should be a vector but received shape
|
${inputShape.shape}`);
|
}
|
if (newShape.shape.length !== 1) {
|
throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`);
|
}
|
const $inputShape = Array.from(backend.data.get(inputShape.dataId).values);
|
const $inputIndices = backend.data.get(inputIndices.dataId).values;
|
const targetShape = Array.from(backend.data.get(newShape.dataId).values);
|
const [newIndices, indicesShape, outputShape] = sparseReshapeImpl($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape);
|
return [
|
backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
|
backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
|
];
|
}
|
const sparseReshapeConfig = {
|
kernelName: SparseReshape,
|
backendName: 'cpu',
|
kernelFunc: sparseReshape,
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function sparseSegmentMean(args) {
|
const { inputs, backend } = args;
|
const { data, indices, segmentIds } = inputs;
|
if (data.shape.length < 1) {
|
throw new Error(`Data should be at least 1 dimensional but received scalar`);
|
}
|
if (indices.shape.length !== 1) {
|
throw new Error(`Indices should be a vector but received shape
|
${indices.shape}`);
|
}
|
if (segmentIds.shape.length !== 1) {
|
throw new Error(`Segment ids should be a vector but received shape
|
${segmentIds.shape}`);
|
}
|
if (indices.shape[0] !== segmentIds.shape[0]) {
|
throw new Error(`segmentIds and indices should have same size.`);
|
}
|
const $data = backend.data.get(data.dataId).values;
|
const $indices = backend.data.get(indices.dataId).values;
|
const $segmentIds = backend.data.get(segmentIds.dataId).values;
|
const [outputData, outputDataShape] = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds, true);
|
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
|
}
|
const sparseSegmentMeanConfig = {
|
kernelName: SparseSegmentMean,
|
backendName: 'cpu',
|
kernelFunc: sparseSegmentMean,
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function sparseSegmentSum(args) {
|
const { inputs, backend } = args;
|
const { data, indices, segmentIds } = inputs;
|
if (data.shape.length < 1) {
|
throw new Error(`Data should be at least 1 dimensional but received scalar`);
|
}
|
if (indices.shape.length !== 1) {
|
throw new Error(`Indices should be a vector but received shape
|
${indices.shape}`);
|
}
|
if (segmentIds.shape.length !== 1) {
|
throw new Error(`Segment ids should be a vector but received shape
|
${segmentIds.shape}`);
|
}
|
if (indices.shape[0] !== segmentIds.shape[0]) {
|
throw new Error(`segmentIds and indices should have same size.`);
|
}
|
const $data = backend.data.get(data.dataId).values;
|
const $indices = backend.data.get(indices.dataId).values;
|
const $segmentIds = backend.data.get(segmentIds.dataId).values;
|
const [outputData, outputDataShape] = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds);
|
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
|
}
|
const sparseSegmentSumConfig = {
|
kernelName: SparseSegmentSum,
|
backendName: 'cpu',
|
kernelFunc: sparseSegmentSum,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function sparseToDense(args) {
|
const { inputs, backend, attrs } = args;
|
const { sparseIndices, sparseValues, defaultValue } = inputs;
|
const { outputShape } = attrs;
|
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = backend_util.calculateShapes(sparseValues, sparseIndices, outputShape);
|
const sumDupeIndices = false;
|
const indicesBuf = backend.bufferSync(sparseIndices);
|
let outBuf;
|
switch (sparseValues.dtype) {
|
case 'bool': {
|
const updatesBuf = backend.bufferSync(sparseValues);
|
const $defaultValue = Boolean(backend.data.get(defaultValue.dataId).values[0]);
|
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
|
break;
|
}
|
case 'float32': {
|
const updatesBuf = backend.bufferSync(sparseValues);
|
const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
|
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
|
break;
|
}
|
case 'int32': {
|
const updatesBuf = backend.bufferSync(sparseValues);
|
const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
|
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
|
break;
|
}
|
case 'string': {
|
const updatesBuf = backend.bufferSync(sparseValues);
|
const $defaultValue = util.decodeString(backend.data.get(defaultValue.dataId).values[0]);
|
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
|
break;
|
}
|
default:
|
throw new Error(`Unsupported type ${sparseValues.dtype}`);
|
}
|
return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
|
}
|
const sparseToDenseConfig = {
|
kernelName: SparseToDense,
|
backendName: 'cpu',
|
kernelFunc: sparseToDense
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function splitV(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { numOrSizeSplits, axis } = attrs;
|
const $axis = util.parseAxisParam(axis, x.shape)[0];
|
const splitSizes = backend_util.prepareSplitSize(x, numOrSizeSplits, $axis);
|
const begin = new Array(x.shape.length).fill(0);
|
const size = x.shape.slice();
|
return splitSizes.map(s => {
|
const sliceSize = [...size];
|
sliceSize[$axis] = s;
|
const sliceT = slice({ inputs: { x }, backend, attrs: { begin, size: sliceSize } });
|
begin[$axis] += s;
|
return sliceT;
|
});
|
}
|
const splitVConfig = {
|
kernelName: SplitV,
|
backendName: 'cpu',
|
kernelFunc: splitV
|
};
|
|
/**
|
* @license
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const squareConfig = {
|
kernelName: Square,
|
backendName: 'cpu',
|
kernelFunc: ({ inputs, backend }) => {
|
const { x } = inputs;
|
const cpuBackend = backend;
|
assertNotComplex(x, 'square');
|
const values = cpuBackend.data.get(x.dataId).values;
|
const newValues = new Float32Array(values.length);
|
for (let i = 0; i < values.length; ++i) {
|
const value = values[i];
|
newValues[i] = value * value;
|
}
|
const dataId = cpuBackend.write(newValues, x.shape, x.dtype);
|
return { dataId, shape: x.shape, dtype: x.dtype };
|
}
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const step = unaryKernelFunc(Step, (xi, attrs) => {
|
const stepAttrs = attrs;
|
if (isNaN(xi)) {
|
return NaN;
|
}
|
else {
|
return xi > 0 ? 1 : stepAttrs.alpha;
|
}
|
});
|
const stepConfig = {
|
kernelName: Step,
|
backendName: 'cpu',
|
kernelFunc: step,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function stridedSlice(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
|
assertNotComplex(x, 'stridedSlice');
|
const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = slice_util.sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
|
let result;
|
// ref:
|
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/strided_slice_op.cc
|
if (isIdentity) {
|
// Optimization #1, slice is a no-op plus reshape
|
result = reshape({ inputs: { x }, backend, attrs: { shape: finalShape } });
|
}
|
else if (sliceDim0 || isSimpleSlice) {
|
// Optimization #2, slice is memory contiguous (only occurs in dim 0)
|
util.assert(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
|
const size = slice_util.computeOutShape($begin, $end, $strides);
|
// To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end).
|
const sliced = slice({ inputs: { x }, backend, attrs: { begin: $begin, size } });
|
result =
|
reshape({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
|
backend.disposeIntermediateTensorInfo(sliced);
|
}
|
else {
|
const xBuf = backend.bufferSync(x);
|
const outBuf = stridedSliceImpl(finalShapeSparse, xBuf, $strides, $begin);
|
result = backend.makeTensorInfo(finalShape, outBuf.dtype, outBuf.values);
|
}
|
return result;
|
}
|
const stridedSliceConfig = {
|
kernelName: StridedSlice,
|
backendName: 'cpu',
|
kernelFunc: stridedSlice
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function stringNGrams(args) {
|
const { inputs, backend, attrs } = args;
|
const { separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences } = attrs;
|
const { data, dataSplits } = inputs;
|
const $data = backend.data.get(data.dataId).values;
|
const $dataSplits = backend.data.get(dataSplits.dataId).values;
|
const [nGrams, nGramsSplits] = stringNGramsImpl($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences);
|
return [
|
backend.makeTensorInfo([nGrams.length], 'string', nGrams),
|
backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits),
|
];
|
}
|
const stringNGramsConfig = {
|
kernelName: StringNGrams,
|
backendName: 'cpu',
|
kernelFunc: stringNGrams,
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function stringSplit(args) {
|
const { inputs, backend, attrs } = args;
|
const { skipEmpty } = attrs;
|
const { input, delimiter } = inputs;
|
if (input.dtype !== 'string') {
|
throw new Error('Input must be of datatype string');
|
}
|
if (input.shape.length !== 1) {
|
throw new Error(`Input must be a vector, got shape: ${input.shape}`);
|
}
|
if (delimiter.shape.length !== 0) {
|
throw new Error(`Delimiter must be a scalar, got shape: ${delimiter.shape}`);
|
}
|
const $input = backend.data.get(input.dataId).values;
|
const $delimiter = backend.data.get(delimiter.dataId).values[0];
|
const [indices, values, shape] = stringSplitImpl($input, $delimiter, skipEmpty);
|
const outputSize = values.length;
|
return [
|
backend.makeTensorInfo([outputSize, 2], 'int32', indices),
|
backend.makeTensorInfo([outputSize], 'string', values),
|
backend.makeTensorInfo([2], 'int32', new Int32Array(shape))
|
];
|
}
|
const stringSplitConfig = {
|
kernelName: StringSplit,
|
backendName: 'cpu',
|
kernelFunc: stringSplit,
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function stringToHashBucketFast(args) {
|
const { inputs, backend, attrs } = args;
|
const { numBuckets } = attrs;
|
const { input } = inputs;
|
if (input.dtype !== 'string') {
|
throw new Error('Input must be of datatype string');
|
}
|
if (numBuckets <= 0) {
|
throw new Error(`Number of buckets must be at least 1`);
|
}
|
const $input = backend.data.get(input.dataId).values;
|
const output = stringToHashBucketFastImpl($input, numBuckets);
|
return backend.makeTensorInfo(input.shape, 'int32', output);
|
}
|
const stringToHashBucketFastConfig = {
|
kernelName: StringToHashBucketFast,
|
backendName: 'cpu',
|
kernelFunc: stringToHashBucketFast,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const tan = unaryKernelFunc(Tan, (xi) => Math.tan(xi));
|
const tanConfig = {
|
kernelName: Tan,
|
backendName: 'cpu',
|
kernelFunc: tan,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
const tanh = unaryKernelFunc(Tanh, (xi) => Math.tanh(xi));
|
const tanhConfig = {
|
kernelName: Tanh,
|
backendName: 'cpu',
|
kernelFunc: tanh,
|
};
|
|
/**
|
* @license
|
* Copyright 2022 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function tensorScatterUpdate(args) {
|
const { inputs, backend } = args;
|
const { tensor, indices, updates } = inputs;
|
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = backend_util.calculateShapes(updates, indices, tensor.shape);
|
const sumDupeIndices = false;
|
const indicesBuf = backend.bufferSync(indices);
|
const updatesBuf = backend.bufferSync(updates);
|
const tensorBuf = backend.bufferSync(tensor);
|
const outBuf = scatterImpl(indicesBuf, updatesBuf, tensor.shape, outputSize, sliceSize, numUpdates, sliceRank, strides, tensorBuf, sumDupeIndices);
|
return backend.makeTensorInfo(tensor.shape, outBuf.dtype, outBuf.values);
|
}
|
const tensorScatterUpdateConfig = {
|
kernelName: TensorScatterUpdate,
|
backendName: 'cpu',
|
kernelFunc: tensorScatterUpdate
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function tile(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { reps } = attrs;
|
assertNotComplex(x, 'tile');
|
const outBuf = tileImpl(backend.bufferSync(x), reps);
|
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
|
}
|
const tileConfig = {
|
kernelName: Tile,
|
backendName: 'cpu',
|
kernelFunc: tile
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function topK(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { k, sorted } = attrs;
|
assertNotComplex(x, 'topk');
|
const xVals = backend.data.get(x.dataId).values;
|
const [allTopKVals, allTopKIndices] = topKImpl(xVals, x.shape, x.dtype, k, sorted);
|
return [
|
backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
|
backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
|
];
|
}
|
const topKConfig = {
|
kernelName: TopK,
|
backendName: 'cpu',
|
kernelFunc: topK
|
};
|
|
/**
|
* @license
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function transform(args) {
|
const { inputs, attrs, backend } = args;
|
const { image, transforms } = inputs;
|
const { interpolation, fillMode, fillValue, outputShape } = attrs;
|
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
|
const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth];
|
const outShape = [batch, outHeight, outWidth, numChannels];
|
const inStrides = util.computeStrides(image.shape);
|
const batchInStride = inStrides[0];
|
const rowInStride = inStrides[1];
|
const colInStride = inStrides[2];
|
const outStrides = util.computeStrides(outShape);
|
const batchOutStride = outStrides[0];
|
const rowOutStride = outStrides[1];
|
const colOutStride = outStrides[2];
|
const outVals = util.getTypedArrayFromDType(image.dtype, util.sizeFromShape(outShape));
|
outVals.fill(fillValue);
|
const imageVals = backend.data.get(image.dataId).values;
|
const transformVals = backend.data.get(transforms.dataId).values;
|
// Ref TF implementation:
|
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/image/image_ops.h
|
for (let b = 0; b < batch; ++b) {
|
const transform = transforms.shape[0] === 1 ?
|
transformVals :
|
transformVals.subarray(b * 8, b * 8 + 8);
|
for (let outY = 0; outY < outHeight; ++outY) {
|
for (let outX = 0; outX < outWidth; ++outX) {
|
for (let channel = 0; channel < numChannels; ++channel) {
|
let val;
|
const projection = transform[6] * outX + transform[7] * outY + 1;
|
if (projection === 0) {
|
// Return the fill value for infinite coordinates,
|
// which are outside the input image
|
continue;
|
}
|
const inX = (transform[0] * outX + transform[1] * outY + transform[2]) /
|
projection;
|
const inY = (transform[3] * outX + transform[4] * outY + transform[5]) /
|
projection;
|
const x = mapCoord(inX, imageWidth, fillMode);
|
const y = mapCoord(inY, imageHeight, fillMode);
|
switch (interpolation) {
|
case 'nearest':
|
val = nearestInterpolation(imageVals, imageHeight, imageWidth, batchInStride, rowInStride, colInStride, b, y, x, channel, fillValue);
|
break;
|
case 'bilinear':
|
val = bilinearInterpolation(imageVals, imageHeight, imageWidth, batchInStride, rowInStride, colInStride, b, y, x, channel, fillValue);
|
break;
|
default:
|
throw new Error(`Error in Transform: Expect 'nearest' or ` +
|
`'bilinear', but got ${interpolation}`);
|
}
|
const ind = b * batchOutStride + outY * rowOutStride +
|
outX * colOutStride + channel;
|
outVals[ind] = val;
|
}
|
}
|
}
|
return backend.makeTensorInfo(outShape, image.dtype, outVals);
|
}
|
const dataId = backend.write(outVals, outShape, image.dtype);
|
return { dataId, shape: image.shape, dtype: image.dtype };
|
}
|
const transformConfig = {
|
kernelName: Transform,
|
backendName: 'cpu',
|
kernelFunc: transform
|
};
|
function mapCoord(outCoord, len, mode) {
|
switch (mode) {
|
case 'reflect':
|
return mapCoordReflect(outCoord, len);
|
case 'wrap':
|
return mapCoordWrap(outCoord, len);
|
case 'nearest':
|
return mapCoordNearest(outCoord, len);
|
case 'constant':
|
default:
|
return mapCoordConstant(outCoord);
|
}
|
}
|
function mapCoordReflect(outCoord, len) {
|
// Reflect [abcd] to [dcba|abcd|dcba].
|
let inCoord = outCoord;
|
if (inCoord < 0) {
|
if (len <= 1) {
|
inCoord = 0;
|
}
|
else {
|
const sz2 = 2 * len;
|
if (inCoord < sz2) {
|
inCoord = sz2 * Math.trunc(-inCoord / sz2) + inCoord;
|
}
|
inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1;
|
}
|
}
|
else if (inCoord > len - 1) {
|
if (len <= 1) {
|
inCoord = 0;
|
}
|
else {
|
const sz2 = 2 * len;
|
inCoord -= sz2 * Math.trunc(inCoord / sz2);
|
if (inCoord >= len) {
|
inCoord = sz2 - inCoord - 1;
|
}
|
}
|
}
|
// clamp is necessary because when outCoord = 3.5 and len = 4,
|
// inCoord = 3.5 and will be rounded to 4 in nearest interpolation.
|
return util.clamp(0, inCoord, len - 1);
|
}
|
function mapCoordWrap(outCoord, len) {
|
// Wrap [abcd] to [abcd|abcd|abcd].
|
let inCoord = outCoord;
|
if (inCoord < 0) {
|
if (len <= 1) {
|
inCoord = 0;
|
}
|
else {
|
const sz = len - 1;
|
inCoord += len * (Math.trunc(-inCoord / sz) + 1);
|
}
|
}
|
else if (inCoord > len - 1) {
|
if (len <= 1) {
|
inCoord = 0;
|
}
|
else {
|
const sz = len - 1;
|
inCoord -= len * Math.trunc(inCoord / sz);
|
}
|
}
|
// clamp is necessary because when outCoord = -0.5 and len = 4,
|
// inCoord = 3.5 and will be rounded to 4 in nearest interpolation.
|
return util.clamp(0, inCoord, len - 1);
|
}
|
function mapCoordConstant(outCoord, len) {
|
return outCoord;
|
}
|
function mapCoordNearest(outCoord, len) {
|
return util.clamp(0, outCoord, len - 1);
|
}
|
function readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
|
const ind = batch * batchStride + y * rowStride + x * colStride + channel;
|
if (0 <= y && y < imageHeight && 0 <= x && x < imageWidth) {
|
return imageVals[ind];
|
}
|
else {
|
return fillValue;
|
}
|
}
|
function nearestInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
|
const $y = Math.round(y);
|
const $x = Math.round(x);
|
return readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, $y, $x, channel, fillValue);
|
}
|
function bilinearInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
|
const yFloor = Math.floor(y);
|
const xFloor = Math.floor(x);
|
const yCeil = yFloor + 1;
|
const xCeil = xFloor + 1;
|
// f(x, yFloor) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yFloor)
|
// + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yFloor)
|
const valueYFloor = (xCeil - x) *
|
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xFloor, channel, fillValue) +
|
(x - xFloor) *
|
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xCeil, channel, fillValue);
|
// f(x, yCeil) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yCeil)
|
// + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yCeil)
|
const valueYCeil = (xCeil - x) *
|
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xFloor, channel, fillValue) +
|
(x - xFloor) *
|
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xCeil, channel, fillValue);
|
// f(x, y) = (yCeil - y) / (yCeil - yFloor) * f(x, yFloor)
|
// + (y - yFloor) / (yCeil - yFloor) * f(x, yCeil)
|
return (yCeil - y) * valueYFloor + (y - yFloor) * valueYCeil;
|
}
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the License);
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an AS IS BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function unique(args) {
|
const { inputs, attrs, backend } = args;
|
const { axis } = attrs;
|
const { x } = inputs;
|
assertNotComplex(x, 'unique');
|
const values = backend.data.get(x.dataId).values;
|
const { outputValues, outputShape, indices } = uniqueImpl(values, axis, x.shape, x.dtype);
|
return [
|
backend.makeTensorInfo(outputShape, x.dtype, outputValues),
|
backend.makeTensorInfo([indices.length], 'int32', indices),
|
];
|
}
|
const uniqueConfig = {
|
kernelName: Unique,
|
backendName: 'cpu',
|
kernelFunc: unique,
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function unpack(args) {
|
const { inputs, backend, attrs } = args;
|
const { value } = inputs;
|
let { axis } = attrs;
|
if (axis < 0) {
|
axis += value.shape.length;
|
}
|
const valueRank = value.shape.length;
|
const num = value.shape[axis];
|
const outShape = new Array(valueRank - 1);
|
let outIndex = 0;
|
for (let i = 0; i < valueRank; i++) {
|
if (i !== axis) {
|
outShape[outIndex++] = value.shape[i];
|
}
|
}
|
const begin = new Array(valueRank).fill(0);
|
const size = value.shape.slice();
|
size[axis] = 1;
|
const res = new Array(num);
|
for (let i = 0; i < res.length; i++) {
|
begin[axis] = i;
|
const tempRes = slice({ inputs: { x: value }, backend, attrs: { begin, size } });
|
res[i] = reshape({ inputs: { x: tempRes }, backend, attrs: { shape: outShape } });
|
backend.disposeIntermediateTensorInfo(tempRes);
|
}
|
return res;
|
}
|
const unpackConfig = {
|
kernelName: Unpack,
|
backendName: 'cpu',
|
kernelFunc: unpack
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
function unsortedSegmentSum(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, segmentIds } = inputs;
|
const { numSegments } = attrs;
|
assertNotComplex(x, 'unsortedSegmentSum');
|
const xRank = x.shape.length;
|
const segmentIdsRank = segmentIds.shape.length;
|
const res = [];
|
const intermediates = [];
|
// Reshape the segment id's so that they can be broadcast with
|
// x. The new shape should be [segmentIds.shape, 1, ..., 1]
|
const numIters = xRank - segmentIdsRank;
|
let $segmentIds = segmentIds;
|
for (let i = 0; i < numIters; ++i) {
|
const expanded = expandDims({ inputs: { input: $segmentIds }, backend, attrs: { dim: i + 1 } });
|
$segmentIds = expanded;
|
intermediates.push(expanded);
|
}
|
for (let i = 0; i < numSegments; ++i) {
|
const scalarValue = util.createScalarValue(i, 'int32');
|
const segmentId = backend.makeTensorInfo([], 'int32', scalarValue);
|
const mask = equal({ inputs: { a: segmentId, b: $segmentIds }, backend });
|
const maskCasted = cast({ inputs: { x: mask }, backend, attrs: { dtype: 'float32' } });
|
const mul = multiply({ inputs: { a: maskCasted, b: x }, backend });
|
const sumTensorInfo = sum({ inputs: { x: mul }, backend, attrs: { axis: 0, keepDims: false } });
|
res.push(sumTensorInfo);
|
intermediates.push(segmentId);
|
intermediates.push(mask);
|
intermediates.push(maskCasted);
|
intermediates.push(mul);
|
intermediates.push(sumTensorInfo);
|
}
|
const result = pack({ inputs: res, backend, attrs: { axis: 0 } });
|
intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
return result;
|
}
|
const unsortedSegmentSumConfig = {
|
kernelName: UnsortedSegmentSum,
|
backendName: 'cpu',
|
kernelFunc: unsortedSegmentSum
|
};
|
|
/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
// List all kernel configs here
|
const kernelConfigs = [
|
_fusedMatMulConfig,
|
absConfig,
|
acosConfig,
|
acoshConfig,
|
addConfig,
|
addNConfig,
|
allConfig,
|
anyConfig,
|
argMaxConfig,
|
argMinConfig,
|
asinConfig,
|
asinhConfig,
|
atanConfig,
|
atan2Config,
|
atanhConfig,
|
avgPoolConfig,
|
avgPool3DConfig,
|
avgPool3DGradConfig,
|
avgPoolGradConfig,
|
batchMatMulConfig,
|
batchNormConfig,
|
batchToSpaceNDConfig,
|
bincountConfig,
|
bitwiseAndConfig,
|
broadcastArgsConfig,
|
castConfig,
|
ceilConfig,
|
clipByValueConfig,
|
complexConfig,
|
complexAbsConfig,
|
concatConfig,
|
conv2DConfig,
|
conv2DBackpropFilterConfig,
|
conv2DBackpropInputConfig,
|
conv3DConfig,
|
conv3DBackpropFilterV2Config,
|
conv3DBackpropInputV2Config,
|
cosConfig,
|
coshConfig,
|
cropAndResizeConfig,
|
cumprodConfig,
|
cumsumConfig,
|
denseBincountConfig,
|
depthToSpaceConfig,
|
depthwiseConv2dNativeConfig,
|
depthwiseConv2dNativeBackpropFilterConfig,
|
depthwiseConv2dNativeBackpropInputConfig,
|
diagConfig,
|
dilation2DConfig,
|
dilation2DBackpropFilterConfig,
|
dilation2DBackpropInputConfig,
|
drawConfig,
|
einsumConfig,
|
eluConfig,
|
eluGradConfig,
|
equalConfig,
|
erfConfig,
|
expConfig,
|
expandDimsConfig,
|
expm1Config,
|
fftConfig,
|
fillConfig,
|
flipLeftRightConfig,
|
floorConfig,
|
floorDivConfig,
|
fusedConv2DConfig,
|
fusedDepthwiseConv2DConfig,
|
gatherNdConfig,
|
gatherV2Config,
|
greaterConfig,
|
greaterEqualConfig,
|
identityConfig,
|
ifftConfig,
|
imagConfig,
|
isFiniteConfig,
|
isInfConfig,
|
isNaNConfig,
|
leakyReluConfig,
|
lessConfig,
|
lessEqualConfig,
|
linSpaceConfig,
|
logConfig,
|
log1pConfig,
|
logicalAndConfig,
|
logicalNotConfig,
|
logicalOrConfig,
|
LRNConfig,
|
LRNGradConfig,
|
maxConfig,
|
maximumConfig,
|
maxPoolConfig,
|
maxPool3DConfig,
|
maxPool3DGradConfig,
|
maxPoolGradConfig,
|
maxPoolWithArgmaxConfig,
|
meanConfig,
|
minConfig,
|
minimumConfig,
|
mirrorPadConfig,
|
modConfig,
|
multinomialConfig,
|
multiplyConfig,
|
negConfig,
|
nonMaxSuppressionV3Config,
|
nonMaxSuppressionV4Config,
|
nonMaxSuppressionV5Config,
|
notEqualConfig,
|
oneHotConfig,
|
onesLikeConfig,
|
packConfig,
|
padV2Config,
|
powConfig,
|
preluConfig,
|
prodConfig,
|
raggedGatherConfig,
|
raggedRangeConfig,
|
raggedTensorToTensorConfig,
|
rangeConfig,
|
realConfig,
|
realDivConfig,
|
reciprocalConfig,
|
reluConfig,
|
relu6Config,
|
reshapeConfig,
|
resizeBilinearConfig,
|
resizeBilinearGradConfig,
|
resizeNearestNeighborConfig,
|
resizeNearestNeighborGradConfig,
|
reverseConfig,
|
rotateWithOffsetConfig,
|
roundConfig,
|
rsqrtConfig,
|
scatterNdConfig,
|
searchSortedConfig,
|
selectConfig,
|
seluConfig,
|
sigmoidConfig,
|
signConfig,
|
sinConfig,
|
sinhConfig,
|
sliceConfig,
|
softmaxConfig,
|
softplusConfig,
|
spaceToBatchNDConfig,
|
sparseFillEmptyRowsConfig,
|
sparseReshapeConfig,
|
sparseSegmentMeanConfig,
|
sparseSegmentSumConfig,
|
sparseToDenseConfig,
|
splitVConfig,
|
sqrtConfig,
|
squareConfig,
|
squaredDifferenceConfig,
|
staticRegexReplaceConfig,
|
stepConfig,
|
stridedSliceConfig,
|
stringNGramsConfig,
|
stringSplitConfig,
|
stringToHashBucketFastConfig,
|
subConfig,
|
sumConfig,
|
tanConfig,
|
tanhConfig,
|
tensorScatterUpdateConfig,
|
tileConfig,
|
topKConfig,
|
transformConfig,
|
transposeConfig,
|
uniqueConfig,
|
unpackConfig,
|
unsortedSegmentSumConfig,
|
zerosLikeConfig
|
];
|
for (const kernelConfig of kernelConfigs) {
|
registerKernel(kernelConfig);
|
}
|
|
export { MathBackendCPU, shared, version as version_cpu };
|
//# sourceMappingURL=tf-backend-cpu.fesm.js.map
|