/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, util } from '@tensorflow/tfjs-core'; import { useShapeUniforms } from './gpgpu_math'; import { getChannels } from './packing_util'; import { getCoordsDataType } from './shader_compiler'; export const CHECK_NAN_SNIPPET_PACKED = ` result.r = isNaN.r ? NAN : result.r; result.g = isNaN.g ? NAN : result.g; result.b = isNaN.b ? NAN : result.b; result.a = isNaN.a ? NAN : result.a; `; export const ELU_DER = ` vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.))); return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0)))); `; export const NOT_EQUAL = ` return vec4(notEqual(a, b)); `; export class BinaryOpPackedProgram { constructor(op, aShape, bShape, checkOutOfBounds = false) { this.variableNames = ['A', 'B']; this.supportsBroadcasting = true; this.packedInputs = true; this.packedOutput = true; this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape); const rank = this.outputShape.length; this.enableShapeUniforms = useShapeUniforms(rank); let checkOutOfBoundsString = ''; if (checkOutOfBounds) { if (rank === 0 || util.sizeFromShape(this.outputShape) === 1) { checkOutOfBoundsString = ` result.y = 0.; result.z = 0.; result.w = 0.; `; } else { const dtype = getCoordsDataType(rank); checkOutOfBoundsString = ` ${dtype} coords = getOutputCoords(); `; if (rank === 1) { if (this.enableShapeUniforms) { checkOutOfBoundsString += ` result.y = (coords + 1) >= outShape ? 0. : result.y; result.z = 0.; result.w = 0.; `; } else { checkOutOfBoundsString += ` result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y; result.z = 0.; result.w = 0.; `; } } else { const channels = getChannels('coords', rank); if (this.enableShapeUniforms) { checkOutOfBoundsString += ` bool nextRowOutOfBounds = (${channels[rank - 2]} + 1) >= outShape[${rank} - 2]; bool nextColOutOfBounds = (${channels[rank - 1]} + 1) >= outShape[${rank} - 1]; result.y = nextColOutOfBounds ? 0. : result.y; result.z = nextRowOutOfBounds ? 0. : result.z; result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w; `; } else { checkOutOfBoundsString += ` bool nextRowOutOfBounds = (${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]}; bool nextColOutOfBounds = (${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]}; result.y = nextColOutOfBounds ? 0. : result.y; result.z = nextRowOutOfBounds ? 0. : result.z; result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w; `; } } } } this.userCode = ` vec4 binaryOperation(vec4 a, vec4 b) { ${op} } void main() { vec4 a = getAAtOutCoords(); vec4 b = getBAtOutCoords(); vec4 result = binaryOperation(a, b); ${checkOutOfBoundsString} setOutput(result); } `; } } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"binaryop_packed_gpu.js","sourceRoot":"","sources":["../../../../../tfjs-backend-webgl/src/binaryop_packed_gpu.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAEzD,OAAO,EAAe,gBAAgB,EAAC,MAAM,cAAc,CAAC;AAC5D,OAAO,EAAC,WAAW,EAAC,MAAM,gBAAgB,CAAC;AAC3C,OAAO,EAAC,iBAAiB,EAAC,MAAM,mBAAmB,CAAC;AAEpD,MAAM,CAAC,MAAM,wBAAwB,GAAG;;;;;CAKvC,CAAC;AAEF,MAAM,CAAC,MAAM,OAAO,GAAG;;;CAGtB,CAAC;AAEF,MAAM,CAAC,MAAM,SAAS,GAAG;;CAExB,CAAC;AAEF,MAAM,OAAO,qBAAqB;IAShC,YACI,EAAU,EAAE,MAAgB,EAAE,MAAgB,EAC9C,gBAAgB,GAAG,KAAK;QAV5B,kBAAa,GAAG,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC;QAG3B,yBAAoB,GAAG,IAAI,CAAC;QAC5B,iBAAY,GAAG,IAAI,CAAC;QACpB,iBAAY,GAAG,IAAI,CAAC;QAMlB,IAAI,CAAC,WAAW,GAAG,YAAY,CAAC,0BAA0B,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;QAC3E,MAAM,IAAI,GAAG,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC;QACrC,IAAI,CAAC,mBAAmB,GAAG,gBAAgB,CAAC,IAAI,CAAC,CAAC;QAClD,IAAI,sBAAsB,GAAG,EAAE,CAAC;QAChC,IAAI,gBAAgB,EAAE;YACpB,IAAI,IAAI,KAAK,CAAC,IAAI,IAAI,CAAC,aAAa,CAAC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,EAAE;gBAC5D,sBAAsB,GAAG;;;;SAIxB,CAAC;aACH;iBAAM;gBACL,MAAM,KAAK,GAAG,iBAAiB,CAAC,IAAI,CAAC,CAAC;gBACtC,sBAAsB,GAAG;YACrB,KAAK;SACR,CAAC;gBACF,IAAI,IAAI,KAAK,CAAC,EAAE;oBACd,IAAI,IAAI,CAAC,mBAAmB,EAAE;wBAC5B,sBAAsB,IAAI;;;;WAI3B,CAAC;qBACD;yBAAM;wBACL,sBAAsB,IAAI;yCACG,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC;;;WAGjD,CAAC;qBACD;iBACF;qBAAM;oBACL,MAAM,QAAQ,GAAG,WAAW,CAAC,QAAQ,EAAE,IAAI,CAAC,CAAC;oBAC7C,IAAI,IAAI,CAAC,mBAAmB,EAAE;wBAC5B,sBAAsB,IAAI;;iBAErB,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,qBAAqB,IAAI;;iBAE3C,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,qBAAqB,IAAI;;;;WAIjD,CAAC;qBACD;yBAAM;wBACL,sBAAsB,IAAI;;iBAErB,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,YAAY,IAAI,CAAC,WAAW,CAAC,IAAI,GAAG,CAAC,CAAC;;iBAExD,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,YAAY,IAAI,CAAC,WAAW,CAAC,IAAI,GAAG,CAAC,CAAC;;;;WAI9D,CAAC;qBACD;iBACF;aACF;SACF;QAED,IAAI,CAAC,QAAQ,GAAG;;UAEV,EAAE;;;;;;;;UAQF,sBAAsB;;;;KAI3B,CAAC;IACJ,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, util} from '@tensorflow/tfjs-core';\n\nimport {GPGPUProgram, useShapeUniforms} from './gpgpu_math';\nimport {getChannels} from './packing_util';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport const CHECK_NAN_SNIPPET_PACKED = `\n  result.r = isNaN.r ? NAN : result.r;\n  result.g = isNaN.g ? NAN : result.g;\n  result.b = isNaN.b ? NAN : result.b;\n  result.a = isNaN.a ? NAN : result.a;\n`;\n\nexport const ELU_DER = `\n  vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n  return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n`;\n\nexport const NOT_EQUAL = `\n  return vec4(notEqual(a, b));\n`;\n\nexport class BinaryOpPackedProgram implements GPGPUProgram {\n  variableNames = ['A', 'B'];\n  outputShape: number[];\n  userCode: string;\n  supportsBroadcasting = true;\n  packedInputs = true;\n  packedOutput = true;\n  enableShapeUniforms: boolean;\n\n  constructor(\n      op: string, aShape: number[], bShape: number[],\n      checkOutOfBounds = false) {\n    this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);\n    const rank = this.outputShape.length;\n    this.enableShapeUniforms = useShapeUniforms(rank);\n    let checkOutOfBoundsString = '';\n    if (checkOutOfBounds) {\n      if (rank === 0 || util.sizeFromShape(this.outputShape) === 1) {\n        checkOutOfBoundsString = `\n          result.y = 0.;\n          result.z = 0.;\n          result.w = 0.;\n        `;\n      } else {\n        const dtype = getCoordsDataType(rank);\n        checkOutOfBoundsString = `\n          ${dtype} coords = getOutputCoords();\n        `;\n        if (rank === 1) {\n          if (this.enableShapeUniforms) {\n            checkOutOfBoundsString += `\n            result.y = (coords + 1) >= outShape ? 0. : result.y;\n            result.z = 0.;\n            result.w = 0.;\n          `;\n          } else {\n            checkOutOfBoundsString += `\n            result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;\n            result.z = 0.;\n            result.w = 0.;\n          `;\n          }\n        } else {\n          const channels = getChannels('coords', rank);\n          if (this.enableShapeUniforms) {\n            checkOutOfBoundsString += `\n            bool nextRowOutOfBounds =\n              (${channels[rank - 2]} + 1) >= outShape[${rank} - 2];\n            bool nextColOutOfBounds =\n              (${channels[rank - 1]} + 1) >= outShape[${rank} - 1];\n            result.y = nextColOutOfBounds ? 0. : result.y;\n            result.z = nextRowOutOfBounds ? 0. : result.z;\n            result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n          `;\n          } else {\n            checkOutOfBoundsString += `\n            bool nextRowOutOfBounds =\n              (${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]};\n            bool nextColOutOfBounds =\n              (${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]};\n            result.y = nextColOutOfBounds ? 0. : result.y;\n            result.z = nextRowOutOfBounds ? 0. : result.z;\n            result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n          `;\n          }\n        }\n      }\n    }\n\n    this.userCode = `\n      vec4 binaryOperation(vec4 a, vec4 b) {\n        ${op}\n      }\n\n      void main() {\n        vec4 a = getAAtOutCoords();\n        vec4 b = getBAtOutCoords();\n\n        vec4 result = binaryOperation(a, b);\n        ${checkOutOfBoundsString}\n\n        setOutput(result);\n      }\n    `;\n  }\n}\n"]}