/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, util } from '@tensorflow/tfjs-core';
|
import { useShapeUniforms } from './gpgpu_math';
|
import { getChannels } from './packing_util';
|
import { getCoordsDataType } from './shader_compiler';
|
export const CHECK_NAN_SNIPPET_PACKED = `
|
result.r = isNaN.r ? NAN : result.r;
|
result.g = isNaN.g ? NAN : result.g;
|
result.b = isNaN.b ? NAN : result.b;
|
result.a = isNaN.a ? NAN : result.a;
|
`;
|
export const ELU_DER = `
|
vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));
|
return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));
|
`;
|
export const NOT_EQUAL = `
|
return vec4(notEqual(a, b));
|
`;
|
export class BinaryOpPackedProgram {
|
constructor(op, aShape, bShape, checkOutOfBounds = false) {
|
this.variableNames = ['A', 'B'];
|
this.supportsBroadcasting = true;
|
this.packedInputs = true;
|
this.packedOutput = true;
|
this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
|
const rank = this.outputShape.length;
|
this.enableShapeUniforms = useShapeUniforms(rank);
|
let checkOutOfBoundsString = '';
|
if (checkOutOfBounds) {
|
if (rank === 0 || util.sizeFromShape(this.outputShape) === 1) {
|
checkOutOfBoundsString = `
|
result.y = 0.;
|
result.z = 0.;
|
result.w = 0.;
|
`;
|
}
|
else {
|
const dtype = getCoordsDataType(rank);
|
checkOutOfBoundsString = `
|
${dtype} coords = getOutputCoords();
|
`;
|
if (rank === 1) {
|
if (this.enableShapeUniforms) {
|
checkOutOfBoundsString += `
|
result.y = (coords + 1) >= outShape ? 0. : result.y;
|
result.z = 0.;
|
result.w = 0.;
|
`;
|
}
|
else {
|
checkOutOfBoundsString += `
|
result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
|
result.z = 0.;
|
result.w = 0.;
|
`;
|
}
|
}
|
else {
|
const channels = getChannels('coords', rank);
|
if (this.enableShapeUniforms) {
|
checkOutOfBoundsString += `
|
bool nextRowOutOfBounds =
|
(${channels[rank - 2]} + 1) >= outShape[${rank} - 2];
|
bool nextColOutOfBounds =
|
(${channels[rank - 1]} + 1) >= outShape[${rank} - 1];
|
result.y = nextColOutOfBounds ? 0. : result.y;
|
result.z = nextRowOutOfBounds ? 0. : result.z;
|
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
|
`;
|
}
|
else {
|
checkOutOfBoundsString += `
|
bool nextRowOutOfBounds =
|
(${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]};
|
bool nextColOutOfBounds =
|
(${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]};
|
result.y = nextColOutOfBounds ? 0. : result.y;
|
result.z = nextRowOutOfBounds ? 0. : result.z;
|
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
|
`;
|
}
|
}
|
}
|
}
|
this.userCode = `
|
vec4 binaryOperation(vec4 a, vec4 b) {
|
${op}
|
}
|
|
void main() {
|
vec4 a = getAAtOutCoords();
|
vec4 b = getBAtOutCoords();
|
|
vec4 result = binaryOperation(a, b);
|
${checkOutOfBoundsString}
|
|
setOutput(result);
|
}
|
`;
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"binaryop_packed_gpu.js","sourceRoot":"","sources":["../../../../../tfjs-backend-webgl/src/binaryop_packed_gpu.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAEzD,OAAO,EAAe,gBAAgB,EAAC,MAAM,cAAc,CAAC;AAC5D,OAAO,EAAC,WAAW,EAAC,MAAM,gBAAgB,CAAC;AAC3C,OAAO,EAAC,iBAAiB,EAAC,MAAM,mBAAmB,CAAC;AAEpD,MAAM,CAAC,MAAM,wBAAwB,GAAG;;;;;CAKvC,CAAC;AAEF,MAAM,CAAC,MAAM,OAAO,GAAG;;;CAGtB,CAAC;AAEF,MAAM,CAAC,MAAM,SAAS,GAAG;;CAExB,CAAC;AAEF,MAAM,OAAO,qBAAqB;IAShC,YACI,EAAU,EAAE,MAAgB,EAAE,MAAgB,EAC9C,gBAAgB,GAAG,KAAK;QAV5B,kBAAa,GAAG,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC;QAG3B,yBAAoB,GAAG,IAAI,CAAC;QAC5B,iBAAY,GAAG,IAAI,CAAC;QACpB,iBAAY,GAAG,IAAI,CAAC;QAMlB,IAAI,CAAC,WAAW,GAAG,YAAY,CAAC,0BAA0B,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;QAC3E,MAAM,IAAI,GAAG,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC;QACrC,IAAI,CAAC,mBAAmB,GAAG,gBAAgB,CAAC,IAAI,CAAC,CAAC;QAClD,IAAI,sBAAsB,GAAG,EAAE,CAAC;QAChC,IAAI,gBAAgB,EAAE;YACpB,IAAI,IAAI,KAAK,CAAC,IAAI,IAAI,CAAC,aAAa,CAAC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,EAAE;gBAC5D,sBAAsB,GAAG;;;;SAIxB,CAAC;aACH;iBAAM;gBACL,MAAM,KAAK,GAAG,iBAAiB,CAAC,IAAI,CAAC,CAAC;gBACtC,sBAAsB,GAAG;YACrB,KAAK;SACR,CAAC;gBACF,IAAI,IAAI,KAAK,CAAC,EAAE;oBACd,IAAI,IAAI,CAAC,mBAAmB,EAAE;wBAC5B,sBAAsB,IAAI;;;;WAI3B,CAAC;qBACD;yBAAM;wBACL,sBAAsB,IAAI;yCACG,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC;;;WAGjD,CAAC;qBACD;iBACF;qBAAM;oBACL,MAAM,QAAQ,GAAG,WAAW,CAAC,QAAQ,EAAE,IAAI,CAAC,CAAC;oBAC7C,IAAI,IAAI,CAAC,mBAAmB,EAAE;wBAC5B,sBAAsB,IAAI;;iBAErB,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,qBAAqB,IAAI;;iBAE3C,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,qBAAqB,IAAI;;;;WAIjD,CAAC;qBACD;yBAAM;wBACL,sBAAsB,IAAI;;iBAErB,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,YAAY,IAAI,CAAC,WAAW,CAAC,IAAI,GAAG,CAAC,CAAC;;iBAExD,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,YAAY,IAAI,CAAC,WAAW,CAAC,IAAI,GAAG,CAAC,CAAC;;;;WAI9D,CAAC;qBACD;iBACF;aACF;SACF;QAED,IAAI,CAAC,QAAQ,GAAG;;UAEV,EAAE;;;;;;;;UAQF,sBAAsB;;;;KAI3B,CAAC;IACJ,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, util} from '@tensorflow/tfjs-core';\n\nimport {GPGPUProgram, useShapeUniforms} from './gpgpu_math';\nimport {getChannels} from './packing_util';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport const CHECK_NAN_SNIPPET_PACKED = `\n  result.r = isNaN.r ? NAN : result.r;\n  result.g = isNaN.g ? NAN : result.g;\n  result.b = isNaN.b ? NAN : result.b;\n  result.a = isNaN.a ? NAN : result.a;\n`;\n\nexport const ELU_DER = `\n  vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n  return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n`;\n\nexport const NOT_EQUAL = `\n  return vec4(notEqual(a, b));\n`;\n\nexport class BinaryOpPackedProgram implements GPGPUProgram {\n  variableNames = ['A', 'B'];\n  outputShape: number[];\n  userCode: string;\n  supportsBroadcasting = true;\n  packedInputs = true;\n  packedOutput = true;\n  enableShapeUniforms: boolean;\n\n  constructor(\n      op: string, aShape: number[], bShape: number[],\n      checkOutOfBounds = false) {\n    this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);\n    const rank = this.outputShape.length;\n    this.enableShapeUniforms = useShapeUniforms(rank);\n    let checkOutOfBoundsString = '';\n    if (checkOutOfBounds) {\n      if (rank === 0 || util.sizeFromShape(this.outputShape) === 1) {\n        checkOutOfBoundsString = `\n          result.y = 0.;\n          result.z = 0.;\n          result.w = 0.;\n        `;\n      } else {\n        const dtype = getCoordsDataType(rank);\n        checkOutOfBoundsString = `\n          ${dtype} coords = getOutputCoords();\n        `;\n        if (rank === 1) {\n          if (this.enableShapeUniforms) {\n            checkOutOfBoundsString += `\n            result.y = (coords + 1) >= outShape ? 0. : result.y;\n            result.z = 0.;\n            result.w = 0.;\n          `;\n          } else {\n            checkOutOfBoundsString += `\n            result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;\n            result.z = 0.;\n            result.w = 0.;\n          `;\n          }\n        } else {\n          const channels = getChannels('coords', rank);\n          if (this.enableShapeUniforms) {\n            checkOutOfBoundsString += `\n            bool nextRowOutOfBounds =\n              (${channels[rank - 2]} + 1) >= outShape[${rank} - 2];\n            bool nextColOutOfBounds =\n              (${channels[rank - 1]} + 1) >= outShape[${rank} - 1];\n            result.y = nextColOutOfBounds ? 0. : result.y;\n            result.z = nextRowOutOfBounds ? 0. : result.z;\n            result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n          `;\n          } else {\n            checkOutOfBoundsString += `\n            bool nextRowOutOfBounds =\n              (${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]};\n            bool nextColOutOfBounds =\n              (${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]};\n            result.y = nextColOutOfBounds ? 0. : result.y;\n            result.z = nextRowOutOfBounds ? 0. : result.z;\n            result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n          `;\n          }\n        }\n      }\n    }\n\n    this.userCode = `\n      vec4 binaryOperation(vec4 a, vec4 b) {\n        ${op}\n      }\n\n      void main() {\n        vec4 a = getAAtOutCoords();\n        vec4 b = getBAtOutCoords();\n\n        vec4 result = binaryOperation(a, b);\n        ${checkOutOfBoundsString}\n\n        setOutput(result);\n      }\n    `;\n  }\n}\n"]}
|