/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, MaxPoolGrad } from '@tensorflow/tfjs-core'; import { MaxPool2DBackpropProgram } from '../max_pool_backprop_gpu'; import { Pool2DProgram } from '../pool_gpu'; import { assertNotComplex } from '../webgl_util'; export function maxPoolGrad(args) { const { inputs, backend, attrs } = args; const { dy, input, output } = inputs; const x = input; assertNotComplex([input, output], 'maxPoolGrad'); const { filterSize, strides, pad, dimRoundingMode } = attrs; const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode); const getPositions = true; const maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions); const maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype); const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo); const result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype); backend.disposeIntermediateTensorInfo(maxPoolPositions); return result; } export const maxPoolGradConfig = { kernelName: MaxPoolGrad, backendName: 'webgl', kernelFunc: maxPoolGrad }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTWF4UG9vbEdyYWQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL2tlcm5lbHMvTWF4UG9vbEdyYWQudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBQ0gsT0FBTyxFQUFDLFlBQVksRUFBNEIsV0FBVyxFQUFrRCxNQUFNLHVCQUF1QixDQUFDO0FBRzNJLE9BQU8sRUFBQyx3QkFBd0IsRUFBQyxNQUFNLDBCQUEwQixDQUFDO0FBQ2xFLE9BQU8sRUFBQyxhQUFhLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDMUMsT0FBTyxFQUFDLGdCQUFnQixFQUFDLE1BQU0sZUFBZSxDQUFDO0FBRS9DLE1BQU0sVUFBVSxXQUFXLENBQUMsSUFJM0I7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLEVBQUUsRUFBRSxLQUFLLEVBQUUsTUFBTSxFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ25DLE1BQU0sQ0FBQyxHQUFHLEtBQUssQ0FBQztJQUNoQixnQkFBZ0IsQ0FBQyxDQUFDLEtBQUssRUFBRSxNQUFNLENBQUMsRUFBRSxhQUFhLENBQUMsQ0FBQztJQUNqRCxNQUFNLEVBQUMsVUFBVSxFQUFFLE9BQU8sRUFBRSxHQUFHLEVBQUUsZUFBZSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRTFELE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQyxpQkFBaUIsQ0FDM0MsQ0FBQyxDQUFDLEtBQXlDLEVBQUUsVUFBVSxFQUFFLE9BQU8sRUFDaEUsQ0FBQyxDQUFDLGVBQWUsRUFBRSxHQUFHLEVBQUUsZUFBZSxDQUFDLENBQUM7SUFDN0MsTUFBTSxZQUFZLEdBQUcsSUFBSSxDQUFDO0lBQzFCLE1BQU0sdUJBQXVCLEdBQ3pCLElBQUksYUFBYSxDQUFDLFFBQVEsRUFBRSxLQUFLLEVBQUUsWUFBWSxDQUFDLENBQUM7SUFDckQsTUFBTSxnQkFBZ0IsR0FDbEIsT0FBTyxDQUFDLGVBQWUsQ0FBQyx1QkFBdUIsRUFBRSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUVuRSxNQUFNLHNCQUFzQixHQUFHLElBQUksd0JBQXdCLENBQUMsUUFBUSxDQUFDLENBQUM7SUFDdEUsTUFBTSxNQUFNLEdBQUcsT0FBTyxDQUFDLGVBQWUsQ0FDbEMsc0JBQXNCLEVBQUUsQ0FBQyxFQUFFLEVBQUUsZ0JBQWdCLENBQUMsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUM7SUFDN0QsT0FBTyxDQUFDLDZCQUE2QixDQUFDLGdCQUFnQixDQUFDLENBQUM7SUFDeEQsT0FBTyxNQUFNLENBQUM7QUFDaEIsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLGlCQUFpQixHQUFpQjtJQUM3QyxVQUFVLEVBQUUsV0FBVztJQUN2QixXQUFXLEVBQUUsT0FBTztJQUNwQixVQUFVLEVBQUUsV0FBb0M7Q0FDakQsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cbmltcG9ydCB7YmFja2VuZF91dGlsLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIE1heFBvb2xHcmFkLCBNYXhQb29sR3JhZEF0dHJzLCBNYXhQb29sR3JhZElucHV0cywgVGVuc29ySW5mb30gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtNYXRoQmFja2VuZFdlYkdMfSBmcm9tICcuLi9iYWNrZW5kX3dlYmdsJztcbmltcG9ydCB7TWF4UG9vbDJEQmFja3Byb3BQcm9ncmFtfSBmcm9tICcuLi9tYXhfcG9vbF9iYWNrcHJvcF9ncHUnO1xuaW1wb3J0IHtQb29sMkRQcm9ncmFtfSBmcm9tICcuLi9wb29sX2dwdSc7XG5pbXBvcnQge2Fzc2VydE5vdENvbXBsZXh9IGZyb20gJy4uL3dlYmdsX3V0aWwnO1xuXG5leHBvcnQgZnVuY3Rpb24gbWF4UG9vbEdyYWQoYXJnczoge1xuICBpbnB1dHM6IE1heFBvb2xHcmFkSW5wdXRzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZFdlYkdMLFxuICBhdHRyczogTWF4UG9vbEdyYWRBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7ZHksIGlucHV0LCBvdXRwdXR9ID0gaW5wdXRzO1xuICBjb25zdCB4ID0gaW5wdXQ7XG4gIGFzc2VydE5vdENvbXBsZXgoW2lucHV0LCBvdXRwdXRdLCAnbWF4UG9vbEdyYWQnKTtcbiAgY29uc3Qge2ZpbHRlclNpemUsIHN0cmlkZXMsIHBhZCwgZGltUm91bmRpbmdNb2RlfSA9IGF0dHJzO1xuXG4gIGNvbnN0IGNvbnZJbmZvID0gYmFja2VuZF91dGlsLmNvbXB1dGVQb29sMkRJbmZvKFxuICAgICAgeC5zaGFwZSBhcyBbbnVtYmVyLCBudW1iZXIsIG51bWJlciwgbnVtYmVyXSwgZmlsdGVyU2l6ZSwgc3RyaWRlcyxcbiAgICAgIDEgLyogZGlsYXRpb25zICovLCBwYWQsIGRpbVJvdW5kaW5nTW9kZSk7XG4gIGNvbnN0IGdldFBvc2l0aW9ucyA9IHRydWU7XG4gIGNvbnN0IG1heFBvb2xQb3NpdGlvbnNQcm9ncmFtID1cbiAgICAgIG5ldyBQb29sMkRQcm9ncmFtKGNvbnZJbmZvLCAnbWF4JywgZ2V0UG9zaXRpb25zKTtcbiAgY29uc3QgbWF4UG9vbFBvc2l0aW9uczogVGVuc29ySW5mbyA9XG4gICAgICBiYWNrZW5kLnJ1bldlYkdMUHJvZ3JhbShtYXhQb29sUG9zaXRpb25zUHJvZ3JhbSwgW3hdLCB4LmR0eXBlKTtcblxuICBjb25zdCBtYXhQb29sQmFja1Byb3BQcm9ncmFtID0gbmV3IE1heFBvb2wyREJhY2twcm9wUHJvZ3JhbShjb252SW5mbyk7XG4gIGNvbnN0IHJlc3VsdCA9IGJhY2tlbmQucnVuV2ViR0xQcm9ncmFtKFxuICAgICAgbWF4UG9vbEJhY2tQcm9wUHJvZ3JhbSwgW2R5LCBtYXhQb29sUG9zaXRpb25zXSwgeC5kdHlwZSk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8obWF4UG9vbFBvc2l0aW9ucyk7XG4gIHJldHVybiByZXN1bHQ7XG59XG5cbmV4cG9ydCBjb25zdCBtYXhQb29sR3JhZENvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBNYXhQb29sR3JhZCxcbiAgYmFja2VuZE5hbWU6ICd3ZWJnbCcsXG4gIGtlcm5lbEZ1bmM6IG1heFBvb2xHcmFkIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==