/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, MaxPool, util } from '@tensorflow/tfjs-core'; import { Pool2DProgram } from '../pool_gpu'; import { assertNotComplex } from '../webgl_util'; import { identity } from './Identity'; export function maxPool(args) { const { inputs, backend, attrs } = args; const { x } = inputs; assertNotComplex(x, 'maxPool'); const { filterSize, strides, pad, dimRoundingMode } = attrs; const dilations = 1; util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && util.arraysEqual(convInfo.inShape, convInfo.outShape)) { return identity({ inputs: { x }, backend }); } const maxPoolProgram = new Pool2DProgram(convInfo, 'max', false); return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype); } export const maxPoolConfig = { kernelName: MaxPool, backendName: 'webgl', kernelFunc: maxPool }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTWF4UG9vbC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9NYXhQb29sLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUNILE9BQU8sRUFBQyxZQUFZLEVBQTRCLE9BQU8sRUFBMkMsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFHckksT0FBTyxFQUFDLGFBQWEsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUMxQyxPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxlQUFlLENBQUM7QUFDL0MsT0FBTyxFQUFDLFFBQVEsRUFBQyxNQUFNLFlBQVksQ0FBQztBQUVwQyxNQUFNLFVBQVUsT0FBTyxDQUFDLElBSXZCO0lBQ0MsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxDQUFDLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDbkIsZ0JBQWdCLENBQUMsQ0FBQyxFQUFFLFNBQVMsQ0FBQyxDQUFDO0lBQy9CLE1BQU0sRUFBQyxVQUFVLEVBQUUsT0FBTyxFQUFFLEdBQUcsRUFBRSxlQUFlLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFDMUQsTUFBTSxTQUFTLEdBQUcsQ0FBQyxDQUFDO0lBRXBCLElBQUksQ0FBQyxNQUFNLENBQ1AsWUFBWSxDQUFDLDhCQUE4QixDQUFDLE9BQU8sRUFBRSxTQUFTLENBQUMsRUFDL0QsR0FBRyxFQUFFLENBQUMsMkRBQTJEO1FBQzdELGVBQWUsT0FBTyxtQkFBbUIsU0FBUyxHQUFHLENBQUMsQ0FBQztJQUUvRCxNQUFNLFFBQVEsR0FBRyxZQUFZLENBQUMsaUJBQWlCLENBQzNDLENBQUMsQ0FBQyxLQUF5QyxFQUFFLFVBQVUsRUFBRSxPQUFPLEVBQ2hFLFNBQVMsRUFBRSxHQUFHLEVBQUUsZUFBZSxDQUFDLENBQUM7SUFDckMsSUFBSSxRQUFRLENBQUMsV0FBVyxLQUFLLENBQUMsSUFBSSxRQUFRLENBQUMsWUFBWSxLQUFLLENBQUM7UUFDekQsSUFBSSxDQUFDLFdBQVcsQ0FBQyxRQUFRLENBQUMsT0FBTyxFQUFFLFFBQVEsQ0FBQyxRQUFRLENBQUMsRUFBRTtRQUN6RCxPQUFPLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7S0FDekM7SUFDRCxNQUFNLGNBQWMsR0FBRyxJQUFJLGFBQWEsQ0FBQyxRQUFRLEVBQUUsS0FBSyxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBQ2pFLE9BQU8sT0FBTyxDQUFDLGVBQWUsQ0FBQyxjQUFjLEVBQUUsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUM7QUFDL0QsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLGFBQWEsR0FBaUI7SUFDekMsVUFBVSxFQUFFLE9BQU87SUFDbkIsV0FBVyxFQUFFLE9BQU87SUFDcEIsVUFBVSxFQUFFLE9BQWdDO0NBQzdDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBNYXhQb29sLCBNYXhQb29sQXR0cnMsIE1heFBvb2xJbnB1dHMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5pbXBvcnQge1Bvb2wyRFByb2dyYW19IGZyb20gJy4uL3Bvb2xfZ3B1JztcbmltcG9ydCB7YXNzZXJ0Tm90Q29tcGxleH0gZnJvbSAnLi4vd2ViZ2xfdXRpbCc7XG5pbXBvcnQge2lkZW50aXR5fSBmcm9tICcuL0lkZW50aXR5JztcblxuZXhwb3J0IGZ1bmN0aW9uIG1heFBvb2woYXJnczoge1xuICBpbnB1dHM6IE1heFBvb2xJbnB1dHMsXG4gIGJhY2tlbmQ6IE1hdGhCYWNrZW5kV2ViR0wsXG4gIGF0dHJzOiBNYXhQb29sQXR0cnNcbn0pOiBUZW5zb3JJbmZvIHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZCwgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge3h9ID0gaW5wdXRzO1xuICBhc3NlcnROb3RDb21wbGV4KHgsICdtYXhQb29sJyk7XG4gIGNvbnN0IHtmaWx0ZXJTaXplLCBzdHJpZGVzLCBwYWQsIGRpbVJvdW5kaW5nTW9kZX0gPSBhdHRycztcbiAgY29uc3QgZGlsYXRpb25zID0gMTtcblxuICB1dGlsLmFzc2VydChcbiAgICAgIGJhY2tlbmRfdXRpbC5laXRoZXJTdHJpZGVzT3JEaWxhdGlvbnNBcmVPbmUoc3RyaWRlcywgZGlsYXRpb25zKSxcbiAgICAgICgpID0+ICdFcnJvciBpbiBtYXhQb29sOiBFaXRoZXIgc3RyaWRlcyBvciBkaWxhdGlvbnMgbXVzdCBiZSAxLiAnICtcbiAgICAgICAgICBgR290IHN0cmlkZXMgJHtzdHJpZGVzfSBhbmQgZGlsYXRpb25zICcke2RpbGF0aW9uc30nYCk7XG5cbiAgY29uc3QgY29udkluZm8gPSBiYWNrZW5kX3V0aWwuY29tcHV0ZVBvb2wyREluZm8oXG4gICAgICB4LnNoYXBlIGFzIFtudW1iZXIsIG51bWJlciwgbnVtYmVyLCBudW1iZXJdLCBmaWx0ZXJTaXplLCBzdHJpZGVzLFxuICAgICAgZGlsYXRpb25zLCBwYWQsIGRpbVJvdW5kaW5nTW9kZSk7XG4gIGlmIChjb252SW5mby5maWx0ZXJXaWR0aCA9PT0gMSAmJiBjb252SW5mby5maWx0ZXJIZWlnaHQgPT09IDEgJiZcbiAgICAgIHV0aWwuYXJyYXlzRXF1YWwoY29udkluZm8uaW5TaGFwZSwgY29udkluZm8ub3V0U2hhcGUpKSB7XG4gICAgcmV0dXJuIGlkZW50aXR5KHtpbnB1dHM6IHt4fSwgYmFja2VuZH0pO1xuICB9XG4gIGNvbnN0IG1heFBvb2xQcm9ncmFtID0gbmV3IFBvb2wyRFByb2dyYW0oY29udkluZm8sICdtYXgnLCBmYWxzZSk7XG4gIHJldHVybiBiYWNrZW5kLnJ1bldlYkdMUHJvZ3JhbShtYXhQb29sUHJvZ3JhbSwgW3hdLCB4LmR0eXBlKTtcbn1cblxuZXhwb3J0IGNvbnN0IG1heFBvb2xDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogTWF4UG9vbCxcbiAgYmFja2VuZE5hbWU6ICd3ZWJnbCcsXG4gIGtlcm5lbEZ1bmM6IG1heFBvb2wgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19