/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, util } from '@tensorflow/tfjs-core'; import { useShapeUniforms } from './gpgpu_math'; import { getChannels } from './packing_util'; import { getCoordsDataType } from './shader_compiler'; export const CHECK_NAN_SNIPPET_PACKED = ` result.r = isNaN.r ? NAN : result.r; result.g = isNaN.g ? NAN : result.g; result.b = isNaN.b ? NAN : result.b; result.a = isNaN.a ? NAN : result.a; `; export const ELU_DER = ` vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.))); return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0)))); `; export const NOT_EQUAL = ` return vec4(notEqual(a, b)); `; export class BinaryOpPackedProgram { constructor(op, aShape, bShape, checkOutOfBounds = false) { this.variableNames = ['A', 'B']; this.supportsBroadcasting = true; this.packedInputs = true; this.packedOutput = true; this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape); const rank = this.outputShape.length; this.enableShapeUniforms = useShapeUniforms(rank); let checkOutOfBoundsString = ''; if (checkOutOfBounds) { if (rank === 0 || util.sizeFromShape(this.outputShape) === 1) { checkOutOfBoundsString = ` result.y = 0.; result.z = 0.; result.w = 0.; `; } else { const dtype = getCoordsDataType(rank); checkOutOfBoundsString = ` ${dtype} coords = getOutputCoords(); `; if (rank === 1) { if (this.enableShapeUniforms) { checkOutOfBoundsString += ` result.y = (coords + 1) >= outShape ? 0. : result.y; result.z = 0.; result.w = 0.; `; } else { checkOutOfBoundsString += ` result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y; result.z = 0.; result.w = 0.; `; } } else { const channels = getChannels('coords', rank); if (this.enableShapeUniforms) { checkOutOfBoundsString += ` bool nextRowOutOfBounds = (${channels[rank - 2]} + 1) >= outShape[${rank} - 2]; bool nextColOutOfBounds = (${channels[rank - 1]} + 1) >= outShape[${rank} - 1]; result.y = nextColOutOfBounds ? 0. : result.y; result.z = nextRowOutOfBounds ? 0. : result.z; result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w; `; } else { checkOutOfBoundsString += ` bool nextRowOutOfBounds = (${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]}; bool nextColOutOfBounds = (${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]}; result.y = nextColOutOfBounds ? 0. : result.y; result.z = nextRowOutOfBounds ? 0. : result.z; result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w; `; } } } } this.userCode = ` vec4 binaryOperation(vec4 a, vec4 b) { ${op} } void main() { vec4 a = getAAtOutCoords(); vec4 b = getBAtOutCoords(); vec4 result = binaryOperation(a, b); ${checkOutOfBoundsString} setOutput(result); } `; } } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYmluYXJ5b3BfcGFja2VkX2dwdS5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMvYmluYXJ5b3BfcGFja2VkX2dwdS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUFFLElBQUksRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBRXpELE9BQU8sRUFBZSxnQkFBZ0IsRUFBQyxNQUFNLGNBQWMsQ0FBQztBQUM1RCxPQUFPLEVBQUMsV0FBVyxFQUFDLE1BQU0sZ0JBQWdCLENBQUM7QUFDM0MsT0FBTyxFQUFDLGlCQUFpQixFQUFDLE1BQU0sbUJBQW1CLENBQUM7QUFFcEQsTUFBTSxDQUFDLE1BQU0sd0JBQXdCLEdBQUc7Ozs7O0NBS3ZDLENBQUM7QUFFRixNQUFNLENBQUMsTUFBTSxPQUFPLEdBQUc7OztDQUd0QixDQUFDO0FBRUYsTUFBTSxDQUFDLE1BQU0sU0FBUyxHQUFHOztDQUV4QixDQUFDO0FBRUYsTUFBTSxPQUFPLHFCQUFxQjtJQVNoQyxZQUNJLEVBQVUsRUFBRSxNQUFnQixFQUFFLE1BQWdCLEVBQzlDLGdCQUFnQixHQUFHLEtBQUs7UUFWNUIsa0JBQWEsR0FBRyxDQUFDLEdBQUcsRUFBRSxHQUFHLENBQUMsQ0FBQztRQUczQix5QkFBb0IsR0FBRyxJQUFJLENBQUM7UUFDNUIsaUJBQVksR0FBRyxJQUFJLENBQUM7UUFDcEIsaUJBQVksR0FBRyxJQUFJLENBQUM7UUFNbEIsSUFBSSxDQUFDLFdBQVcsR0FBRyxZQUFZLENBQUMsMEJBQTBCLENBQUMsTUFBTSxFQUFFLE1BQU0sQ0FBQyxDQUFDO1FBQzNFLE1BQU0sSUFBSSxHQUFHLElBQUksQ0FBQyxXQUFXLENBQUMsTUFBTSxDQUFDO1FBQ3JDLElBQUksQ0FBQyxtQkFBbUIsR0FBRyxnQkFBZ0IsQ0FBQyxJQUFJLENBQUMsQ0FBQztRQUNsRCxJQUFJLHNCQUFzQixHQUFHLEVBQUUsQ0FBQztRQUNoQyxJQUFJLGdCQUFnQixFQUFFO1lBQ3BCLElBQUksSUFBSSxLQUFLLENBQUMsSUFBSSxJQUFJLENBQUMsYUFBYSxDQUFDLElBQUksQ0FBQyxXQUFXLENBQUMsS0FBSyxDQUFDLEVBQUU7Z0JBQzVELHNCQUFzQixHQUFHOzs7O1NBSXhCLENBQUM7YUFDSDtpQkFBTTtnQkFDTCxNQUFNLEtBQUssR0FBRyxpQkFBaUIsQ0FBQyxJQUFJLENBQUMsQ0FBQztnQkFDdEMsc0JBQXNCLEdBQUc7WUFDckIsS0FBSztTQUNSLENBQUM7Z0JBQ0YsSUFBSSxJQUFJLEtBQUssQ0FBQyxFQUFFO29CQUNkLElBQUksSUFBSSxDQUFDLG1CQUFtQixFQUFFO3dCQUM1QixzQkFBc0IsSUFBSTs7OztXQUkzQixDQUFDO3FCQUNEO3lCQUFNO3dCQUNMLHNCQUFzQixJQUFJO3lDQUNHLElBQUksQ0FBQyxXQUFXLENBQUMsQ0FBQyxDQUFDOzs7V0FHakQsQ0FBQztxQkFDRDtpQkFDRjtxQkFBTTtvQkFDTCxNQUFNLFFBQVEsR0FBRyxXQUFXLENBQUMsUUFBUSxFQUFFLElBQUksQ0FBQyxDQUFDO29CQUM3QyxJQUFJLElBQUksQ0FBQyxtQkFBbUIsRUFBRTt3QkFDNUIsc0JBQXNCLElBQUk7O2lCQUVyQixRQUFRLENBQUMsSUFBSSxHQUFHLENBQUMsQ0FBQyxxQkFBcUIsSUFBSTs7aUJBRTNDLFFBQVEsQ0FBQyxJQUFJLEdBQUcsQ0FBQyxDQUFDLHFCQUFxQixJQUFJOzs7O1dBSWpELENBQUM7cUJBQ0Q7eUJBQU07d0JBQ0wsc0JBQXNCLElBQUk7O2lCQUVyQixRQUFRLENBQUMsSUFBSSxHQUFHLENBQUMsQ0FBQyxZQUFZLElBQUksQ0FBQyxXQUFXLENBQUMsSUFBSSxHQUFHLENBQUMsQ0FBQzs7aUJBRXhELFFBQVEsQ0FBQyxJQUFJLEdBQUcsQ0FBQyxDQUFDLFlBQVksSUFBSSxDQUFDLFdBQVcsQ0FBQyxJQUFJLEdBQUcsQ0FBQyxDQUFDOzs7O1dBSTlELENBQUM7cUJBQ0Q7aUJBQ0Y7YUFDRjtTQUNGO1FBRUQsSUFBSSxDQUFDLFFBQVEsR0FBRzs7VUFFVixFQUFFOzs7Ozs7OztVQVFGLHNCQUFzQjs7OztLQUkzQixDQUFDO0lBQ0osQ0FBQztDQUNGIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTggR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtHUEdQVVByb2dyYW0sIHVzZVNoYXBlVW5pZm9ybXN9IGZyb20gJy4vZ3BncHVfbWF0aCc7XG5pbXBvcnQge2dldENoYW5uZWxzfSBmcm9tICcuL3BhY2tpbmdfdXRpbCc7XG5pbXBvcnQge2dldENvb3Jkc0RhdGFUeXBlfSBmcm9tICcuL3NoYWRlcl9jb21waWxlcic7XG5cbmV4cG9ydCBjb25zdCBDSEVDS19OQU5fU05JUFBFVF9QQUNLRUQgPSBgXG4gIHJlc3VsdC5yID0gaXNOYU4uciA/IE5BTiA6IHJlc3VsdC5yO1xuICByZXN1bHQuZyA9IGlzTmFOLmcgPyBOQU4gOiByZXN1bHQuZztcbiAgcmVzdWx0LmIgPSBpc05hTi5iID8gTkFOIDogcmVzdWx0LmI7XG4gIHJlc3VsdC5hID0gaXNOYU4uYSA/IE5BTiA6IHJlc3VsdC5hO1xuYDtcblxuZXhwb3J0IGNvbnN0IEVMVV9ERVIgPSBgXG4gIHZlYzQgYkdURVplcm8gPSB2ZWM0KGdyZWF0ZXJUaGFuRXF1YWwoYiwgdmVjNCgwLikpKTtcbiAgcmV0dXJuIChiR1RFWmVybyAqIGEpICsgKCh2ZWM0KDEuMCkgLSBiR1RFWmVybykgKiAoYSAqIChiICsgdmVjNCgxLjApKSkpO1xuYDtcblxuZXhwb3J0IGNvbnN0IE5PVF9FUVVBTCA9IGBcbiAgcmV0dXJuIHZlYzQobm90RXF1YWwoYSwgYikpO1xuYDtcblxuZXhwb3J0IGNsYXNzIEJpbmFyeU9wUGFja2VkUHJvZ3JhbSBpbXBsZW1lbnRzIEdQR1BVUHJvZ3JhbSB7XG4gIHZhcmlhYmxlTmFtZXMgPSBbJ0EnLCAnQiddO1xuICBvdXRwdXRTaGFwZTogbnVtYmVyW107XG4gIHVzZXJDb2RlOiBzdHJpbmc7XG4gIHN1cHBvcnRzQnJvYWRjYXN0aW5nID0gdHJ1ZTtcbiAgcGFja2VkSW5wdXRzID0gdHJ1ZTtcbiAgcGFja2VkT3V0cHV0ID0gdHJ1ZTtcbiAgZW5hYmxlU2hhcGVVbmlmb3JtczogYm9vbGVhbjtcblxuICBjb25zdHJ1Y3RvcihcbiAgICAgIG9wOiBzdHJpbmcsIGFTaGFwZTogbnVtYmVyW10sIGJTaGFwZTogbnVtYmVyW10sXG4gICAgICBjaGVja091dE9mQm91bmRzID0gZmFsc2UpIHtcbiAgICB0aGlzLm91dHB1dFNoYXBlID0gYmFja2VuZF91dGlsLmFzc2VydEFuZEdldEJyb2FkY2FzdFNoYXBlKGFTaGFwZSwgYlNoYXBlKTtcbiAgICBjb25zdCByYW5rID0gdGhpcy5vdXRwdXRTaGFwZS5sZW5ndGg7XG4gICAgdGhpcy5lbmFibGVTaGFwZVVuaWZvcm1zID0gdXNlU2hhcGVVbmlmb3JtcyhyYW5rKTtcbiAgICBsZXQgY2hlY2tPdXRPZkJvdW5kc1N0cmluZyA9ICcnO1xuICAgIGlmIChjaGVja091dE9mQm91bmRzKSB7XG4gICAgICBpZiAocmFuayA9PT0gMCB8fCB1dGlsLnNpemVGcm9tU2hhcGUodGhpcy5vdXRwdXRTaGFwZSkgPT09IDEpIHtcbiAgICAgICAgY2hlY2tPdXRPZkJvdW5kc1N0cmluZyA9IGBcbiAgICAgICAgICByZXN1bHQueSA9IDAuO1xuICAgICAgICAgIHJlc3VsdC56ID0gMC47XG4gICAgICAgICAgcmVzdWx0LncgPSAwLjtcbiAgICAgICAgYDtcbiAgICAgIH0gZWxzZSB7XG4gICAgICAgIGNvbnN0IGR0eXBlID0gZ2V0Q29vcmRzRGF0YVR5cGUocmFuayk7XG4gICAgICAgIGNoZWNrT3V0T2ZCb3VuZHNTdHJpbmcgPSBgXG4gICAgICAgICAgJHtkdHlwZX0gY29vcmRzID0gZ2V0T3V0cHV0Q29vcmRzKCk7XG4gICAgICAgIGA7XG4gICAgICAgIGlmIChyYW5rID09PSAxKSB7XG4gICAgICAgICAgaWYgKHRoaXMuZW5hYmxlU2hhcGVVbmlmb3Jtcykge1xuICAgICAgICAgICAgY2hlY2tPdXRPZkJvdW5kc1N0cmluZyArPSBgXG4gICAgICAgICAgICByZXN1bHQueSA9IChjb29yZHMgKyAxKSA+PSBvdXRTaGFwZSA/IDAuIDogcmVzdWx0Lnk7XG4gICAgICAgICAgICByZXN1bHQueiA9IDAuO1xuICAgICAgICAgICAgcmVzdWx0LncgPSAwLjtcbiAgICAgICAgICBgO1xuICAgICAgICAgIH0gZWxzZSB7XG4gICAgICAgICAgICBjaGVja091dE9mQm91bmRzU3RyaW5nICs9IGBcbiAgICAgICAgICAgIHJlc3VsdC55ID0gKGNvb3JkcyArIDEpID49ICR7dGhpcy5vdXRwdXRTaGFwZVswXX0gPyAwLiA6IHJlc3VsdC55O1xuICAgICAgICAgICAgcmVzdWx0LnogPSAwLjtcbiAgICAgICAgICAgIHJlc3VsdC53ID0gMC47XG4gICAgICAgICAgYDtcbiAgICAgICAgICB9XG4gICAgICAgIH0gZWxzZSB7XG4gICAgICAgICAgY29uc3QgY2hhbm5lbHMgPSBnZXRDaGFubmVscygnY29vcmRzJywgcmFuayk7XG4gICAgICAgICAgaWYgKHRoaXMuZW5hYmxlU2hhcGVVbmlmb3Jtcykge1xuICAgICAgICAgICAgY2hlY2tPdXRPZkJvdW5kc1N0cmluZyArPSBgXG4gICAgICAgICAgICBib29sIG5leHRSb3dPdXRPZkJvdW5kcyA9XG4gICAgICAgICAgICAgICgke2NoYW5uZWxzW3JhbmsgLSAyXX0gKyAxKSA+PSBvdXRTaGFwZVske3Jhbmt9IC0gMl07XG4gICAgICAgICAgICBib29sIG5leHRDb2xPdXRPZkJvdW5kcyA9XG4gICAgICAgICAgICAgICgke2NoYW5uZWxzW3JhbmsgLSAxXX0gKyAxKSA+PSBvdXRTaGFwZVske3Jhbmt9IC0gMV07XG4gICAgICAgICAgICByZXN1bHQueSA9IG5leHRDb2xPdXRPZkJvdW5kcyA/IDAuIDogcmVzdWx0Lnk7XG4gICAgICAgICAgICByZXN1bHQueiA9IG5leHRSb3dPdXRPZkJvdW5kcyA/IDAuIDogcmVzdWx0Lno7XG4gICAgICAgICAgICByZXN1bHQudyA9IG5leHRDb2xPdXRPZkJvdW5kcyB8fCBuZXh0Um93T3V0T2ZCb3VuZHMgPyAwLiA6IHJlc3VsdC53O1xuICAgICAgICAgIGA7XG4gICAgICAgICAgfSBlbHNlIHtcbiAgICAgICAgICAgIGNoZWNrT3V0T2ZCb3VuZHNTdHJpbmcgKz0gYFxuICAgICAgICAgICAgYm9vbCBuZXh0Um93T3V0T2ZCb3VuZHMgPVxuICAgICAgICAgICAgICAoJHtjaGFubmVsc1tyYW5rIC0gMl19ICsgMSkgPj0gJHt0aGlzLm91dHB1dFNoYXBlW3JhbmsgLSAyXX07XG4gICAgICAgICAgICBib29sIG5leHRDb2xPdXRPZkJvdW5kcyA9XG4gICAgICAgICAgICAgICgke2NoYW5uZWxzW3JhbmsgLSAxXX0gKyAxKSA+PSAke3RoaXMub3V0cHV0U2hhcGVbcmFuayAtIDFdfTtcbiAgICAgICAgICAgIHJlc3VsdC55ID0gbmV4dENvbE91dE9mQm91bmRzID8gMC4gOiByZXN1bHQueTtcbiAgICAgICAgICAgIHJlc3VsdC56ID0gbmV4dFJvd091dE9mQm91bmRzID8gMC4gOiByZXN1bHQuejtcbiAgICAgICAgICAgIHJlc3VsdC53ID0gbmV4dENvbE91dE9mQm91bmRzIHx8IG5leHRSb3dPdXRPZkJvdW5kcyA/IDAuIDogcmVzdWx0Lnc7XG4gICAgICAgICAgYDtcbiAgICAgICAgICB9XG4gICAgICAgIH1cbiAgICAgIH1cbiAgICB9XG5cbiAgICB0aGlzLnVzZXJDb2RlID0gYFxuICAgICAgdmVjNCBiaW5hcnlPcGVyYXRpb24odmVjNCBhLCB2ZWM0IGIpIHtcbiAgICAgICAgJHtvcH1cbiAgICAgIH1cblxuICAgICAgdm9pZCBtYWluKCkge1xuICAgICAgICB2ZWM0IGEgPSBnZXRBQXRPdXRDb29yZHMoKTtcbiAgICAgICAgdmVjNCBiID0gZ2V0QkF0T3V0Q29vcmRzKCk7XG5cbiAgICAgICAgdmVjNCByZXN1bHQgPSBiaW5hcnlPcGVyYXRpb24oYSwgYik7XG4gICAgICAgICR7Y2hlY2tPdXRPZkJvdW5kc1N0cmluZ31cblxuICAgICAgICBzZXRPdXRwdXQocmVzdWx0KTtcbiAgICAgIH1cbiAgICBgO1xuICB9XG59XG4iXX0=