gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, MaxPoolGrad } from '@tensorflow/tfjs-core';
import { MaxPool2DBackpropProgram } from '../max_pool_backprop_gpu';
import { Pool2DProgram } from '../pool_gpu';
import { assertNotComplex } from '../webgl_util';
export function maxPoolGrad(args) {
    const { inputs, backend, attrs } = args;
    const { dy, input, output } = inputs;
    const x = input;
    assertNotComplex([input, output], 'maxPoolGrad');
    const { filterSize, strides, pad, dimRoundingMode } = attrs;
    const convInfo = backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
    const getPositions = true;
    const maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions);
    const maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype);
    const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);
    const result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype);
    backend.disposeIntermediateTensorInfo(maxPoolPositions);
    return result;
}
export const maxPoolGradConfig = {
    kernelName: MaxPoolGrad,
    backendName: 'webgl',
    kernelFunc: maxPoolGrad
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiTWF4UG9vbEdyYWQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL2tlcm5lbHMvTWF4UG9vbEdyYWQudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBQ0gsT0FBTyxFQUFDLFlBQVksRUFBNEIsV0FBVyxFQUFrRCxNQUFNLHVCQUF1QixDQUFDO0FBRzNJLE9BQU8sRUFBQyx3QkFBd0IsRUFBQyxNQUFNLDBCQUEwQixDQUFDO0FBQ2xFLE9BQU8sRUFBQyxhQUFhLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDMUMsT0FBTyxFQUFDLGdCQUFnQixFQUFDLE1BQU0sZUFBZSxDQUFDO0FBRS9DLE1BQU0sVUFBVSxXQUFXLENBQUMsSUFJM0I7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLEVBQUUsRUFBRSxLQUFLLEVBQUUsTUFBTSxFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ25DLE1BQU0sQ0FBQyxHQUFHLEtBQUssQ0FBQztJQUNoQixnQkFBZ0IsQ0FBQyxDQUFDLEtBQUssRUFBRSxNQUFNLENBQUMsRUFBRSxhQUFhLENBQUMsQ0FBQztJQUNqRCxNQUFNLEVBQUMsVUFBVSxFQUFFLE9BQU8sRUFBRSxHQUFHLEVBQUUsZUFBZSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRTFELE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQyxpQkFBaUIsQ0FDM0MsQ0FBQyxDQUFDLEtBQXlDLEVBQUUsVUFBVSxFQUFFLE9BQU8sRUFDaEUsQ0FBQyxDQUFDLGVBQWUsRUFBRSxHQUFHLEVBQUUsZUFBZSxDQUFDLENBQUM7SUFDN0MsTUFBTSxZQUFZLEdBQUcsSUFBSSxDQUFDO0lBQzFCLE1BQU0sdUJBQXVCLEdBQ3pCLElBQUksYUFBYSxDQUFDLFFBQVEsRUFBRSxLQUFLLEVBQUUsWUFBWSxDQUFDLENBQUM7SUFDckQsTUFBTSxnQkFBZ0IsR0FDbEIsT0FBTyxDQUFDLGVBQWUsQ0FBQyx1QkFBdUIsRUFBRSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQztJQUVuRSxNQUFNLHNCQUFzQixHQUFHLElBQUksd0JBQXdCLENBQUMsUUFBUSxDQUFDLENBQUM7SUFDdEUsTUFBTSxNQUFNLEdBQUcsT0FBTyxDQUFDLGVBQWUsQ0FDbEMsc0JBQXNCLEVBQUUsQ0FBQyxFQUFFLEVBQUUsZ0JBQWdCLENBQUMsRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUM7SUFDN0QsT0FBTyxDQUFDLDZCQUE2QixDQUFDLGdCQUFnQixDQUFDLENBQUM7SUFDeEQsT0FBTyxNQUFNLENBQUM7QUFDaEIsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLGlCQUFpQixHQUFpQjtJQUM3QyxVQUFVLEVBQUUsV0FBVztJQUN2QixXQUFXLEVBQUUsT0FBTztJQUNwQixVQUFVLEVBQUUsV0FBb0M7Q0FDakQsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cbmltcG9ydCB7YmFja2VuZF91dGlsLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIE1heFBvb2xHcmFkLCBNYXhQb29sR3JhZEF0dHJzLCBNYXhQb29sR3JhZElucHV0cywgVGVuc29ySW5mb30gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtNYXRoQmFja2VuZFdlYkdMfSBmcm9tICcuLi9iYWNrZW5kX3dlYmdsJztcbmltcG9ydCB7TWF4UG9vbDJEQmFja3Byb3BQcm9ncmFtfSBmcm9tICcuLi9tYXhfcG9vbF9iYWNrcHJvcF9ncHUnO1xuaW1wb3J0IHtQb29sMkRQcm9ncmFtfSBmcm9tICcuLi9wb29sX2dwdSc7XG5pbXBvcnQge2Fzc2VydE5vdENvbXBsZXh9IGZyb20gJy4uL3dlYmdsX3V0aWwnO1xuXG5leHBvcnQgZnVuY3Rpb24gbWF4UG9vbEdyYWQoYXJnczoge1xuICBpbnB1dHM6IE1heFBvb2xHcmFkSW5wdXRzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZFdlYkdMLFxuICBhdHRyczogTWF4UG9vbEdyYWRBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7ZHksIGlucHV0LCBvdXRwdXR9ID0gaW5wdXRzO1xuICBjb25zdCB4ID0gaW5wdXQ7XG4gIGFzc2VydE5vdENvbXBsZXgoW2lucHV0LCBvdXRwdXRdLCAnbWF4UG9vbEdyYWQnKTtcbiAgY29uc3Qge2ZpbHRlclNpemUsIHN0cmlkZXMsIHBhZCwgZGltUm91bmRpbmdNb2RlfSA9IGF0dHJzO1xuXG4gIGNvbnN0IGNvbnZJbmZvID0gYmFja2VuZF91dGlsLmNvbXB1dGVQb29sMkRJbmZvKFxuICAgICAgeC5zaGFwZSBhcyBbbnVtYmVyLCBudW1iZXIsIG51bWJlciwgbnVtYmVyXSwgZmlsdGVyU2l6ZSwgc3RyaWRlcyxcbiAgICAgIDEgLyogZGlsYXRpb25zICovLCBwYWQsIGRpbVJvdW5kaW5nTW9kZSk7XG4gIGNvbnN0IGdldFBvc2l0aW9ucyA9IHRydWU7XG4gIGNvbnN0IG1heFBvb2xQb3NpdGlvbnNQcm9ncmFtID1cbiAgICAgIG5ldyBQb29sMkRQcm9ncmFtKGNvbnZJbmZvLCAnbWF4JywgZ2V0UG9zaXRpb25zKTtcbiAgY29uc3QgbWF4UG9vbFBvc2l0aW9uczogVGVuc29ySW5mbyA9XG4gICAgICBiYWNrZW5kLnJ1bldlYkdMUHJvZ3JhbShtYXhQb29sUG9zaXRpb25zUHJvZ3JhbSwgW3hdLCB4LmR0eXBlKTtcblxuICBjb25zdCBtYXhQb29sQmFja1Byb3BQcm9ncmFtID0gbmV3IE1heFBvb2wyREJhY2twcm9wUHJvZ3JhbShjb252SW5mbyk7XG4gIGNvbnN0IHJlc3VsdCA9IGJhY2tlbmQucnVuV2ViR0xQcm9ncmFtKFxuICAgICAgbWF4UG9vbEJhY2tQcm9wUHJvZ3JhbSwgW2R5LCBtYXhQb29sUG9zaXRpb25zXSwgeC5kdHlwZSk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8obWF4UG9vbFBvc2l0aW9ucyk7XG4gIHJldHVybiByZXN1bHQ7XG59XG5cbmV4cG9ydCBjb25zdCBtYXhQb29sR3JhZENvbmZpZzogS2VybmVsQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBNYXhQb29sR3JhZCxcbiAgYmFja2VuZE5hbWU6ICd3ZWJnbCcsXG4gIGtlcm5lbEZ1bmM6IG1heFBvb2xHcmFkIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==