/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { ENGINE } from '../engine';
|
import { Conv3DBackpropInputV2 } from '../kernel_names';
|
import * as util from '../util';
|
import { op } from './operation';
|
import { reshape } from './reshape';
|
/**
|
* Computes the derivative of the input of a 3D convolution.
|
*
|
* @param xShape The shape of the input: [batch, depth, height, width,
|
* in_channels]. If length of 4, batch of 1 is assumed.
|
* @param dy The derivative of the output, of rank 5 or rank 4 of shape
|
* `[batch, outDepth, outHeight, outWidth, in_channels]`.
|
* If rank 4, batch of 1 is assumed.
|
* @param filter The filter, rank 5, of shape
|
* `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`.
|
* @param strides The strides of the convolution: `[strideDepth, strideHeight,
|
* strideWidth]`.
|
* @param pad The type of padding algorithm used:
|
* - `same` and stride 1: output will be of same size as input,
|
* regardless of filter size.
|
* - `valid`: output will be smaller than input if filter is larger
|
* than 1x1.
|
*/
|
function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
|
util.assert(xShape.length === dy.rank, () => `Length of inShape ` +
|
`(${xShape.length}) and rank of dy (${dy.rank}) must match`);
|
let xShape5D = xShape;
|
let dy5D = dy;
|
let reshapedTo5D = false;
|
if (dy.rank === 4) {
|
reshapedTo5D = true;
|
dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
|
xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
|
}
|
const inDepth = xShape5D[4];
|
const outDepth = dy5D.shape[4];
|
util.assert(xShape5D.length === 5, () => `Error in conv3dDerInput: inShape must be length 5, but got length ` +
|
`${xShape5D.length}.`);
|
util.assert(dy5D.rank === 5, () => `Error in conv3dDerInput: dy must be rank 5, but got ` +
|
`rank ${dy5D.rank}`);
|
util.assert(filter.rank === 5, () => `Error in conv3dDerInput: filter must be rank 5, but got ` +
|
`rank ${filter.rank}`);
|
util.assert(inDepth === filter.shape[3], () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` +
|
`match input depth for filter ${filter.shape[3]}.`);
|
util.assert(outDepth === filter.shape[4], () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` +
|
`match output depth for filter ${filter.shape[4]}.`);
|
const inputs = { dy: dy5D, filter };
|
const attrs = { pad, strides, inputShape: xShape5D };
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
const res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs);
|
if (reshapedTo5D) {
|
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
|
}
|
return res;
|
}
|
export const conv3DBackpropInput = /* @__PURE__ */ op({ conv3DBackpropInput_ });
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY29udjNkX2JhY2twcm9wX2lucHV0LmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvY29udjNkX2JhY2twcm9wX2lucHV0LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUNILE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDakMsT0FBTyxFQUFDLHFCQUFxQixFQUEwRCxNQUFNLGlCQUFpQixDQUFDO0FBSS9HLE9BQU8sS0FBSyxJQUFJLE1BQU0sU0FBUyxDQUFDO0FBRWhDLE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDL0IsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUVsQzs7Ozs7Ozs7Ozs7Ozs7Ozs7R0FpQkc7QUFDSCxTQUFTLG9CQUFvQixDQUN6QixNQUU2QyxFQUM3QyxFQUFLLEVBQUUsTUFBZ0IsRUFBRSxPQUF3QyxFQUNqRSxHQUFtQjtJQUNyQixJQUFJLENBQUMsTUFBTSxDQUNQLE1BQU0sQ0FBQyxNQUFNLEtBQUssRUFBRSxDQUFDLElBQUksRUFDekIsR0FBRyxFQUFFLENBQUMsb0JBQW9CO1FBQ3RCLElBQUksTUFBTSxDQUFDLE1BQU0scUJBQXFCLEVBQUUsQ0FBQyxJQUFJLGNBQWMsQ0FBQyxDQUFDO0lBRXJFLElBQUksUUFBUSxHQUFHLE1BQWtELENBQUM7SUFDbEUsSUFBSSxJQUFJLEdBQUcsRUFBYyxDQUFDO0lBQzFCLElBQUksWUFBWSxHQUFHLEtBQUssQ0FBQztJQUN6QixJQUFJLEVBQUUsQ0FBQyxJQUFJLEtBQUssQ0FBQyxFQUFFO1FBQ2pCLFlBQVksR0FBRyxJQUFJLENBQUM7UUFDcEIsSUFBSSxHQUFHLE9BQU8sQ0FBQyxFQUFFLEVBQUUsQ0FBQyxDQUFDLEVBQUUsRUFBRSxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxFQUFFLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLEVBQUUsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQUUsRUFBRSxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDNUUsUUFBUSxHQUFHLENBQUMsQ0FBQyxFQUFFLE1BQU0sQ0FBQyxDQUFDLENBQUMsRUFBRSxNQUFNLENBQUMsQ0FBQyxDQUFDLEVBQUUsTUFBTSxDQUFDLENBQUMsQ0FBQyxFQUFFLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQzVEO0lBRUQsTUFBTSxPQUFPLEdBQUcsUUFBUSxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQzVCLE1BQU0sUUFBUSxHQUFHLElBQUksQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDL0IsSUFBSSxDQUFDLE1BQU0sQ0FDUCxRQUFRLENBQUMsTUFBTSxLQUFLLENBQUMsRUFDckIsR0FBRyxFQUFFLENBQ0Qsb0VBQW9FO1FBQ3BFLEdBQUcsUUFBUSxDQUFDLE1BQU0sR0FBRyxDQUFDLENBQUM7SUFDL0IsSUFBSSxDQUFDLE1BQU0sQ0FDUCxJQUFJLENBQUMsSUFBSSxLQUFLLENBQUMsRUFDZixHQUFHLEVBQUUsQ0FBQyxzREFBc0Q7UUFDeEQsUUFBUSxJQUFJLENBQUMsSUFBSSxFQUFFLENBQUMsQ0FBQztJQUM3QixJQUFJLENBQUMsTUFBTSxDQUNQLE1BQU0sQ0FBQyxJQUFJLEtBQUssQ0FBQyxFQUNqQixHQUFHLEVBQUUsQ0FBQywwREFBMEQ7UUFDNUQsUUFBUSxNQUFNLENBQUMsSUFBSSxFQUFFLENBQUMsQ0FBQztJQUMvQixJQUFJLENBQUMsTUFBTSxDQUNQLE9BQU8sS0FBSyxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUMzQixHQUFHLEVBQUUsQ0FBQyw0Q0FBNEMsT0FBTyxTQUFTO1FBQzlELGdDQUFnQyxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQztJQUM1RCxJQUFJLENBQUMsTUFBTSxDQUNQLFFBQVEsS0FBSyxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUM1QixHQUFHLEVBQUUsQ0FBQyw2Q0FBNkMsUUFBUSxTQUFTO1FBQ2hFLGlDQUFpQyxNQUFNLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQztJQUU3RCxNQUFNLE1BQU0sR0FBZ0MsRUFBQyxFQUFFLEVBQUUsSUFBSSxFQUFFLE1BQU0sRUFBQyxDQUFDO0lBRS9ELE1BQU0sS0FBSyxHQUNzQixFQUFDLEdBQUcsRUFBRSxPQUFPLEVBQUUsVUFBVSxFQUFFLFFBQVEsRUFBQyxDQUFDO0lBRXRFLDBEQUEwRDtJQUMxRCxNQUFNLEdBQUcsR0FBRyxNQUFNLENBQUMsU0FBUyxDQUNaLHFCQUFxQixFQUFFLE1BQW1DLEVBQzFELEtBQWdDLENBQU0sQ0FBQztJQUV2RCxJQUFJLFlBQVksRUFBRTtRQUNoQixPQUFPLE9BQU8sQ0FDSCxHQUFHLEVBQUUsQ0FBQyxHQUFHLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxFQUFFLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQUUsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsRUFBRSxHQUFHLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQ25FLENBQUM7S0FDUDtJQUNELE9BQU8sR0FBRyxDQUFDO0FBQ2IsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLG1CQUFtQixHQUFHLGVBQWUsQ0FBQyxFQUFFLENBQUMsRUFBQyxvQkFBb0IsRUFBQyxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5pbXBvcnQge0VOR0lORX0gZnJvbSAnLi4vZW5naW5lJztcbmltcG9ydCB7Q29udjNEQmFja3Byb3BJbnB1dFYyLCBDb252M0RCYWNrcHJvcElucHV0VjJBdHRycywgQ29udjNEQmFja3Byb3BJbnB1dFYySW5wdXRzfSBmcm9tICcuLi9rZXJuZWxfbmFtZXMnO1xuaW1wb3J0IHtOYW1lZEF0dHJNYXB9IGZyb20gJy4uL2tlcm5lbF9yZWdpc3RyeSc7XG5pbXBvcnQge1RlbnNvcjRELCBUZW5zb3I1RH0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7TmFtZWRUZW5zb3JNYXB9IGZyb20gJy4uL3RlbnNvcl90eXBlcyc7XG5pbXBvcnQgKiBhcyB1dGlsIGZyb20gJy4uL3V0aWwnO1xuXG5pbXBvcnQge29wfSBmcm9tICcuL29wZXJhdGlvbic7XG5pbXBvcnQge3Jlc2hhcGV9IGZyb20gJy4vcmVzaGFwZSc7XG5cbi8qKlxuICogQ29tcHV0ZXMgdGhlIGRlcml2YXRpdmUgb2YgdGhlIGlucHV0IG9mIGEgM0QgY29udm9sdXRpb24uXG4gKlxuICogQHBhcmFtIHhTaGFwZSBUaGUgc2hhcGUgb2YgdGhlIGlucHV0OiBbYmF0Y2gsIGRlcHRoLCBoZWlnaHQsIHdpZHRoLFxuICogaW5fY2hhbm5lbHNdLiBJZiBsZW5ndGggb2YgNCwgYmF0Y2ggb2YgMSBpcyBhc3N1bWVkLlxuICogQHBhcmFtIGR5IFRoZSBkZXJpdmF0aXZlIG9mIHRoZSBvdXRwdXQsIG9mIHJhbmsgNSBvciByYW5rIDQgb2Ygc2hhcGVcbiAqICAgYFtiYXRjaCwgb3V0RGVwdGgsIG91dEhlaWdodCwgb3V0V2lkdGgsIGluX2NoYW5uZWxzXWAuXG4gKiBJZiByYW5rIDQsIGJhdGNoIG9mIDEgaXMgYXNzdW1lZC5cbiAqIEBwYXJhbSBmaWx0ZXIgVGhlIGZpbHRlciwgcmFuayA1LCBvZiBzaGFwZVxuICogICAgIGBbZmlsdGVyRGVwdGgsIGZpbHRlckhlaWdodCwgZmlsdGVyV2lkdGgsIGluRGVwdGgsIG91dERlcHRoXWAuXG4gKiBAcGFyYW0gc3RyaWRlcyBUaGUgc3RyaWRlcyBvZiB0aGUgY29udm9sdXRpb246IGBbc3RyaWRlRGVwdGgsIHN0cmlkZUhlaWdodCxcbiAqIHN0cmlkZVdpZHRoXWAuXG4gKiBAcGFyYW0gcGFkIFRoZSB0eXBlIG9mIHBhZGRpbmcgYWxnb3JpdGhtIHVzZWQ6XG4gKiAgICAtIGBzYW1lYCBhbmQgc3RyaWRlIDE6IG91dHB1dCB3aWxsIGJlIG9mIHNhbWUgc2l6ZSBhcyBpbnB1dCxcbiAqICAgICAgIHJlZ2FyZGxlc3Mgb2YgZmlsdGVyIHNpemUuXG4gKiAgICAtIGB2YWxpZGA6IG91dHB1dCB3aWxsIGJlIHNtYWxsZXIgdGhhbiBpbnB1dCBpZiBmaWx0ZXIgaXMgbGFyZ2VyXG4gKiAgICAgICB0aGFuIDF4MS5cbiAqL1xuZnVuY3Rpb24gY29udjNEQmFja3Byb3BJbnB1dF88VCBleHRlbmRzIFRlbnNvcjREfFRlbnNvcjVEPihcbiAgICB4U2hhcGU6XG4gICAgICAgIFtudW1iZXIsIG51bWJlciwgbnVtYmVyLCBudW1iZXIsXG4gICAgICAgICBudW1iZXJdfFtudW1iZXIsIG51bWJlciwgbnVtYmVyLCBudW1iZXJdLFxuICAgIGR5OiBULCBmaWx0ZXI6IFRlbnNvcjVELCBzdHJpZGVzOiBbbnVtYmVyLCBudW1iZXIsIG51bWJlcl18bnVtYmVyLFxuICAgIHBhZDogJ3ZhbGlkJ3wnc2FtZScpOiBUIHtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICB4U2hhcGUubGVuZ3RoID09PSBkeS5yYW5rLFxuICAgICAgKCkgPT4gYExlbmd0aCBvZiBpblNoYXBlIGAgK1xuICAgICAgICAgIGAoJHt4U2hhcGUubGVuZ3RofSkgYW5kIHJhbmsgb2YgZHkgKCR7ZHkucmFua30pIG11c3QgbWF0Y2hgKTtcblxuICBsZXQgeFNoYXBlNUQgPSB4U2hhcGUgYXMgW251bWJlciwgbnVtYmVyLCBudW1iZXIsIG51bWJlciwgbnVtYmVyXTtcbiAgbGV0IGR5NUQgPSBkeSBhcyBUZW5zb3I1RDtcbiAgbGV0IHJlc2hhcGVkVG81RCA9IGZhbHNlO1xuICBpZiAoZHkucmFuayA9PT0gNCkge1xuICAgIHJlc2hhcGVkVG81RCA9IHRydWU7XG4gICAgZHk1RCA9IHJlc2hhcGUoZHksIFsxLCBkeS5zaGFwZVswXSwgZHkuc2hhcGVbMV0sIGR5LnNoYXBlWzJdLCBkeS5zaGFwZVszXV0pO1xuICAgIHhTaGFwZTVEID0gWzEsIHhTaGFwZVswXSwgeFNoYXBlWzFdLCB4U2hhcGVbMl0sIHhTaGFwZVszXV07XG4gIH1cblxuICBjb25zdCBpbkRlcHRoID0geFNoYXBlNURbNF07XG4gIGNvbnN0IG91dERlcHRoID0gZHk1RC5zaGFwZVs0XTtcbiAgdXRpbC5hc3NlcnQoXG4gICAgICB4U2hhcGU1RC5sZW5ndGggPT09IDUsXG4gICAgICAoKSA9PlxuICAgICAgICAgIGBFcnJvciBpbiBjb252M2REZXJJbnB1dDogaW5TaGFwZSBtdXN0IGJlIGxlbmd0aCA1LCBidXQgZ290IGxlbmd0aCBgICtcbiAgICAgICAgICBgJHt4U2hhcGU1RC5sZW5ndGh9LmApO1xuICB1dGlsLmFzc2VydChcbiAgICAgIGR5NUQucmFuayA9PT0gNSxcbiAgICAgICgpID0+IGBFcnJvciBpbiBjb252M2REZXJJbnB1dDogZHkgbXVzdCBiZSByYW5rIDUsIGJ1dCBnb3QgYCArXG4gICAgICAgICAgYHJhbmsgJHtkeTVELnJhbmt9YCk7XG4gIHV0aWwuYXNzZXJ0KFxuICAgICAgZmlsdGVyLnJhbmsgPT09IDUsXG4gICAgICAoKSA9PiBgRXJyb3IgaW4gY29udjNkRGVySW5wdXQ6IGZpbHRlciBtdXN0IGJlIHJhbmsgNSwgYnV0IGdvdCBgICtcbiAgICAgICAgICBgcmFuayAke2ZpbHRlci5yYW5rfWApO1xuICB1dGlsLmFzc2VydChcbiAgICAgIGluRGVwdGggPT09IGZpbHRlci5zaGFwZVszXSxcbiAgICAgICgpID0+IGBFcnJvciBpbiBjb252M2REZXJJbnB1dDogZGVwdGggb2YgaW5wdXQgKCR7aW5EZXB0aH0pIG11c3QgYCArXG4gICAgICAgICAgYG1hdGNoIGlucHV0IGRlcHRoIGZvciBmaWx0ZXIgJHtmaWx0ZXIuc2hhcGVbM119LmApO1xuICB1dGlsLmFzc2VydChcbiAgICAgIG91dERlcHRoID09PSBmaWx0ZXIuc2hhcGVbNF0sXG4gICAgICAoKSA9PiBgRXJyb3IgaW4gY29udjNkRGVySW5wdXQ6IGRlcHRoIG9mIG91dHB1dCAoJHtvdXREZXB0aH0pIG11c3QgYCArXG4gICAgICAgICAgYG1hdGNoIG91dHB1dCBkZXB0aCBmb3IgZmlsdGVyICR7ZmlsdGVyLnNoYXBlWzRdfS5gKTtcblxuICBjb25zdCBpbnB1dHM6IENvbnYzREJhY2twcm9wSW5wdXRWMklucHV0cyA9IHtkeTogZHk1RCwgZmlsdGVyfTtcblxuICBjb25zdCBhdHRyczpcbiAgICAgIENvbnYzREJhY2twcm9wSW5wdXRWMkF0dHJzID0ge3BhZCwgc3RyaWRlcywgaW5wdXRTaGFwZTogeFNoYXBlNUR9O1xuXG4gIC8vIHRzbGludDpkaXNhYmxlLW5leHQtbGluZTogbm8tdW5uZWNlc3NhcnktdHlwZS1hc3NlcnRpb25cbiAgY29uc3QgcmVzID0gRU5HSU5FLnJ1bktlcm5lbChcbiAgICAgICAgICAgICAgICAgIENvbnYzREJhY2twcm9wSW5wdXRWMiwgaW5wdXRzIGFzIHVua25vd24gYXMgTmFtZWRUZW5zb3JNYXAsXG4gICAgICAgICAgICAgICAgICBhdHRycyBhcyB1bmtub3duIGFzIE5hbWVkQXR0ck1hcCkgYXMgVDtcblxuICBpZiAocmVzaGFwZWRUbzVEKSB7XG4gICAgcmV0dXJuIHJlc2hhcGUoXG4gICAgICAgICAgICAgICByZXMsIFtyZXMuc2hhcGVbMV0sIHJlcy5zaGFwZVsyXSwgcmVzLnNoYXBlWzNdLCByZXMuc2hhcGVbNF1dKSBhc1xuICAgICAgICBUO1xuICB9XG4gIHJldHVybiByZXM7XG59XG5cbmV4cG9ydCBjb25zdCBjb252M0RCYWNrcHJvcElucHV0ID0gLyogQF9fUFVSRV9fICovIG9wKHtjb252M0RCYWNrcHJvcElucHV0X30pO1xuIl19
|