gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, Conv2DBackpropInput, env } from '@tensorflow/tfjs-core';
import { Conv2DDerInputProgram } from '../conv_backprop_gpu';
import { Conv2DDerInputPackedProgram } from '../conv_backprop_packed_gpu';
export function conv2DBackpropInput(args) {
    const { inputs, backend, attrs } = args;
    const { dy, filter } = inputs;
    const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs;
    const $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);
    const convInfo = backend_util.computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat);
    if (env().getBool('WEBGL_PACK_CONV2DTRANSPOSE') &&
        $dataFormat === 'channelsLast') {
        const customValues = [
            [convInfo.strideHeight, convInfo.strideWidth],
        ];
        const program = new Conv2DDerInputPackedProgram(convInfo);
        return backend.runWebGLProgram(program, [dy, filter], 'float32', customValues);
    }
    else {
        const program = new Conv2DDerInputProgram(convInfo);
        return backend.runWebGLProgram(program, [dy, filter], 'float32');
    }
}
export const conv2DBackpropInputConfig = {
    kernelName: Conv2DBackpropInput,
    backendName: 'webgl',
    kernelFunc: conv2DBackpropInput,
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQ29udjJEQmFja3Byb3BJbnB1dC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9Db252MkRCYWNrcHJvcElucHV0LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQUUsbUJBQW1CLEVBQXVELEdBQUcsRUFBMkIsTUFBTSx1QkFBdUIsQ0FBQztBQUc1SixPQUFPLEVBQUMscUJBQXFCLEVBQUMsTUFBTSxzQkFBc0IsQ0FBQztBQUMzRCxPQUFPLEVBQUMsMkJBQTJCLEVBQUMsTUFBTSw2QkFBNkIsQ0FBQztBQUV4RSxNQUFNLFVBQVUsbUJBQW1CLENBQUMsSUFJbkM7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLEVBQUUsRUFBRSxNQUFNLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDNUIsTUFBTSxFQUFDLFVBQVUsRUFBRSxPQUFPLEVBQUUsR0FBRyxFQUFFLFVBQVUsRUFBRSxlQUFlLEVBQUMsR0FBRyxLQUFLLENBQUM7SUFFdEUsTUFBTSxXQUFXLEdBQUcsWUFBWSxDQUFDLHVCQUF1QixDQUFDLFVBQVUsQ0FBQyxDQUFDO0lBQ3JFLE1BQU0sUUFBUSxHQUFHLFlBQVksQ0FBQyxpQkFBaUIsQ0FDM0MsVUFBVSxFQUFFLE1BQU0sQ0FBQyxLQUF5QyxFQUFFLE9BQU8sRUFDckUsQ0FBQyxDQUFDLGVBQWUsRUFBRSxHQUFHLEVBQUUsZUFBZSxFQUFFLEtBQUssRUFBRSxXQUFXLENBQUMsQ0FBQztJQUVqRSxJQUFJLEdBQUcsRUFBRSxDQUFDLE9BQU8sQ0FBQyw0QkFBNEIsQ0FBQztRQUMzQyxXQUFXLEtBQUssY0FBYyxFQUFFO1FBQ2xDLE1BQU0sWUFBWSxHQUFHO1lBQ25CLENBQUMsUUFBUSxDQUFDLFlBQVksRUFBRSxRQUFRLENBQUMsV0FBVyxDQUFDO1NBQzlDLENBQUM7UUFDRixNQUFNLE9BQU8sR0FBRyxJQUFJLDJCQUEyQixDQUFDLFFBQVEsQ0FBQyxDQUFDO1FBQzFELE9BQU8sT0FBTyxDQUFDLGVBQWUsQ0FDMUIsT0FBTyxFQUFFLENBQUMsRUFBRSxFQUFFLE1BQU0sQ0FBQyxFQUFFLFNBQVMsRUFBRSxZQUFZLENBQUMsQ0FBQztLQUNyRDtTQUFNO1FBQ0wsTUFBTSxPQUFPLEdBQUcsSUFBSSxxQkFBcUIsQ0FBQyxRQUFRLENBQUMsQ0FBQztRQUNwRCxPQUFPLE9BQU8sQ0FBQyxlQUFlLENBQUMsT0FBTyxFQUFFLENBQUMsRUFBRSxFQUFFLE1BQU0sQ0FBQyxFQUFFLFNBQVMsQ0FBQyxDQUFDO0tBQ2xFO0FBQ0gsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLHlCQUF5QixHQUFpQjtJQUNyRCxVQUFVLEVBQUUsbUJBQW1CO0lBQy9CLFdBQVcsRUFBRSxPQUFPO0lBQ3BCLFVBQVUsRUFBRSxtQkFBNEM7Q0FDekQsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIENvbnYyREJhY2twcm9wSW5wdXQsIENvbnYyREJhY2twcm9wSW5wdXRBdHRycywgQ29udjJEQmFja3Byb3BJbnB1dElucHV0cywgZW52LCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmN9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5pbXBvcnQge0NvbnYyRERlcklucHV0UHJvZ3JhbX0gZnJvbSAnLi4vY29udl9iYWNrcHJvcF9ncHUnO1xuaW1wb3J0IHtDb252MkREZXJJbnB1dFBhY2tlZFByb2dyYW19IGZyb20gJy4uL2NvbnZfYmFja3Byb3BfcGFja2VkX2dwdSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBjb252MkRCYWNrcHJvcElucHV0KGFyZ3M6IHtcbiAgaW5wdXRzOiBDb252MkRCYWNrcHJvcElucHV0SW5wdXRzLFxuICBhdHRyczogQ29udjJEQmFja3Byb3BJbnB1dEF0dHJzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZFdlYkdMXG59KSB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHtkeSwgZmlsdGVyfSA9IGlucHV0cztcbiAgY29uc3Qge2lucHV0U2hhcGUsIHN0cmlkZXMsIHBhZCwgZGF0YUZvcm1hdCwgZGltUm91bmRpbmdNb2RlfSA9IGF0dHJzO1xuXG4gIGNvbnN0ICRkYXRhRm9ybWF0ID0gYmFja2VuZF91dGlsLmNvbnZlcnRDb252MkREYXRhRm9ybWF0KGRhdGFGb3JtYXQpO1xuICBjb25zdCBjb252SW5mbyA9IGJhY2tlbmRfdXRpbC5jb21wdXRlQ29udjJESW5mbyhcbiAgICAgIGlucHV0U2hhcGUsIGZpbHRlci5zaGFwZSBhcyBbbnVtYmVyLCBudW1iZXIsIG51bWJlciwgbnVtYmVyXSwgc3RyaWRlcyxcbiAgICAgIDEgLyogZGlsYXRpb25zICovLCBwYWQsIGRpbVJvdW5kaW5nTW9kZSwgZmFsc2UsICRkYXRhRm9ybWF0KTtcblxuICBpZiAoZW52KCkuZ2V0Qm9vbCgnV0VCR0xfUEFDS19DT05WMkRUUkFOU1BPU0UnKSAmJlxuICAgICAgJGRhdGFGb3JtYXQgPT09ICdjaGFubmVsc0xhc3QnKSB7XG4gICAgY29uc3QgY3VzdG9tVmFsdWVzID0gW1xuICAgICAgW2NvbnZJbmZvLnN0cmlkZUhlaWdodCwgY29udkluZm8uc3RyaWRlV2lkdGhdLFxuICAgIF07XG4gICAgY29uc3QgcHJvZ3JhbSA9IG5ldyBDb252MkREZXJJbnB1dFBhY2tlZFByb2dyYW0oY29udkluZm8pO1xuICAgIHJldHVybiBiYWNrZW5kLnJ1bldlYkdMUHJvZ3JhbShcbiAgICAgICAgcHJvZ3JhbSwgW2R5LCBmaWx0ZXJdLCAnZmxvYXQzMicsIGN1c3RvbVZhbHVlcyk7XG4gIH0gZWxzZSB7XG4gICAgY29uc3QgcHJvZ3JhbSA9IG5ldyBDb252MkREZXJJbnB1dFByb2dyYW0oY29udkluZm8pO1xuICAgIHJldHVybiBiYWNrZW5kLnJ1bldlYkdMUHJvZ3JhbShwcm9ncmFtLCBbZHksIGZpbHRlcl0sICdmbG9hdDMyJyk7XG4gIH1cbn1cblxuZXhwb3J0IGNvbnN0IGNvbnYyREJhY2twcm9wSW5wdXRDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogQ29udjJEQmFja3Byb3BJbnB1dCxcbiAgYmFja2VuZE5hbWU6ICd3ZWJnbCcsXG4gIGtlcm5lbEZ1bmM6IGNvbnYyREJhY2twcm9wSW5wdXQgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jLFxufTtcbiJdfQ==