/** * @license * Copyright 2023 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ 'use strict'; var tf = require('@tensorflow/tfjs-core'); function _interopNamespaceDefault(e) { var n = Object.create(null); if (e) { Object.keys(e).forEach(function (k) { if (k !== 'default') { var d = Object.getOwnPropertyDescriptor(e, k); Object.defineProperty(n, k, d.get ? d : { enumerable: true, get: function () { return e[k]; } }); } }); } n.default = e; return n; } var tf__namespace = /*#__PURE__*/_interopNamespaceDefault(tf); /****************************************************************************** Copyright (c) Microsoft Corporation. Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ***************************************************************************** */ /* global Reflect, Promise */ var extendStatics = function (d, b) { extendStatics = Object.setPrototypeOf || ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || function (d, b) { for (var p in b) if (Object.prototype.hasOwnProperty.call(b, p)) d[p] = b[p]; }; return extendStatics(d, b); }; function __extends(d, b) { if (typeof b !== "function" && b !== null) throw new TypeError("Class extends value " + String(b) + " is not a constructor or null"); extendStatics(d, b); function __() { this.constructor = d; } d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); } function __awaiter(thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); } function __generator(thisArg, body) { var _ = { label: 0, sent: function () { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function () { return this; }), g; function verb(n) { return function (v) { return step([n, v]); }; } function step(op) { if (f) throw new TypeError("Generator is already executing."); while (_) try { if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; if (y = 0, t) op = [op[0] & 2, t.value]; switch (op[0]) { case 0: case 1: t = op; break; case 4: _.label++; return { value: op[1], done: false }; case 5: _.label++; y = op[1]; op = [0]; continue; case 7: op = _.ops.pop(); _.trys.pop(); continue; default: if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } if (t[2]) _.ops.pop(); _.trys.pop(); continue; } op = body.call(thisArg, _); } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; } } function __values(o) { var s = typeof Symbol === "function" && Symbol.iterator, m = s && o[s], i = 0; if (m) return m.call(o); if (o && typeof o.length === "number") return { next: function () { if (o && i >= o.length) o = void 0; return { value: o && o[i++], done: !o }; } }; throw new TypeError(s ? "Object is not iterable." : "Symbol.iterator is not defined."); } function __read(o, n) { var m = typeof Symbol === "function" && o[Symbol.iterator]; if (!m) return o; var i = m.call(o), r, ar = [], e; try { while ((n === void 0 || n-- > 0) && !(r = i.next()).done) ar.push(r.value); } catch (error) { e = { error: error }; } finally { try { if (r && !r.done && (m = i["return"])) m.call(i); } finally { if (e) throw e.error; } } return ar; } function __spreadArray(to, from, pack) { if (pack || arguments.length === 2) for (var i = 0, l = from.length, ar; i < l; i++) { if (ar || !(i in from)) { if (!ar) ar = Array.prototype.slice.call(from, 0, i); ar[i] = from[i]; } } return to.concat(ar || Array.prototype.slice.call(from)); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var contexts = {}; var WEBGL_ATTRIBUTES = { alpha: false, antialias: false, premultipliedAlpha: false, preserveDrawingBuffer: false, depth: false, stencil: false, failIfMajorPerformanceCaveat: true }; function setWebGLContext(webGLVersion, gl) { contexts[webGLVersion] = gl; } function getWebGLContext(webGLVersion, customCanvas) { if (!(webGLVersion in contexts) || customCanvas != null) { var newCtx = getWebGLRenderingContext(webGLVersion, customCanvas); if (newCtx !== null) { contexts[webGLVersion] = newCtx; } else { console.log('Could not get context for WebGL version', webGLVersion); return null; } } var gl = contexts[webGLVersion]; if (gl == null || gl.isContextLost()) { delete contexts[webGLVersion]; return getWebGLContext(webGLVersion); } gl.disable(gl.DEPTH_TEST); gl.disable(gl.STENCIL_TEST); gl.disable(gl.BLEND); gl.disable(gl.DITHER); gl.disable(gl.POLYGON_OFFSET_FILL); gl.disable(gl.SAMPLE_COVERAGE); gl.enable(gl.SCISSOR_TEST); gl.enable(gl.CULL_FACE); gl.cullFace(gl.BACK); return contexts[webGLVersion]; } function createCanvas(webGLVersion) { // Use canvas element for Safari, since its offscreen canvas does not support // fencing. if (!tf.env().getBool('IS_SAFARI') && typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) { return new OffscreenCanvas(300, 150); } else if (typeof document !== 'undefined') { return document.createElement('canvas'); } else { throw new Error('Cannot create a canvas in this context'); } } function getWebGLRenderingContext(webGLVersion, customCanvas) { if (webGLVersion !== 1 && webGLVersion !== 2) { throw new Error('Cannot get WebGL rendering context, WebGL is disabled.'); } var canvas = customCanvas == null ? createCanvas(webGLVersion) : customCanvas; canvas.addEventListener('webglcontextlost', function (ev) { ev.preventDefault(); delete contexts[webGLVersion]; }, false); if (tf.env().getBool('SOFTWARE_WEBGL_ENABLED')) { WEBGL_ATTRIBUTES.failIfMajorPerformanceCaveat = false; } if (webGLVersion === 1) { return ( // tslint:disable-next-line canvas.getContext('webgl', WEBGL_ATTRIBUTES) || canvas .getContext('experimental-webgl', WEBGL_ATTRIBUTES)); } return canvas.getContext('webgl2', WEBGL_ATTRIBUTES); } var PackingScheme; (function (PackingScheme) { /** * All values in a single texel are densely packed without any constraints. * * This is how the shader encodes a tensor with shape = [2, 3, 4] * (indices are [batch, row, col]). * * 000|001 010|011 020|021 * ------- ------- ------- * 002|003 012|013 022|023 * * 100|101 110|111 120|121 * ------- ------- ------- * 102|103 112|113 122|123 * */ PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE"; /** * Single texels contain only values from the same batch, and from adjacent * rows and columns. * * This is how the shader encodes a tensor with shape = [2, 3, 5] * (indices are [batch, row, col]). * * 000|001 002|003 004|xxx 020|021 022|023 024|xxx * ------- ------- ------- ------- ------- ------- * 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx * * 100|101 102|103 104|xxx 120|121 122|123 124|xxx * ------- ------- ------- ------- ------- ------- * 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx * */ PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH"; })(PackingScheme || (PackingScheme = {})); var TextureUsage; (function (TextureUsage) { TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER"; TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD"; TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS"; TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD"; })(TextureUsage || (TextureUsage = {})); var PhysicalTextureType; (function (PhysicalTextureType) { PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16"; PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32"; PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE"; PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32"; PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16"; })(PhysicalTextureType || (PhysicalTextureType = {})); function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) { return [columns, rows]; } function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) { return matrixSize * channelsPerTexture; } /** * Get shape for densely packed RGBA texture. */ function getDenseTexShape(shape) { var size = tf.util.sizeFromShape(shape); var texelsNeeded = Math.ceil(size / 4); return tf.util.sizeToSquarishShape(texelsNeeded); } function getPackedMatrixTextureShapeWidthHeight(rows, columns) { return [ Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2)) ]; } function getPackedRGBAArraySizeFromMatrixShape(rows, columns) { var _a = __read(getPackedMatrixTextureShapeWidthHeight(rows, columns), 2), w = _a[0], h = _a[1]; return w * h * 4; } function getTextureConfig( // tslint:disable-next-line:no-any gl, textureHalfFloatExtension) { // tslint:disable-next-line:no-any var glany = gl; var internalFormatFloat; var internalFormatHalfFloat; var internalFormatPackedHalfFloat; var internalFormatPackedFloat; var textureFormatFloat; var downloadTextureFormat; var downloadUnpackNumChannels; var defaultNumChannels; var textureTypeHalfFloat; var textureTypeFloat; if (tf.env().getNumber('WEBGL_VERSION') === 2) { internalFormatFloat = glany.R32F; internalFormatHalfFloat = glany.R16F; internalFormatPackedHalfFloat = glany.RGBA16F; internalFormatPackedFloat = glany.RGBA32F; textureFormatFloat = glany.RED; downloadUnpackNumChannels = 4; defaultNumChannels = 1; textureTypeHalfFloat = glany.HALF_FLOAT; textureTypeFloat = glany.FLOAT; downloadTextureFormat = glany.RGBA8; } else { internalFormatFloat = gl.RGBA; internalFormatHalfFloat = gl.RGBA; internalFormatPackedHalfFloat = gl.RGBA; internalFormatPackedFloat = glany.RGBA; textureFormatFloat = gl.RGBA; downloadUnpackNumChannels = 4; defaultNumChannels = 4; textureTypeHalfFloat = textureHalfFloatExtension != null ? textureHalfFloatExtension.HALF_FLOAT_OES : null; textureTypeFloat = gl.FLOAT; downloadTextureFormat = gl.RGBA; } return { internalFormatFloat: internalFormatFloat, internalFormatHalfFloat: internalFormatHalfFloat, internalFormatPackedHalfFloat: internalFormatPackedHalfFloat, internalFormatPackedFloat: internalFormatPackedFloat, textureFormatFloat: textureFormatFloat, downloadTextureFormat: downloadTextureFormat, downloadUnpackNumChannels: downloadUnpackNumChannels, defaultNumChannels: defaultNumChannels, textureTypeHalfFloat: textureTypeHalfFloat, textureTypeFloat: textureTypeFloat }; } function callAndCheck(gl, func) { var returnValue = func(); if (tf.env().getBool('DEBUG')) { checkWebGLError(gl); } return returnValue; } function checkWebGLError(gl) { var error = gl.getError(); if (error !== gl.NO_ERROR) { throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error)); } } // https://en.wikipedia.org/wiki/Half-precision_floating-point_format var MIN_FLOAT16 = 5.96e-8; var MAX_FLOAT16 = 65504; function canBeRepresented(num) { if (tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 || (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) { return true; } return false; } function getWebGLErrorMessage(gl, status) { switch (status) { case gl.NO_ERROR: return 'NO_ERROR'; case gl.INVALID_ENUM: return 'INVALID_ENUM'; case gl.INVALID_VALUE: return 'INVALID_VALUE'; case gl.INVALID_OPERATION: return 'INVALID_OPERATION'; case gl.INVALID_FRAMEBUFFER_OPERATION: return 'INVALID_FRAMEBUFFER_OPERATION'; case gl.OUT_OF_MEMORY: return 'OUT_OF_MEMORY'; case gl.CONTEXT_LOST_WEBGL: return 'CONTEXT_LOST_WEBGL'; default: return "Unknown error code ".concat(status); } } function getExtensionOrThrow(gl, extensionName) { return throwIfNull(gl, function () { return gl.getExtension(extensionName); }, 'Extension "' + extensionName + '" not supported on this browser.'); } function createVertexShader$1(gl, vertexShaderSource) { var vertexShader = throwIfNull(gl, function () { return gl.createShader(gl.VERTEX_SHADER); }, 'Unable to create vertex WebGLShader.'); callAndCheck(gl, function () { return gl.shaderSource(vertexShader, vertexShaderSource); }); callAndCheck(gl, function () { return gl.compileShader(vertexShader); }); if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) { console.log(gl.getShaderInfoLog(vertexShader)); throw new Error('Failed to compile vertex shader.'); } return vertexShader; } function createFragmentShader(gl, fragmentShaderSource) { var fragmentShader = throwIfNull(gl, function () { return gl.createShader(gl.FRAGMENT_SHADER); }, 'Unable to create fragment WebGLShader.'); callAndCheck(gl, function () { return gl.shaderSource(fragmentShader, fragmentShaderSource); }); callAndCheck(gl, function () { return gl.compileShader(fragmentShader); }); if (tf.env().get('ENGINE_COMPILE_ONLY')) { return fragmentShader; } if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) { logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader)); throw new Error('Failed to compile fragment shader.'); } return fragmentShader; } var lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g; function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) { var lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog); if (lineNumberRegexResult == null) { console.log("Couldn't parse line number in error: ".concat(shaderInfoLog)); console.log(shaderSource); return; } var lineNumber = +lineNumberRegexResult[1]; var shaderLines = shaderSource.split('\n'); var pad = shaderLines.length.toString().length + 2; var linesWithLineNumbers = shaderLines.map(function (line, lineNumber) { return tf.util.rightPad((lineNumber + 1).toString(), pad) + line; }); var maxLineLength = 0; for (var i = 0; i < linesWithLineNumbers.length; i++) { maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength); } var beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1); var errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber); var afterErrorLines = linesWithLineNumbers.slice(lineNumber); console.log(beforeErrorLines.join('\n')); console.log(shaderInfoLog.split('\n')[0]); console.log("%c ".concat(tf.util.rightPad(errorLine[0], maxLineLength)), 'border:1px solid red; background-color:#e3d2d2; color:#a61717'); console.log(afterErrorLines.join('\n')); } function createProgram(gl) { return throwIfNull(gl, function () { return gl.createProgram(); }, 'Unable to create WebGLProgram.'); } function linkProgram(gl, program) { callAndCheck(gl, function () { return gl.linkProgram(program); }); if (tf.env().get('ENGINE_COMPILE_ONLY')) { return; } if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) { console.log(gl.getProgramInfoLog(program)); throw new Error('Failed to link vertex and fragment shaders.'); } } /// validateProgram is effectively "If we `useProgram(program); drawArrays();`, /// give feedback in log about perf/correctness warnings or errors that would /// occur." /// So make sure we set up all vertex/texture/sampler/uniform data before /// calling validateProgram! function validateProgram(gl, program) { callAndCheck(gl, function () { return gl.validateProgram(program); }); if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) { console.log(gl.getProgramInfoLog(program)); throw new Error('Shader program validation failed.'); } } function createStaticVertexBuffer(gl, data) { var buffer = throwIfNull(gl, function () { return gl.createBuffer(); }, 'Unable to create WebGLBuffer'); callAndCheck(gl, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, buffer); }); callAndCheck(gl, function () { return gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW); }); return buffer; } function createStaticIndexBuffer(gl, data) { var buffer = throwIfNull(gl, function () { return gl.createBuffer(); }, 'Unable to create WebGLBuffer'); callAndCheck(gl, function () { return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer); }); callAndCheck(gl, function () { return gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW); }); return buffer; } function getNumChannels() { if (tf.env().getNumber('WEBGL_VERSION') === 2) { return 1; } return 4; } function createTexture(gl) { return throwIfNull(gl, function () { return gl.createTexture(); }, 'Unable to create WebGLTexture.'); } function validateTextureSize(width, height) { var maxTextureSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); if ((width <= 0) || (height <= 0)) { var requested = "[".concat(width, "x").concat(height, "]"); throw new Error('Requested texture size ' + requested + ' is invalid.'); } if ((width > maxTextureSize) || (height > maxTextureSize)) { var requested = "[".concat(width, "x").concat(height, "]"); var max = "[".concat(maxTextureSize, "x").concat(maxTextureSize, "]"); throw new Error('Requested texture size ' + requested + ' greater than WebGL maximum on this browser / GPU ' + max + '.'); } } function createFramebuffer(gl) { return throwIfNull(gl, function () { return gl.createFramebuffer(); }, 'Unable to create WebGLFramebuffer.'); } function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) { var loc = gl.getAttribLocation(program, attribute); if (loc === -1) { // The GPU compiler decided to strip out this attribute because it's unused, // thus no need to bind. return false; } callAndCheck(gl, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, buffer); }); callAndCheck(gl, function () { return gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes); }); callAndCheck(gl, function () { return gl.enableVertexAttribArray(loc); }); return true; } function bindTextureUnit(gl, texture, textureUnit) { validateTextureUnit(gl, textureUnit); callAndCheck(gl, function () { return gl.activeTexture(gl.TEXTURE0 + textureUnit); }); callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); }); } function unbindTextureUnit(gl, textureUnit) { validateTextureUnit(gl, textureUnit); callAndCheck(gl, function () { return gl.activeTexture(gl.TEXTURE0 + textureUnit); }); callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); } function getProgramUniformLocationOrThrow(gl, program, uniformName) { return throwIfNull(gl, function () { return gl.getUniformLocation(program, uniformName); }, 'uniform "' + uniformName + '" not present in program.'); } function getProgramUniformLocation(gl, program, uniformName) { return gl.getUniformLocation(program, uniformName); } function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) { callAndCheck(gl, function () { return bindTextureUnit(gl, texture, textureUnit); }); callAndCheck(gl, function () { return gl.uniform1i(uniformSamplerLocation, textureUnit); }); } function bindCanvasToFramebuffer(gl) { callAndCheck(gl, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, null); }); callAndCheck(gl, function () { return gl.viewport(0, 0, gl.canvas.width, gl.canvas.height); }); callAndCheck(gl, function () { return gl.scissor(0, 0, gl.canvas.width, gl.canvas.height); }); } function bindColorTextureToFramebuffer(gl, texture, framebuffer) { callAndCheck(gl, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); }); callAndCheck(gl, function () { return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); }); } function unbindColorTextureFromFramebuffer(gl, framebuffer) { callAndCheck(gl, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); }); callAndCheck(gl, function () { return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0); }); } function validateFramebuffer(gl) { var status = gl.checkFramebufferStatus(gl.FRAMEBUFFER); if (status !== gl.FRAMEBUFFER_COMPLETE) { throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status)); } } function getFramebufferErrorMessage(gl, status) { switch (status) { case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT: return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT'; case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT: return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT'; case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS: return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS'; case gl.FRAMEBUFFER_UNSUPPORTED: return 'FRAMEBUFFER_UNSUPPORTED'; default: return "unknown error ".concat(status); } } function throwIfNull(gl, returnTOrNull, failureMessage) { var tOrNull = callAndCheck(gl, function () { return returnTOrNull(); }); if (tOrNull == null) { throw new Error(failureMessage); } return tOrNull; } function validateTextureUnit(gl, textureUnit) { var maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1; var glTextureUnit = textureUnit + gl.TEXTURE0; if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) { var textureUnitRange = "[gl.TEXTURE0, gl.TEXTURE".concat(maxTextureUnit, "]"); throw new Error("textureUnit must be in ".concat(textureUnitRange, ".")); } } function getBatchDim(shape, dimsToSkip) { if (dimsToSkip === void 0) { dimsToSkip = 2; } return tf.util.sizeFromShape(shape.slice(0, shape.length - dimsToSkip)); } function getRowsCols(shape) { if (shape.length === 0) { throw Error('Cannot get rows and columns of an empty shape array.'); } return [ shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1] ]; } function getShapeAs3D(shape) { var shapeAs3D = [1, 1, 1]; var isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1); if (!isScalar) { shapeAs3D = __spreadArray([getBatchDim(shape)], __read(getRowsCols(shape)), false); } return shapeAs3D; } function getTextureShapeFromLogicalShape(logShape, isPacked) { var _a; if (isPacked === void 0) { isPacked = false; } var maxTexSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); var maxSizeForNarrowTex = tf.env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE'); if (maxSizeForNarrowTex === Infinity && tf.env().getBool('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE')) { maxSizeForNarrowTex = maxTexSize / 2; } if (isPacked) { maxTexSize = maxTexSize * 2; maxSizeForNarrowTex = maxSizeForNarrowTex * 2; // This logic ensures we accurately count the number of packed texels needed // to accommodate the tensor. We can only pack values in the same texel if // they are from adjacent pairs of rows/cols within the same batch. So if a // tensor has 3 rows, we pretend it has 4 rows in order to account for the // fact that the texels containing the third row are half empty. logShape = logShape.map(function (d, i) { return i >= logShape.length - 2 ? tf.util.nearestLargerEven(logShape[i]) : logShape[i]; }); // Packed texture height is at least 2 (the channel height of a single // texel). if (logShape.length === 1) { logShape = [2, logShape[0]]; } } // If logical shape is 2, we don't squeeze, since we want to match physical. if (logShape.length !== 2) { var squeezeResult = tf.util.squeezeShape(logShape); logShape = squeezeResult.newShape; } var size = tf.util.sizeFromShape(logShape); var textureShape = null; if (logShape.length <= 1 && size <= maxTexSize) { textureShape = [1, size]; } else if (logShape.length === 2 && logShape[0] <= maxTexSize && logShape[1] <= maxTexSize) { textureShape = logShape; } else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize && logShape[2] <= maxTexSize) { textureShape = [logShape[0] * logShape[1], logShape[2]]; } else if (logShape.length === 3 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] <= maxTexSize) { textureShape = [logShape[0], logShape[1] * logShape[2]]; } else if (logShape.length === 4 && logShape[0] * logShape[1] * logShape[2] <= maxTexSize && logShape[3] <= maxTexSize) { textureShape = [logShape[0] * logShape[1] * logShape[2], logShape[3]]; } else if (logShape.length === 4 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] * logShape[3] <= maxTexSize) { textureShape = [logShape[0], logShape[1] * logShape[2] * logShape[3]]; } // true if one edge length is 1 (1 or 2, if packed), while another edge // length exceeds maxSizeForNarrowTex. var isLongNarrowTex = textureShape != null && Math.max.apply(Math, __spreadArray([], __read(textureShape), false)) > maxSizeForNarrowTex && Math.min.apply(Math, __spreadArray([], __read(textureShape), false)) <= (isPacked ? 2 : 1) && Math.min.apply(Math, __spreadArray([], __read(textureShape), false)) > 0; if (textureShape == null || isLongNarrowTex) { if (isPacked) { // For packed textures size equals the number of channels required to // accommodate the texture data. However in order to squarify such that // inner dimensions stay even, we rewrite size to equal the number of // texels. Then in the return statement we rehydrate the squarified // dimensions to channel units. var batchDim = getBatchDim(logShape); var rows = 2, cols = 2; if (logShape.length) { _a = __read(getRowsCols(logShape), 2), rows = _a[0], cols = _a[1]; } size = batchDim * (rows / 2) * (cols / 2); textureShape = tf.util.sizeToSquarishShape(size).map(function (d) { return d * 2; }); } else { textureShape = tf.util.sizeToSquarishShape(size); } } return textureShape; } function isEven(n) { return n % 2 === 0; } /** * This determines whether reshaping a packed texture requires rearranging * the data within the texture, assuming 2x2 packing. */ function isReshapeFree(shape1, shape2) { shape1 = shape1.slice(-2); shape2 = shape2.slice(-2); if (tf.util.arraysEqual(shape1, shape2)) { return true; } if (!shape1.length || !shape2.length) { // One of the shapes is a scalar. return true; } if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 || shape2[1] === 0) { return true; } if (shape1.length !== shape2.length) { // One of the shapes is a vector. var shape1Cols = shape1[shape1.length - 1]; var shape2Cols = shape2[shape2.length - 1]; if (shape1Cols === shape2Cols) { return true; } if (isEven(shape1Cols) && isEven(shape2Cols) && (shape1[0] === 1 || shape2[0] === 1)) { return true; } } return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]); } // We cache webgl params because the environment gets reset between // unit tests and we don't want to constantly query the WebGLContext for // MAX_TEXTURE_SIZE. var MAX_TEXTURE_SIZE; var MAX_TEXTURES_IN_SHADER; function getWebGLMaxTextureSize(webGLVersion) { if (MAX_TEXTURE_SIZE == null) { var gl = getWebGLContext(webGLVersion); MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE); } return MAX_TEXTURE_SIZE; } function resetMaxTextureSize() { MAX_TEXTURE_SIZE = null; } function resetMaxTexturesInShader() { MAX_TEXTURES_IN_SHADER = null; } function getMaxTexturesInShader(webGLVersion) { if (MAX_TEXTURES_IN_SHADER == null) { var gl = getWebGLContext(webGLVersion); MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS); } // We cap at 16 to avoid spurious runtime "memory exhausted" error. return Math.min(16, MAX_TEXTURES_IN_SHADER); } function getWebGLDisjointQueryTimerVersion(webGLVersion) { if (webGLVersion === 0) { return 0; } var queryTimerVersion; var gl = getWebGLContext(webGLVersion); if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') && webGLVersion === 2) { queryTimerVersion = 2; } else if (hasExtension(gl, 'EXT_disjoint_timer_query')) { queryTimerVersion = 1; } else { queryTimerVersion = 0; } return queryTimerVersion; } function hasExtension(gl, extensionName) { var ext = gl.getExtension(extensionName); return ext != null; } function isWebGLVersionEnabled(webGLVersion) { try { var gl = getWebGLContext(webGLVersion); if (gl != null) { return true; } } catch (e) { console.log('Error when getting WebGL context: ', e); return false; } return false; } function isCapableOfRenderingToFloatTexture(webGLVersion) { if (webGLVersion === 0) { return false; } var gl = getWebGLContext(webGLVersion); if (webGLVersion === 1) { if (!hasExtension(gl, 'OES_texture_float')) { return false; } } else { if (!hasExtension(gl, 'EXT_color_buffer_float')) { return false; } } var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl); return isFrameBufferComplete; } /** * Check if we can download values from a float/half-float texture. * * Note that for performance reasons we use binding a texture to a framebuffer * as a proxy for ability to download float values later using readPixels. The * texture params of this texture will not match those in readPixels exactly * but if we are unable to bind some kind of float texture to the frameBuffer * then we definitely will not be able to read float values from it. */ function isDownloadFloatTextureEnabled(webGLVersion) { if (webGLVersion === 0) { return false; } var gl = getWebGLContext(webGLVersion); if (webGLVersion === 1) { if (!hasExtension(gl, 'OES_texture_float')) { return false; } if (!hasExtension(gl, 'WEBGL_color_buffer_float')) { return false; } } else { if (hasExtension(gl, 'EXT_color_buffer_float')) { return createFloatTextureAndBindToFramebuffer(gl); } var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float'; if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) { var textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT); return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension); } return false; } var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl); return isFrameBufferComplete; } function createFloatTextureAndBindToFramebuffer(gl) { var texConfig = getTextureConfig(gl); var texture = gl.createTexture(); gl.bindTexture(gl.TEXTURE_2D, texture); var width = 1; var height = 1; gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null); var frameBuffer = gl.createFramebuffer(); gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE; gl.bindTexture(gl.TEXTURE_2D, null); gl.bindFramebuffer(gl.FRAMEBUFFER, null); gl.deleteTexture(texture); gl.deleteFramebuffer(frameBuffer); return isFrameBufferComplete; } function createHalfFloatTextureAndBindToFramebuffer( // tslint:disable-next-line:no-any gl, textureHalfFloatExtension) { var texConfig = getTextureConfig(gl, textureHalfFloatExtension); var texture = gl.createTexture(); gl.bindTexture(gl.TEXTURE_2D, texture); var width = 1; var height = 1; gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null); var frameBuffer = gl.createFramebuffer(); gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE; gl.bindTexture(gl.TEXTURE_2D, null); gl.bindFramebuffer(gl.FRAMEBUFFER, null); gl.deleteTexture(texture); gl.deleteFramebuffer(frameBuffer); return isFrameBufferComplete; } function isWebGLFenceEnabled(webGLVersion) { if (webGLVersion !== 2) { return false; } var gl = getWebGLContext(webGLVersion); // tslint:disable-next-line:no-any var isEnabled = gl.fenceSync != null; return isEnabled; } function assertNotComplex(tensor, opName) { if (!Array.isArray(tensor)) { tensor = [tensor]; } tensor.forEach(function (t) { if (t != null) { tf.util.assert(t.dtype !== 'complex64', function () { return "".concat(opName, " does not support complex64 tensors ") + 'in the WebGL backend.'; }); } }); } var webgl_util = { __proto__: null, assertNotComplex: assertNotComplex, bindCanvasToFramebuffer: bindCanvasToFramebuffer, bindColorTextureToFramebuffer: bindColorTextureToFramebuffer, bindTextureToProgramUniformSampler: bindTextureToProgramUniformSampler, bindTextureUnit: bindTextureUnit, bindVertexBufferToProgramAttribute: bindVertexBufferToProgramAttribute, callAndCheck: callAndCheck, canBeRepresented: canBeRepresented, createFragmentShader: createFragmentShader, createFramebuffer: createFramebuffer, createProgram: createProgram, createStaticIndexBuffer: createStaticIndexBuffer, createStaticVertexBuffer: createStaticVertexBuffer, createTexture: createTexture, createVertexShader: createVertexShader$1, getBatchDim: getBatchDim, getExtensionOrThrow: getExtensionOrThrow, getFramebufferErrorMessage: getFramebufferErrorMessage, getMaxTexturesInShader: getMaxTexturesInShader, getNumChannels: getNumChannels, getProgramUniformLocation: getProgramUniformLocation, getProgramUniformLocationOrThrow: getProgramUniformLocationOrThrow, getRowsCols: getRowsCols, getShapeAs3D: getShapeAs3D, getTextureShapeFromLogicalShape: getTextureShapeFromLogicalShape, getWebGLDisjointQueryTimerVersion: getWebGLDisjointQueryTimerVersion, getWebGLErrorMessage: getWebGLErrorMessage, getWebGLMaxTextureSize: getWebGLMaxTextureSize, hasExtension: hasExtension, isCapableOfRenderingToFloatTexture: isCapableOfRenderingToFloatTexture, isDownloadFloatTextureEnabled: isDownloadFloatTextureEnabled, isReshapeFree: isReshapeFree, isWebGLFenceEnabled: isWebGLFenceEnabled, isWebGLVersionEnabled: isWebGLVersionEnabled, linkProgram: linkProgram, logShaderSourceAndInfoLog: logShaderSourceAndInfoLog, resetMaxTextureSize: resetMaxTextureSize, resetMaxTexturesInShader: resetMaxTexturesInShader, unbindColorTextureFromFramebuffer: unbindColorTextureFromFramebuffer, unbindTextureUnit: unbindTextureUnit, validateFramebuffer: validateFramebuffer, validateProgram: validateProgram, validateTextureSize: validateTextureSize }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ENV = tf.env(); /** * This file contains WebGL-specific flag registrations. */ /** * True if WebGL is supported. */ ENV.registerFlag('HAS_WEBGL', function () { return ENV.getNumber('WEBGL_VERSION') > 0; }); /** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */ ENV.registerFlag('WEBGL_VERSION', function () { if (isWebGLVersionEnabled(2)) { return 2; } else if (isWebGLVersionEnabled(1)) { return 1; } return 0; }); /** Whether to check for numerical representation problems. */ ENV.registerFlag('WEBGL_CHECK_NUMERICAL_PROBLEMS', function () { return false; }); ENV.registerFlag('WEBGL_BUFFER_SUPPORTED', function () { return ENV.get('WEBGL_VERSION') === 2; }); /** Whether the WebGL backend will sometimes forward ops to the CPU. */ ENV.registerFlag('WEBGL_CPU_FORWARD', function () { return true; }); /** Whether the WebGL backend will always use f16 textures for rendering. */ ENV.registerFlag('WEBGL_FORCE_F16_TEXTURES', function () { return false; }); /** Whether to turn all packing related flags on. */ ENV.registerFlag('WEBGL_PACK', function () { return ENV.getBool('HAS_WEBGL'); }); /** Whether we will pack the batchnormalization op. */ ENV.registerFlag('WEBGL_PACK_NORMALIZATION', function () { return ENV.getBool('WEBGL_PACK'); }); /** Whether we will pack the clip op. */ ENV.registerFlag('WEBGL_PACK_CLIP', function () { return ENV.getBool('WEBGL_PACK'); }); /** Whether we will pack the depthwise conv op. */ ENV.registerFlag('WEBGL_PACK_DEPTHWISECONV', function () { return ENV.getBool('WEBGL_PACK'); }); /** Whether we will pack binary ops. */ ENV.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', function () { return ENV.getBool('WEBGL_PACK'); }); /** Whether we will pack unary ops. */ ENV.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', function () { return ENV.getBool('WEBGL_PACK'); }); /** Whether we will pack array ops. */ ENV.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', function () { return ENV.getBool('WEBGL_PACK'); }); /** Whether we will pack image ops. */ ENV.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', function () { return ENV.getBool('WEBGL_PACK'); }); /** Whether we will pack reduce ops. */ ENV.registerFlag('WEBGL_PACK_REDUCE', function () { return ENV.getBool('WEBGL_PACK'); }); /** Whether packed WebGL kernels lazily unpack their outputs. */ ENV.registerFlag('WEBGL_LAZILY_UNPACK', function () { return ENV.getBool('WEBGL_PACK'); }); /** Whether we will use the im2col algorithm to speed up convolutions. */ ENV.registerFlag('WEBGL_CONV_IM2COL', function () { return ENV.getBool('WEBGL_PACK'); }); /** Whether we will pack conv2dTranspose op. */ ENV.registerFlag('WEBGL_PACK_CONV2DTRANSPOSE', function () { return ENV.getBool('WEBGL_PACK'); }); /** The maximum texture dimension. */ ENV.registerFlag('WEBGL_MAX_TEXTURE_SIZE', function () { return getWebGLMaxTextureSize(ENV.getNumber('WEBGL_VERSION')); }); /** The maximum texture dimension. */ ENV.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', function () { return getMaxTexturesInShader(ENV.getNumber('WEBGL_VERSION')); }); /** * The disjoint_query_timer extension version. * 0: disabled, 1: EXT_disjoint_timer_query, 2: * EXT_disjoint_timer_query_webgl2. * In Firefox with WebGL 2.0, * EXT_disjoint_timer_query_webgl2 is not available, so we must use the * WebGL 1.0 extension. */ ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', function () { var webGLVersion = ENV.getNumber('WEBGL_VERSION'); if (webGLVersion === 0) { return 0; } return getWebGLDisjointQueryTimerVersion(webGLVersion); }); /** * Whether the timer object from the disjoint_query_timer extension gives * timing information that is reliable. */ ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', function () { return ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 && !tf.device_util.isMobile(); }); /** * Whether the device is physically capable of rendering to float32 textures. */ ENV.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', function () { return isCapableOfRenderingToFloatTexture(ENV.getNumber('WEBGL_VERSION')); }); /** * Whether rendering to float32 textures is enabled. If disabled, renders to * float16 textures. */ ENV.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', function () { return ENV.getBool('WEBGL_FORCE_F16_TEXTURES') ? false : ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE'); }); /** * Whether downloading float textures is enabled (16 or 32 bit). If disabled, * uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading. */ ENV.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', function () { return isDownloadFloatTextureEnabled(ENV.getNumber('WEBGL_VERSION')); }); /** Whether the fence API is available. */ ENV.registerFlag('WEBGL_FENCE_API_ENABLED', function () { return isWebGLFenceEnabled(ENV.getNumber('WEBGL_VERSION')); }); /** * Tensors with size <= than this will be uploaded as uniforms, not textures. */ ENV.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', function () { // Use uniform uploads only when 32bit floats are supported. In // 16bit // environments there are problems with comparing a 16bit texture value // with a 32bit uniform value. var useUniforms = ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED'); return useUniforms ? 4 : 0; }); /** * If the total number of bytes allocated on the GPU is greater than this * number, we will aggressively delete textures upon disposal with * gl.deleteMatrixTexture, rather than making them available for reuse. * * Default value -1 indicates that we will never aggressively delete textures. */ ENV.registerFlag('WEBGL_DELETE_TEXTURE_THRESHOLD', function () { return -1; }, function (threshold) { if (!(typeof threshold === 'number')) { throw new Error('WEBGL_DELETE_TEXTURE_THRESHOLD must be a number but ' + "got ".concat(threshold, ".")); } if (threshold < 0 && threshold !== -1) { throw new Error("WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never " + "delete) or at least 0, but got ".concat(threshold, ".")); } }); /** * Trigger a manual GL command flush if the threshold of time has passed since * previous Kernel execution. This can be useful for Andorid device where GL * command flush are delayed un til the end of javascript task. This value is * measured in millisecond. Typically you want to set this value to close to 1. * * Default value 1 for mobile chrome, and -1 for rest cases. -1 indicates that * we will not enforce manual flush and depend on system default flush schedule. */ ENV.registerFlag('WEBGL_FLUSH_THRESHOLD', function () { return tf.device_util.isMobile() ? 1 : -1; }, function (threshold) { if (!(typeof threshold === 'number')) { throw new Error('WEBGL_FLUSH_THRESHOLD must be a number but got ' + "".concat(threshold, ".")); } if (threshold < 0 && threshold !== -1) { throw new Error("WEBGL_FLUSH_THRESHOLD must be -1 (indicating never " + "manual flush) or at least 0, but got ".concat(threshold, ".")); } }); /** * Threshold for input tensor size that determines whether WebGL backend will * delegate computation to CPU. * * Default value is 128. */ ENV.registerFlag('CPU_HANDOFF_SIZE_THRESHOLD', function () { return 128; }); /** Whether we will use shapes uniforms. */ ENV.registerFlag('WEBGL_USE_SHAPES_UNIFORMS', function () { return false; }); /** * Threshold for last dimension of input tensor that determines whether * WebGL backend for the Top K op will delegate computation to CPU. If input * is smaller than threshold then CPU will be used * * Default value is 100000. */ ENV.registerFlag('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD', function () { return 100000; }); /** * Threshold for K that determines whether * WebGL backend for the Top K op will delegate computation to CPU. If k * is larger than threshold then CPU will be used * * Default value is 128. */ ENV.registerFlag('TOPK_K_CPU_HANDOFF_THRESHOLD', function () { return 128; }); /** Whether we will use the experimental conv op. */ ENV.registerFlag('WEBGL_EXP_CONV', function () { return false; }); /** * If the device performance is low or if no hardware GPU is available, whether * software WebGL will be used. */ ENV.registerFlag('SOFTWARE_WEBGL_ENABLED', function () { return ENV.getBool('IS_TEST'); }); /** * For narrow texture (physical height or physical width is 1), if the length of * any texture edges exceed the threshold, the texture will be reshaped to be * more squarish. * * This flag is used to help some GPUs that could not provide correct * interpolations for long skinny triangles. We found Mali GPU probably has this * problem: https://github.com/tensorflow/tfjs/issues/6775. */ ENV.registerFlag('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', function () { return Infinity; }); /** * If the flag is set to true, the max size of the narrow texture will be auto * computed and it will be considerred as a threshold to reshape the narrow * texture to be more squarish. * * This flag is used to help some GPUs that could not provide correct * interpolations for long skinny triangles. We found Mali GPU probably has this * problem: https://github.com/tensorflow/tfjs/issues/6775. */ ENV.registerFlag('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', function () { return false; }); /** * Whether to use the customized isnan. It's only useful for webgl2 since webgl1 * doesn't have the builtin isnan. */ ENV.registerFlag('WEBGL2_ISNAN_CUSTOM', function () { return false; }); /** Experimental flag, whether enter compile only phase. */ ENV.registerFlag('ENGINE_COMPILE_ONLY', function () { return false; }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function getGlslDifferences() { var version; var attribute; var varyingVs; var varyingFs; var texture2D; var output; var defineOutput; var defineSpecialNaN; var defineSpecialInf; var defineRound; if (tf.env().getNumber('WEBGL_VERSION') === 2) { version = '#version 300 es'; attribute = 'in'; varyingVs = 'out'; varyingFs = 'in'; texture2D = 'texture'; output = 'outputColor'; defineOutput = 'out vec4 outputColor;'; // Use custom isnan definition to work across differences between // implementations on various platforms. While this should happen in ANGLE // we still see differences between android and windows (on chrome) when // using isnan directly. Since WebGL2 supports uint type and // floatBitsToUinT built-in function, we could implment isnan following // IEEE 754 rules. // NaN defination in IEEE 754-1985 is : // - sign = either 0 or 1. // - biased exponent = all 1 bits. // - fraction = anything except all 0 bits (since all 0 bits represents // infinity). // https://en.wikipedia.org/wiki/IEEE_754-1985#Representation_of_non-numbers defineSpecialNaN = tf.env().getBool('WEBGL2_ISNAN_CUSTOM') ? "\n bool isnan_custom(float val) {\n uint floatToUint = floatBitsToUint(val);\n return (floatToUint & 0x7fffffffu) > 0x7f800000u;\n }\n\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan_custom(val.x),\n isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));\n }\n\n #define isnan(value) isnan_custom(value)\n " : ''; // In webgl 2 we do not need to specify a custom isinf so there is no // need for a special INFINITY constant. defineSpecialInf = ""; defineRound = "\n #define round(value) newRound(value)\n int newRound(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 newRound(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n "; } else { version = ''; attribute = 'attribute'; varyingVs = 'varying'; varyingFs = 'varying'; texture2D = 'texture2D'; output = 'gl_FragColor'; defineOutput = ''; // WebGL1 has no built in isnan so we define one here. defineSpecialNaN = "\n #define isnan(value) isnan_custom(value)\n bool isnan_custom(float val) {\n return (val > 0. || val < 1. || val == 0.) ? false : true;\n }\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));\n }\n "; defineSpecialInf = "\n uniform float INFINITY;\n\n bool isinf(float val) {\n return abs(val) == INFINITY;\n }\n bvec4 isinf(vec4 val) {\n return equal(abs(val), vec4(INFINITY));\n }\n "; defineRound = "\n int round(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 round(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n "; } return { version: version, attribute: attribute, varyingVs: varyingVs, varyingFs: varyingFs, texture2D: texture2D, output: output, defineOutput: defineOutput, defineSpecialNaN: defineSpecialNaN, defineSpecialInf: defineSpecialInf, defineRound: defineRound }; } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Produces GLSL code that derives logical coordinates from a flat * index. The code performs integer division with each stride and decrements * the index until the index equals the final dimension coordinate. */ function getLogicalCoordinatesFromFlatIndex(coords, shape, index) { if (index === void 0) { index = 'index'; } var strides = tf.util.computeStrides(shape); return strides .map(function (stride, i) { var line1 = "int ".concat(coords[i], " = ").concat(index, " / ").concat(stride); var line2 = i === strides.length - 1 ? "int ".concat(coords[i + 1], " = ").concat(index, " - ").concat(coords[i], " * ").concat(stride) : "index -= ".concat(coords[i], " * ").concat(stride); return "".concat(line1, "; ").concat(line2, ";"); }) .join(''); } function getOutputLogicalCoordinatesFromFlatIndexByUniform(coords, shape, index) { if (index === void 0) { index = 'index'; } var strides = tf.util.computeStrides(shape); return strides .map(function (_, i) { var line1 = "int ".concat(coords[i], " = ").concat(index, " / outShapeStrides[").concat(i, "]"); var line2 = i === strides.length - 1 ? "int ".concat(coords[i + 1], " = ").concat(index, " - ").concat(coords[i], " * outShapeStrides[").concat(i, "]") : "index -= ".concat(coords[i], " * outShapeStrides[").concat(i, "]"); return "".concat(line1, "; ").concat(line2, ";"); }) .join(''); } // Produces GLSL code that computes strides. function symbolicallyComputeStrides(indicesArr, variableName) { var numCoords = indicesArr.length; var shape = indicesArr.map(function (d) { return "".concat(variableName, "[").concat(d, "]"); }); var strides = new Array(numCoords - 1); strides[numCoords - 2] = shape[numCoords - 1]; for (var i = numCoords - 3; i >= 0; --i) { strides[i] = "(".concat(strides[i + 1], " * ").concat(shape[i + 1], ")"); } return strides; } function getLogicalCoordinatesFromFlatIndexByUniform(coords, variableName, index) { if (index === void 0) { index = 'index'; } var indicesArray = coords.map(function (_, i) { return i; }); var strides = symbolicallyComputeStrides(indicesArray, variableName); return strides .map(function (_, i) { var line1 = "int ".concat(coords[i], " = ").concat(index, " / ").concat(strides[i]); var line2 = i === strides.length - 1 ? "int ".concat(coords[i + 1], " = ").concat(index, " - ").concat(coords[i], " * ").concat(strides[i]) : "index -= ".concat(coords[i], " * ").concat(strides[i]); return "".concat(line1, "; ").concat(line2, ";"); }) .join(''); } /** * Produces GLSL that computes the flat index from 3D coordinates. */ function getFlatIndexFrom3D(shape) { var strides = tf.util.computeStrides(shape).map(function (d) { return d.toString(); }); return "\n int getFlatIndex(ivec3 coords) {\n return coords.x * ".concat(strides[0], " + coords.y * ").concat(strides[1], " + coords.z;\n }\n"); } function getFlatIndexFrom3DOutput() { return "\n int getFlatIndex(ivec3 coords) {\n return coords.x * outShapeStrides[0] + coords.y * outShapeStrides[1] + coords.z;\n }\n"; } var ENCODE_FLOAT_SNIPPET = "\n const float FLOAT_MAX = 1.70141184e38;\n const float FLOAT_MIN = 1.17549435e-38;\n\n lowp vec4 encode_float(highp float v) {\n if (isnan(v)) {\n return vec4(255, 255, 255, 255);\n }\n\n highp float av = abs(v);\n\n if(av < FLOAT_MIN) {\n return vec4(0.0, 0.0, 0.0, 0.0);\n } else if(v > FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;\n } else if(v < -FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;\n }\n\n highp vec4 c = vec4(0,0,0,0);\n\n highp float e = floor(log2(av));\n highp float m = exp2(fract(log2(av))) - 1.0;\n\n c[2] = floor(128.0 * m);\n m -= c[2] / 128.0;\n c[1] = floor(32768.0 * m);\n m -= c[1] / 32768.0;\n c[0] = floor(8388608.0 * m);\n\n highp float ebias = e + 127.0;\n c[3] = floor(ebias / 2.0);\n ebias -= c[3] * 2.0;\n c[2] += floor(ebias) * 128.0;\n\n c[3] += 128.0 * step(0.0, -v);\n\n return c / 255.0;\n }\n"; var getBroadcastDims = tf.backend_util.getBroadcastDims; function makeShader(inputsInfo, outputShape, program) { var prefixSnippets = []; inputsInfo.forEach(function (x) { var size = tf.util.sizeFromShape(x.shapeInfo.logicalShape); // Snippet when we decided to upload the values as uniform. if (x.shapeInfo.isUniform) { prefixSnippets.push("uniform float ".concat(x.name).concat(size > 1 ? "[".concat(size, "]") : '', ";")); } else { prefixSnippets.push("uniform sampler2D ".concat(x.name, ";")); prefixSnippets.push("uniform int offset".concat(x.name, ";")); } if (program.enableShapeUniforms) { var uniformShape = getUniformInfoFromShape(program.packedInputs, x.shapeInfo.logicalShape, x.shapeInfo.texShape).uniformShape; switch (uniformShape.length) { case 1: prefixSnippets.push("uniform int ".concat(x.name, "Shape;")); break; case 2: prefixSnippets.push("uniform ivec2 ".concat(x.name, "Shape;")); break; case 3: prefixSnippets.push("uniform ivec3 ".concat(x.name, "Shape;")); break; case 4: prefixSnippets.push("uniform ivec4 ".concat(x.name, "Shape;")); break; } prefixSnippets.push("uniform ivec2 ".concat(x.name, "TexShape;")); } }); if (program.enableShapeUniforms) { switch (outputShape.logicalShape.length) { case 1: prefixSnippets.push("uniform int outShape;"); break; case 2: prefixSnippets.push("uniform ivec2 outShape;"); prefixSnippets.push("uniform int outShapeStrides;"); break; case 3: prefixSnippets.push("uniform ivec3 outShape;"); prefixSnippets.push("uniform ivec2 outShapeStrides;"); break; case 4: prefixSnippets.push("uniform ivec4 outShape;"); prefixSnippets.push("uniform ivec3 outShapeStrides;"); break; } prefixSnippets.push("uniform ivec2 outTexShape;"); } if (program.customUniforms) { program.customUniforms.forEach(function (d) { prefixSnippets.push("uniform ".concat(d.type, " ").concat(d.name).concat(d.arrayIndex ? "[".concat(d.arrayIndex, "]") : '', ";")); }); } var inputPrefixSnippet = prefixSnippets.join('\n'); var inputSamplingSnippet = inputsInfo .map(function (x) { return getInputSamplingSnippet(x, outputShape, program.packedInputs, program.enableShapeUniforms); }) .join('\n'); var outTexShape = outputShape.texShape; var glsl = getGlslDifferences(); var floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl); var outputSamplingSnippet; var floatTextureSetOutputSnippet; var shaderPrefix = getShaderPrefix(glsl); if (outputShape.isPacked) { outputSamplingSnippet = getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms); floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl); } else { outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms); floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl); } if (program.packedInputs) { shaderPrefix += SHADER_PACKED_PREFIX; } var source = [ shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet, inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, program.userCode ].join('\n'); return source; } function getSamplerFromInInfo(inInfo, enableShapeUniforms) { if (enableShapeUniforms === void 0) { enableShapeUniforms = false; } var shape = inInfo.shapeInfo.logicalShape; switch (shape.length) { case 0: return getSamplerScalar(inInfo, enableShapeUniforms); case 1: return getSampler1D(inInfo, enableShapeUniforms); case 2: return getSampler2D(inInfo, enableShapeUniforms); case 3: return getSampler3D(inInfo, enableShapeUniforms); case 4: return getSampler4D(inInfo, enableShapeUniforms); case 5: return getSampler5D(inInfo); case 6: return getSampler6D(inInfo); default: throw new Error("".concat(shape.length, "-D input sampling") + " is not yet supported"); } } function getPackedSamplerFromInInfo(inInfo, enableShapeUniforms) { var shape = inInfo.shapeInfo.logicalShape; switch (shape.length) { case 0: return getPackedSamplerScalar(inInfo); case 1: return getPackedSampler1D(inInfo, enableShapeUniforms); case 2: return getPackedSampler2D(inInfo, enableShapeUniforms); case 3: return getPackedSampler3D(inInfo, enableShapeUniforms); default: return getPackedSamplerND(inInfo, enableShapeUniforms); } } function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures, enableShapeUniforms) { if (usesPackedTextures === void 0) { usesPackedTextures = false; } var res = ''; if (usesPackedTextures) { res += getPackedSamplerFromInInfo(inInfo, enableShapeUniforms); } else { res += getSamplerFromInInfo(inInfo, enableShapeUniforms); } var inShape = inInfo.shapeInfo.logicalShape; var outShape = outShapeInfo.logicalShape; if (inShape.length <= outShape.length) { if (usesPackedTextures) { res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo); } else { res += getSamplerAtOutputCoords(inInfo, outShapeInfo); } } return res; } function getPackedOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) { switch (outShape.length) { case 0: return getOutputScalarCoords(); case 1: return getOutputPacked1DCoords(outShape, outTexShape, enableShapeUniforms); case 2: return getOutputPacked2DCoords(outShape, outTexShape, enableShapeUniforms); case 3: return getOutputPacked3DCoords(outShape, outTexShape, enableShapeUniforms); default: return getOutputPackedNDCoords(outShape, outTexShape, enableShapeUniforms); } } function getOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) { switch (outShape.length) { case 0: return getOutputScalarCoords(); case 1: return getOutput1DCoords(outShape, outTexShape, enableShapeUniforms); case 2: return getOutput2DCoords(outShape, outTexShape, enableShapeUniforms); case 3: return getOutput3DCoords(outShape, outTexShape, enableShapeUniforms); case 4: return getOutput4DCoords(outShape, outTexShape, enableShapeUniforms); case 5: return getOutput5DCoords(outShape, outTexShape); case 6: return getOutput6DCoords(outShape, outTexShape); default: throw new Error("".concat(outShape.length, "-D output sampling is not yet supported")); } } function getFloatTextureSampleSnippet(glsl) { return "\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n return ".concat(glsl.texture2D, "(textureSampler, uv).r;\n }\n "); } function getFloatTextureSetRSnippet(glsl) { return "\n void setOutput(float val) {\n ".concat(glsl.output, " = vec4(val, 0, 0, 0);\n }\n "); } function getFloatTextureSetRGBASnippet(glsl) { return "\n void setOutput(vec4 val) {\n ".concat(glsl.output, " = val;\n }\n "); } function getShaderPrefix(glsl) { var SHADER_PREFIX = "".concat(glsl.version, "\n precision highp float;\n precision highp int;\n precision highp sampler2D;\n ").concat(glsl.varyingFs, " vec2 resultUV;\n ").concat(glsl.defineOutput, "\n const vec2 halfCR = vec2(0.5, 0.5);\n\n struct ivec5\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n };\n\n struct ivec6\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n int v;\n };\n\n uniform float NAN;\n ").concat(glsl.defineSpecialNaN, "\n ").concat(glsl.defineSpecialInf, "\n ").concat(glsl.defineRound, "\n\n int imod(int x, int y) {\n return x - y * (x / y);\n }\n\n int idiv(int a, int b, float sign) {\n int res = a / b;\n int mod = imod(a, b);\n if (sign < 0. && mod != 0) {\n res -= 1;\n }\n return res;\n }\n\n //Based on the work of Dave Hoskins\n //https://www.shadertoy.com/view/4djSRW\n #define HASHSCALE1 443.8975\n float random(float seed){\n vec2 p = resultUV * seed;\n vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);\n p3 += dot(p3, p3.yzx + 19.19);\n return fract((p3.x + p3.y) * p3.z);\n }\n\n ").concat(SAMPLE_1D_SNIPPET, "\n ").concat(SAMPLE_2D_SNIPPET, "\n ").concat(SAMPLE_3D_SNIPPET, "\n "); return SHADER_PREFIX; } var SAMPLE_1D_SNIPPET = "\nvec2 uvFromFlat(int texNumR, int texNumC, int index) {\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\nvec2 packedUVfrom1D(int texNumR, int texNumC, int index) {\n int texelIndex = index / 2;\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; var SAMPLE_2D_SNIPPET = "\nvec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,\n int texNumC, int row, int col) {\n int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; var SAMPLE_3D_SNIPPET = "\nvec2 packedUVfrom3D(int texNumR, int texNumC,\n int texelsInBatch, int texelsInLogicalRow, int b,\n int row, int col) {\n int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; var SHADER_PACKED_PREFIX = "\n float getChannel(vec4 frag, vec2 innerDims) {\n vec2 modCoord = mod(innerDims, 2.);\n return modCoord.x == 0. ?\n (modCoord.y == 0. ? frag.r : frag.g) :\n (modCoord.y == 0. ? frag.b : frag.a);\n }\n float getChannel(vec4 frag, int dim) {\n float modCoord = mod(float(dim), 2.);\n return modCoord == 0. ? frag.r : frag.g;\n }\n"; function getOutputScalarCoords() { return "\n int getOutputCoords() {\n return 0;\n }\n "; } function getOutputPacked1DCoords(shape, texShape, enableShapeUniforms) { var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; if (packedTexShape[0] === 1) { if (enableShapeUniforms) { return "\n int getOutputCoords() {\n return 2 * int(resultUV.x * ceil(float(outTexShape[1]) / 2.0));\n }\n "; } return "\n int getOutputCoords() {\n return 2 * int(resultUV.x * ".concat(packedTexShape[1], ".0);\n }\n "); } if (packedTexShape[1] === 1) { if (enableShapeUniforms) { return "\n int getOutputCoords() {\n return 2 * int(resultUV.y * ceil(float(outTexShape[0]) / 2.0));\n }\n "; } return "\n int getOutputCoords() {\n return 2 * int(resultUV.y * ".concat(packedTexShape[0], ".0);\n }\n "); } if (enableShapeUniforms) { return "\n int getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n return 2 * (resTexRC.x * packedTexShape[1] + resTexRC.y);\n }\n "; } return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(packedTexShape[0], ", ").concat(packedTexShape[1], "));\n return 2 * (resTexRC.x * ").concat(packedTexShape[1], " + resTexRC.y);\n }\n "); } function getOutput1DCoords(shape, texShape, enableShapeUniforms) { if (texShape[0] === 1) { if (enableShapeUniforms) { return "\n int getOutputCoords() {\n return int(resultUV.x * float(outTexShape[1]));\n }\n "; } return "\n int getOutputCoords() {\n return int(resultUV.x * ".concat(texShape[1], ".0);\n }\n "); } if (texShape[1] === 1) { if (enableShapeUniforms) { return "\n int getOutputCoords() {\n return int(resultUV.y * float(outTexShape[0]));\n }\n "; } return "\n int getOutputCoords() {\n return int(resultUV.y * ".concat(texShape[0], ".0);\n }\n "); } if (enableShapeUniforms) { return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n return resTexRC.x * outTexShape[1] + resTexRC.y;\n }\n "; } return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n return resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n }\n "); } function getOutputPacked3DCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { return "\n ivec3 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n int texelsInLogicalRow = int(ceil(float(outShape[2]) / 2.0));\n int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n int index = resTexRC.x * packedTexShape[1] + resTexRC.y;\n\n int b = index / texelsInBatch;\n index -= b * texelsInBatch;\n\n int r = 2 * (index / texelsInLogicalRow);\n int c = imod(index, texelsInLogicalRow) * 2;\n\n return ivec3(b, r, c);\n }\n "; } var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; var texelsInLogicalRow = Math.ceil(shape[2] / 2); var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2); return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(packedTexShape[0], ", ").concat(packedTexShape[1], "));\n int index = resTexRC.x * ").concat(packedTexShape[1], " + resTexRC.y;\n\n int b = index / ").concat(texelsInBatch, ";\n index -= b * ").concat(texelsInBatch, ";\n\n int r = 2 * (index / ").concat(texelsInLogicalRow, ");\n int c = imod(index, ").concat(texelsInLogicalRow, ") * 2;\n\n return ivec3(b, r, c);\n }\n "); } function getOutput3DCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { var coordsFromIndexSnippet_1 = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], shape); return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n ".concat(coordsFromIndexSnippet_1, "\n return ivec3(r, c, d);\n }\n"); } var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape); return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n ").concat(coordsFromIndexSnippet, "\n return ivec3(r, c, d);\n }\n "); } function getOutputPackedNDCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { // TODO: support 5d and 6d return "\n ivec4 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n int index = resTexRC.x * packedTexShape[1] + resTexRC.y;\n\n int texelsInLogicalRow = int(ceil(float(outShape[3]) / 2.0));\n int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[2]) / 2.0));\n int texelsInBatchN = texelsInBatch * outShape[1];\n\n int b2 = index / texelsInBatchN;\n index -= b2 * texelsInBatchN;\n\n int b = index / texelsInBatch;\n index -= b * texelsInBatch;\n\n int r = 2 * (index / texelsInLogicalRow);\n int c = imod(index, texelsInLogicalRow) * 2;\n\n return ivec4(b2, b, r, c);\n }\n "; } var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; var texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2); var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2); var texelsInBatchN = texelsInBatch; var batches = ""; var coords = 'b, r, c'; for (var b = 2; b < shape.length - 1; b++) { texelsInBatchN *= shape[shape.length - b - 1]; batches = "\n int b".concat(b, " = index / ").concat(texelsInBatchN, ";\n index -= b").concat(b, " * ").concat(texelsInBatchN, ";\n ") + batches; coords = "b".concat(b, ", ") + coords; } return "\n ivec".concat(shape.length, " getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(").concat(packedTexShape[0], ", ").concat(packedTexShape[1], "));\n int index = resTexRC.x * ").concat(packedTexShape[1], " + resTexRC.y;\n\n ").concat(batches, "\n\n int b = index / ").concat(texelsInBatch, ";\n index -= b * ").concat(texelsInBatch, ";\n\n int r = 2 * (index / ").concat(texelsInLogicalRow, ");\n int c = imod(index, ").concat(texelsInLogicalRow, ") * 2;\n\n return ivec").concat(shape.length, "(").concat(coords, ");\n }\n "); } function getOutput4DCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { var coordsFromIndexSnippet_2 = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd', 'd2'], shape); return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n ".concat(coordsFromIndexSnippet_2, "\n return ivec4(r, c, d, d2);\n }\n "); } var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape); return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n ").concat(coordsFromIndexSnippet, "\n return ivec4(r, c, d, d2);\n }\n "); } function getOutput5DCoords(shape, texShape) { var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape); return "\n ivec5 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(".concat(texShape[0], ",\n ").concat(texShape[1], "));\n\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n\n ").concat(coordsFromIndexSnippet, "\n\n ivec5 outShape = ivec5(r, c, d, d2, d3);\n return outShape;\n }\n "); } function getOutput6DCoords(shape, texShape) { var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape); return "\n ivec6 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n\n ").concat(coordsFromIndexSnippet, "\n\n ivec6 result = ivec6(r, c, d, d2, d3, d4);\n return result;\n }\n "); } function getOutputPacked2DCoords(shape, texShape, enableShapeUniforms) { var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; if (tf.util.arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return "\n ivec2 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n return 2 * ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1]));\n }\n "; } return "\n ivec2 getOutputCoords() {\n return 2 * ivec2(resultUV.yx * vec2(".concat(packedTexShape[0], ", ").concat(packedTexShape[1], "));\n }\n "); } // texels needed to accommodate a logical row var texelsInLogicalRow = Math.ceil(shape[1] / 2); /** * getOutputCoords * * resTexRC: The rows and columns of the texels. If you move over one * texel to the right in the packed texture, you are moving over one column * (not two). * * index: The texel index */ if (enableShapeUniforms) { return "\n ivec2 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n int texelsInLogicalRow = int(ceil(float(outShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n\n int index = resTexRC.x * packedTexShape[1] + resTexRC.y;\n int r = 2 * (index / texelsInLogicalRow);\n int c = imod(index, texelsInLogicalRow) * 2;\n\n return ivec2(r, c);\n }\n "; } return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(packedTexShape[0], ", ").concat(packedTexShape[1], "));\n\n int index = resTexRC.x * ").concat(packedTexShape[1], " + resTexRC.y;\n int r = 2 * (index / ").concat(texelsInLogicalRow, ");\n int c = imod(index, ").concat(texelsInLogicalRow, ") * 2;\n\n return ivec2(r, c);\n }\n "); } function getOutput2DCoords(shape, texShape, enableShapeUniforms) { if (tf.util.arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1]));\n }\n "; } return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n }\n "); } if (shape[1] === 1) { if (enableShapeUniforms) { return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n return ivec2(index, 0);\n }\n "; } return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n return ivec2(index, 0);\n }\n "); } if (shape[0] === 1) { if (enableShapeUniforms) { return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n return ivec2(0, index);\n }\n "; } return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n return ivec2(0, index);\n }\n "); } if (enableShapeUniforms) { return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n int r = index / outShape[1];\n int c = index - r * outShape[1];\n return ivec2(r, c);\n }\n "; } return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(".concat(texShape[0], ", ").concat(texShape[1], "));\n int index = resTexRC.x * ").concat(texShape[1], " + resTexRC.y;\n int r = index / ").concat(shape[1], ";\n int c = index - r * ").concat(shape[1], ";\n return ivec2(r, c);\n }\n "); } function getFlatOffsetUniformName(texName) { return "offset".concat(texName); } function getPackedSamplerScalar(inputInfo) { var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var glsl = getGlslDifferences(); return "\n vec4 ".concat(funcName, "() {\n return ").concat(glsl.texture2D, "(").concat(texName, ", halfCR);\n }\n "); } function getSamplerScalar(inputInfo, enableShapeUniforms) { var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); if (inputInfo.shapeInfo.isUniform) { return "float ".concat(funcName, "() {return ").concat(texName, ";}"); } var _a = __read(inputInfo.shapeInfo.texShape, 2), texNumR = _a[0], texNumC = _a[1]; if (texNumR === 1 && texNumC === 1) { return "\n float ".concat(funcName, "() {\n return sampleTexture(").concat(texName, ", halfCR);\n }\n "); } var offset = getFlatOffsetUniformName(texName); if (enableShapeUniforms) { return "\n float ".concat(funcName, "() {\n vec2 uv = uvFromFlat(").concat(texName, "TexShape[0], ").concat(texName, "TexShape[1], ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } var _b = __read(inputInfo.shapeInfo.texShape, 2), tNumR = _b[0], tNumC = _b[1]; return "\n float ".concat(funcName, "() {\n vec2 uv = uvFromFlat(").concat(tNumR, ", ").concat(tNumC, ", ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } function getPackedSampler1D(inputInfo, enableShapeUniforms) { var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var texShape = inputInfo.shapeInfo.texShape; var glsl = getGlslDifferences(); if (enableShapeUniforms) { return "\n vec4 ".concat(funcName, "(int index) {\n ivec2 packedTexShape = ivec2(ceil(float(").concat(texName, "TexShape[0]) / 2.0), ceil(float(").concat(texName, "TexShape[1]) / 2.0));\n vec2 uv = packedUVfrom1D(\n packedTexShape[0], packedTexShape[1], index);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n "); } var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; return "\n vec4 ".concat(funcName, "(int index) {\n vec2 uv = packedUVfrom1D(\n ").concat(packedTexShape[0], ", ").concat(packedTexShape[1], ", index);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n "); } function getSampler1D(inputInfo, enableShapeUniforms) { var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return "\n float ".concat(funcName, "(int index) {\n ").concat(getUniformSampler(inputInfo), "\n }\n "); } var texShape = inputInfo.shapeInfo.texShape; var tNumR = texShape[0]; var tNumC = texShape[1]; if (tNumC === 1 && tNumR === 1) { return "\n float ".concat(funcName, "(int index) {\n return sampleTexture(").concat(texName, ", halfCR);\n }\n "); } var offset = getFlatOffsetUniformName(texName); if (tNumC === 1) { if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int index) {\n vec2 uv = vec2(0.5, (float(index + ").concat(offset, ") + 0.5) / float(").concat(texName, "TexShape[0]));\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int index) {\n vec2 uv = vec2(0.5, (float(index + ").concat(offset, ") + 0.5) / ").concat(tNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } if (tNumR === 1) { if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int index) {\n vec2 uv = vec2((float(index + ").concat(offset, ") + 0.5) / float(").concat(texName, "TexShape[1]), 0.5);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int index) {\n vec2 uv = vec2((float(index + ").concat(offset, ") + 0.5) / ").concat(tNumC, ".0, 0.5);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int index) {\n vec2 uv = uvFromFlat(").concat(texName, "TexShape[0], ").concat(texName, "TexShape[1], index + ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int index) {\n vec2 uv = uvFromFlat(").concat(tNumR, ", ").concat(tNumC, ", index + ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } function getPackedSampler2D(inputInfo, enableShapeUniforms) { var shape = inputInfo.shapeInfo.logicalShape; var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var texShape = inputInfo.shapeInfo.texShape; var texNumR = texShape[0]; var texNumC = texShape[1]; var glsl = getGlslDifferences(); if (texShape != null && tf.util.arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return "\n vec4 ".concat(funcName, "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n "); } return "\n vec4 ".concat(funcName, "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n "); } if (enableShapeUniforms) { return "\n vec4 ".concat(funcName, "(int row, int col) {\n ivec2 packedTexShape = ivec2(ceil(float(").concat(texName, "TexShape[0]) / 2.0), ceil(float(").concat(texName, "TexShape[1]) / 2.0));\n int valuesPerRow = int(ceil(float(").concat(texName, "Shape[1]) / 2.0));\n vec2 uv = packedUVfrom2D(valuesPerRow, packedTexShape[0], packedTexShape[1], row, col);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n "); } var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; var valuesPerRow = Math.ceil(shape[1] / 2); return "\n vec4 ".concat(funcName, "(int row, int col) {\n vec2 uv = packedUVfrom2D(").concat(valuesPerRow, ", ").concat(packedTexShape[0], ", ").concat(packedTexShape[1], ", row, col);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n "); } function getSampler2D(inputInfo, enableShapeUniforms) { var shape = inputInfo.shapeInfo.logicalShape; var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var texShape = inputInfo.shapeInfo.texShape; if (texShape != null && tf.util.arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } var texNumR_1 = texShape[0]; var texNumC_1 = texShape[1]; return "\n float ".concat(funcName, "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(").concat(texNumC_1, ".0, ").concat(texNumR_1, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; var squeezedShape = newShape; if (squeezedShape.length < shape.length) { var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); var params = ['row', 'col']; return "\n ".concat(getSamplerFromInInfo(newInputInfo, enableShapeUniforms), "\n float ").concat(funcName, "(int row, int col) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n "); } if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return "\n float ".concat(funcName, "(int row, int col) {\n int index = round(dot(vec2(row, col), vec2(").concat(shape[1], ", 1)));\n ").concat(getUniformSampler(inputInfo), "\n }\n "); } var texNumR = texShape[0]; var texNumC = texShape[1]; var offset = getFlatOffsetUniformName(texName); if (texNumC === 1) { // index is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int row, int col) {\n float index = dot(vec3(row, col, ").concat(offset, "), vec3(").concat(texName, "Shape[1], 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / float(").concat(texName, "TexShape[0]));\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int row, int col) {\n float index = dot(vec3(row, col, ").concat(offset, "), vec3(").concat(shape[1], ", 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } if (texNumR === 1) { // index is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int row, int col) {\n float index = dot(vec3(row, col, ").concat(offset, "), vec3(").concat(texName, "Shape[1], 1, 1));\n vec2 uv = vec2((index + 0.5) / float(").concat(texName, "TexShape[1]), 0.5);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int row, int col) {\n float index = dot(vec3(row, col, ").concat(offset, "), vec3(").concat(shape[1], ", 1, 1));\n vec2 uv = vec2((index + 0.5) / ").concat(texNumC, ".0, 0.5);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(texName, "Shape[1] + col + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texName, "TexShape[0], ").concat(texName, "TexShape[1], index);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(shape[1], " + col + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texNumR, ", ").concat(texNumC, ", index);\n return sampleTexture(").concat(texName, ", uv);\n }\n"); } function getPackedSampler3D(inputInfo, enableShapeUniforms) { var shape = inputInfo.shapeInfo.logicalShape; var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var texShape = inputInfo.shapeInfo.texShape; var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; if (shape[0] === 1) { var squeezedShape = shape.slice(1); var keptDims = [1, 2]; var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); var params = ['b', 'row', 'col']; return "\n ".concat(getPackedSamplerFromInInfo(newInputInfo, enableShapeUniforms), "\n vec4 ").concat(funcName, "(int b, int row, int col) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n "); } var glsl = getGlslDifferences(); if (enableShapeUniforms) { return "\n vec4 ".concat(funcName, "(int b, int row, int col) {\n ivec2 packedTexShape = ivec2(ceil(float(").concat(texName, "TexShape[0]) / 2.0), ceil(float(").concat(texName, "TexShape[1]) / 2.0));\n int valuesPerRow = int(ceil(float(").concat(texName, "Shape[2]) / 2.0));\n int texelsInBatch = valuesPerRow * int(ceil(float(").concat(texName, "Shape[1]) / 2.0));\n vec2 uv = packedUVfrom3D(\n packedTexShape[0], packedTexShape[1], texelsInBatch, valuesPerRow, b, row, col);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n "); } var texNumR = packedTexShape[0]; var texNumC = packedTexShape[1]; var valuesPerRow = Math.ceil(shape[2] / 2); var texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2); return "\n vec4 ".concat(funcName, "(int b, int row, int col) {\n vec2 uv = packedUVfrom3D(\n ").concat(texNumR, ", ").concat(texNumC, ", ").concat(texelsInBatch, ", ").concat(valuesPerRow, ", b, row, col);\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n "); } function getSampler3D(inputInfo, enableShapeUniforms) { var shape = inputInfo.shapeInfo.logicalShape; var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var stride0 = shape[1] * shape[2]; var stride1 = shape[2]; var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; var squeezedShape = newShape; if (squeezedShape.length < shape.length) { var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); var params = ['row', 'col', 'depth']; return "\n ".concat(getSamplerFromInInfo(newInputInfo, enableShapeUniforms), "\n float ").concat(funcName, "(int row, int col, int depth) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n "); } if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return "\n float ".concat(funcName, "(int row, int col, int depth) {\n int index = round(dot(vec3(row, col, depth),\n vec3(").concat(stride0, ", ").concat(stride1, ", 1)));\n ").concat(getUniformSampler(inputInfo), "\n }\n "); } var texShape = inputInfo.shapeInfo.texShape; var texNumR = texShape[0]; var texNumC = texShape[1]; var flatOffset = inputInfo.shapeInfo.flatOffset; if (texNumC === stride0 && flatOffset == null) { // texC is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int row, int col, int depth) {\n int stride1 = ").concat(texName, "Shape[2];\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(stride1, 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int row, int col, int depth) {\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(").concat(stride1, ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } if (texNumC === stride1 && flatOffset == null) { // texR is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(").concat(texName, "Shape[1], 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(").concat(shape[1], ", 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } var offset = getFlatOffsetUniformName(texName); if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int stride0 = ").concat(texName, "Shape[1] * ").concat(texName, "Shape[2];\n int stride1 = ").concat(texName, "Shape[2];\n int index = row * stride0 + col * stride1 + depth + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texName, "TexShape[0], ").concat(texName, "TexShape[1], index);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(stride0, " + col * ").concat(stride1, " + depth + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texNumR, ", ").concat(texNumC, ", index);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } function getPackedSamplerND(inputInfo, enableShapeUniforms) { var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var glsl = getGlslDifferences(); if (enableShapeUniforms) { // TODO: support 5d and 6d return "\n vec4 ".concat(funcName, "(int b2, int b, int row, int col) {\n int valuesPerRow = int(ceil(float(").concat(texName, "Shape[3]) / 2.0));\n int texelsInBatch = valuesPerRow * int(ceil(float(").concat(texName, "Shape[2]) / 2.0));\n int index = b * texelsInBatch + (row / 2) * valuesPerRow + (col / 2);\n texelsInBatch *= ").concat(texName, "Shape[1];\n index = b2 * texelsInBatch + index;\n ivec2 packedTexShape = ivec2(ceil(float(").concat(texName, "TexShape[0]) / 2.0), ceil(float(").concat(texName, "TexShape[1]) / 2.0));\n int texR = index / packedTexShape[1];\n int texC = index - texR * packedTexShape[1];\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(packedTexShape[1], packedTexShape[0]); return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n "); } var shape = inputInfo.shapeInfo.logicalShape; var rank = shape.length; var texShape = inputInfo.shapeInfo.texShape; var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; var texNumR = packedTexShape[0]; var texNumC = packedTexShape[1]; var valuesPerRow = Math.ceil(shape[rank - 1] / 2); var texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2); var params = "int b, int row, int col"; var index = "b * ".concat(texelsInBatch, " + (row / 2) * ").concat(valuesPerRow, " + (col / 2)"); for (var b = 2; b < rank - 1; b++) { params = "int b".concat(b, ", ") + params; texelsInBatch *= shape[rank - b - 1]; index = "b".concat(b, " * ").concat(texelsInBatch, " + ") + index; } return "\n vec4 ".concat(funcName, "(").concat(params, ") {\n int index = ").concat(index, ";\n int texR = index / ").concat(texNumC, ";\n int texC = index - texR * ").concat(texNumC, ";\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(").concat(texNumC, ", ").concat(texNumR, ");\n return ").concat(glsl.texture2D, "(").concat(texName, ", uv);\n }\n "); } function getSampler4D(inputInfo, enableShapeUniforms) { var shape = inputInfo.shapeInfo.logicalShape; var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var stride2 = shape[3]; var stride1 = shape[2] * stride2; var stride0 = shape[1] * stride1; var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; if (newShape.length < shape.length) { var newInputInfo = squeezeInputInfo(inputInfo, newShape); var params = ['row', 'col', 'depth', 'depth2']; return "\n ".concat(getSamplerFromInInfo(newInputInfo, enableShapeUniforms), "\n float ").concat(funcName, "(int row, int col, int depth, int depth2) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n "); } if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n int index = round(dot(vec4(row, col, depth, depth2),\n vec4(").concat(stride0, ", ").concat(stride1, ", ").concat(stride2, ", 1)));\n ").concat(getUniformSampler(inputInfo), "\n }\n "); } var flatOffset = inputInfo.shapeInfo.flatOffset; var texShape = inputInfo.shapeInfo.texShape; var texNumR = texShape[0]; var texNumC = texShape[1]; var stride2Str = "int stride2 = ".concat(texName, "Shape[3];"); var stride1Str = "int stride1 = ".concat(texName, "Shape[2] * stride2;"); var stride0Str = "int stride0 = ".concat(texName, "Shape[1] * stride1;"); if (texNumC === stride0 && flatOffset == null) { // texC is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n ").concat(stride2Str, "\n ").concat(stride1Str, "\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(stride1, stride2, 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(").concat(stride1, ", ").concat(stride2, ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } if (texNumC === stride2 && flatOffset == null) { // texR is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(").concat(texName, "Shape[1] * ").concat(texName, "Shape[2], ").concat(texName, "Shape[2], 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texName, "TexShape[1], ").concat(texName, "TexShape[0]);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(").concat(shape[1] * shape[2], ", ").concat(shape[2], ", 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } var offset = getFlatOffsetUniformName(texName); if (enableShapeUniforms) { return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n ").concat(stride2Str, "\n ").concat(stride1Str, "\n ").concat(stride0Str, "\n int index = row * stride0 + col * stride1 +\n depth * stride2 + depth2;\n vec2 uv = uvFromFlat(").concat(texName, "TexShape[0], ").concat(texName, "TexShape[1], index + ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(stride0, " + col * ").concat(stride1, " +\n depth * ").concat(stride2, " + depth2;\n vec2 uv = uvFromFlat(").concat(texNumR, ", ").concat(texNumC, ", index + ").concat(offset, ");\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } function getSampler5D(inputInfo) { var shape = inputInfo.shapeInfo.logicalShape; var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var stride3 = shape[4]; var stride2 = shape[3] * stride3; var stride1 = shape[2] * stride2; var stride0 = shape[1] * stride1; var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; if (newShape.length < shape.length) { var newInputInfo = squeezeInputInfo(inputInfo, newShape); var params = ['row', 'col', 'depth', 'depth2', 'depth3']; return "\n ".concat(getSamplerFromInInfo(newInputInfo), "\n float ").concat(funcName, "(int row, int col, int depth, int depth2, int depth3) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n "); } if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2, int depth3) {\n float index = dot(\n vec4(row, col, depth, depth2),\n vec4(").concat(stride0, ", ").concat(stride1, ", ").concat(stride2, ", ").concat(stride3, ")) +\n depth3;\n ").concat(getUniformSampler(inputInfo), "\n }\n "); } var flatOffset = inputInfo.shapeInfo.flatOffset; var texShape = inputInfo.shapeInfo.texShape; var texNumR = texShape[0]; var texNumC = texShape[1]; if (texNumC === stride0 && flatOffset == null) { // texC is used directly as physical (no risk of float16 overflow). return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2, int depth3) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(").concat(stride1, ", ").concat(stride2, ", ").concat(stride3, ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } if (texNumC === stride3 && flatOffset == null) { // texR is used directly as physical (no risk of float16 overflow). return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2, int depth3) {\n float texR = dot(\n vec4(row, col, depth, depth2),\n vec4(").concat(shape[1] * shape[2] * shape[3], ",\n ").concat(shape[2] * shape[3], ", ").concat(shape[3], ", 1));\n int texC = depth3;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } var offset = getFlatOffsetUniformName(texName); return "\n float ".concat(funcName, "(int row, int col, int depth, int depth2, int depth3) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(stride0, " + col * ").concat(stride1, " + depth * ").concat(stride2, " +\n depth2 * ").concat(stride3, " + depth3 + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texNumR, ", ").concat(texNumC, ", index);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } function getSampler6D(inputInfo) { var shape = inputInfo.shapeInfo.logicalShape; var texName = inputInfo.name; var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; if (newShape.length < shape.length) { var newInputInfo = squeezeInputInfo(inputInfo, newShape); var params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4']; return "\n ".concat(getSamplerFromInInfo(newInputInfo), "\n float ").concat(funcName, "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n return ").concat(funcName, "(").concat(getSqueezedParams(params, keptDims), ");\n }\n "); } var stride4 = shape[5]; var stride3 = shape[4] * stride4; var stride2 = shape[3] * stride3; var stride1 = shape[2] * stride2; var stride0 = shape[1] * stride1; if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return "\n float ".concat(funcName, "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int index = round(dot(\n vec4(row, col, depth, depth2),\n vec4(").concat(stride0, ", ").concat(stride1, ", ").concat(stride2, ", ").concat(stride3, ")) +\n dot(\n vec2(depth3, depth4),\n vec2(").concat(stride4, ", 1)));\n ").concat(getUniformSampler(inputInfo), "\n }\n "); } var flatOffset = inputInfo.shapeInfo.flatOffset; var texShape = inputInfo.shapeInfo.texShape; var texNumR = texShape[0]; var texNumC = texShape[1]; if (texNumC === stride0 && flatOffset == null) { // texC is used directly as physical (no risk of float16 overflow). return "\n float ".concat(funcName, "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(").concat(stride1, ", ").concat(stride2, ", ").concat(stride3, ", ").concat(stride4, ")) +\n float(depth4);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } if (texNumC === stride4 && flatOffset == null) { // texR is used directly as physical (no risk of float16 overflow). return "\n float ".concat(funcName, "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n float texR = dot(vec4(row, col, depth, depth2),\n vec4(").concat(shape[1] * shape[2] * shape[3] * shape[4], ",\n ").concat(shape[2] * shape[3] * shape[4], ",\n ").concat(shape[3] * shape[4], ",\n ").concat(shape[4], ")) + float(depth3);\n int texC = depth4;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(").concat(texNumC, ".0, ").concat(texNumR, ".0);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } var offset = getFlatOffsetUniformName(texName); return "\n float ".concat(funcName, "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * ").concat(stride0, " + col * ").concat(stride1, " + depth * ").concat(stride2, " +\n depth2 * ").concat(stride3, " + depth3 * ").concat(stride4, " + depth4 + ").concat(offset, ";\n vec2 uv = uvFromFlat(").concat(texNumR, ", ").concat(texNumC, ", index);\n return sampleTexture(").concat(texName, ", uv);\n }\n "); } function getUniformSampler(inputInfo) { var texName = inputInfo.name; var inSize = tf.util.sizeFromShape(inputInfo.shapeInfo.logicalShape); if (inSize < 2) { return "return ".concat(texName, ";"); } return "\n for (int i = 0; i < ".concat(inSize, "; i++) {\n if (i == index) {\n return ").concat(texName, "[i];\n }\n }\n "); } function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) { var texName = inputInfo.name; var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); var funcName = 'get' + texFuncSnippet + 'AtOutCoords'; var inRank = inputInfo.shapeInfo.logicalShape.length; var outRank = outShapeInfo.logicalShape.length; var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); var type = getCoordsDataType(outRank); var rankDiff = outRank - inRank; var coordsSnippet; var fields = ['x', 'y', 'z', 'w', 'u', 'v']; if (inRank === 0) { coordsSnippet = ''; } else if (outRank < 2 && broadcastDims.length >= 1) { coordsSnippet = 'coords = 0;'; } else { coordsSnippet = broadcastDims.map(function (d) { return "coords.".concat(fields[d + rankDiff], " = 0;"); }) .join('\n'); } var unpackedCoordsSnippet = ''; if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape .map(function (s, i) { return "coords.".concat(fields[i + rankDiff]); }) .join(', '); } var output = "return outputValue;"; var inSize = tf.util.sizeFromShape(inputInfo.shapeInfo.logicalShape); var isInputScalar = inSize === 1; var outSize = tf.util.sizeFromShape(outShapeInfo.logicalShape); var isOutputScalar = outSize === 1; if (inRank === 1 && !isInputScalar && !isOutputScalar) { output = "\n return vec4(outputValue.xy, outputValue.xy);\n "; } else if (isInputScalar && !isOutputScalar) { if (outRank === 1) { output = "\n return vec4(outputValue.x, outputValue.x, 0., 0.);\n "; } else { output = "\n return vec4(outputValue.x);\n "; } } else if (broadcastDims.length) { var rows = inRank - 2; var cols = inRank - 1; if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) { output = "return vec4(outputValue.x);"; } else if (broadcastDims.indexOf(rows) > -1) { output = "return vec4(outputValue.x, outputValue.y, " + "outputValue.x, outputValue.y);"; } else if (broadcastDims.indexOf(cols) > -1) { output = "return vec4(outputValue.xx, outputValue.zz);"; } } return "\n vec4 ".concat(funcName, "() {\n ").concat(type, " coords = getOutputCoords();\n ").concat(coordsSnippet, "\n vec4 outputValue = get").concat(texFuncSnippet, "(").concat(unpackedCoordsSnippet, ");\n ").concat(output, "\n }\n "); } function getSamplerAtOutputCoords(inputInfo, outShapeInfo) { var texName = inputInfo.name; var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); var funcName = 'get' + texFuncSnippet + 'AtOutCoords'; var outTexShape = outShapeInfo.texShape; var inTexShape = inputInfo.shapeInfo.texShape; var inRank = inputInfo.shapeInfo.logicalShape.length; var outRank = outShapeInfo.logicalShape.length; if (!inputInfo.shapeInfo.isUniform && inRank === outRank && inputInfo.shapeInfo.flatOffset == null && tf.util.arraysEqual(inTexShape, outTexShape)) { return "\n float ".concat(funcName, "() {\n return sampleTexture(").concat(texName, ", resultUV);\n }\n "); } var type = getCoordsDataType(outRank); var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); var rankDiff = outRank - inRank; var coordsSnippet; var fields = ['x', 'y', 'z', 'w', 'u', 'v']; if (inRank === 0) { coordsSnippet = ''; } else if (outRank < 2 && broadcastDims.length >= 1) { coordsSnippet = 'coords = 0;'; } else { coordsSnippet = broadcastDims.map(function (d) { return "coords.".concat(fields[d + rankDiff], " = 0;"); }) .join('\n'); } var unpackedCoordsSnippet = ''; if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape .map(function (s, i) { return "coords.".concat(fields[i + rankDiff]); }) .join(', '); } return "\n float ".concat(funcName, "() {\n ").concat(type, " coords = getOutputCoords();\n ").concat(coordsSnippet, "\n return get").concat(texFuncSnippet, "(").concat(unpackedCoordsSnippet, ");\n }\n "); } function getCoordsDataType(rank) { if (rank <= 1) { return 'int'; } else if (rank === 2) { return 'ivec2'; } else if (rank === 3) { return 'ivec3'; } else if (rank === 4) { return 'ivec4'; } else if (rank === 5) { return 'ivec5'; } else if (rank === 6) { return 'ivec6'; } else { throw Error("GPU for rank ".concat(rank, " is not yet supported")); } } function getUniformInfoFromShape(isPacked, shape, texShape) { var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; var rank = shape.length; var useSqueezePackedShape = isPacked && rank === 3 && shape[0] === 1; var squeezeShape = useSqueezePackedShape ? shape.slice(1) : newShape; var useSqueezeShape = (!isPacked && rank > 1 && !tf.util.arraysEqual(shape, texShape) && newShape.length < rank) || useSqueezePackedShape; var uniformShape = useSqueezeShape ? squeezeShape : shape; return { useSqueezeShape: useSqueezeShape, uniformShape: uniformShape, keptDims: keptDims }; } /** Returns a new input info (a copy) that has a squeezed logical shape. */ function squeezeInputInfo(inInfo, squeezedShape) { // Deep copy. var newInputInfo = JSON.parse(JSON.stringify(inInfo)); newInputInfo.shapeInfo.logicalShape = squeezedShape; return newInputInfo; } function getSqueezedParams(params, keptDims) { return keptDims.map(function (d) { return params[d]; }).join(', '); } function compileProgram(gpgpu, program, inputs, output) { var inputInfos = inputs.map(function (input, i) { var shapeInfo = { logicalShape: input.shape, texShape: input.isUniform ? null : input.texData.texShape, isUniform: input.isUniform, isPacked: input.isUniform ? false : input.texData.isPacked, flatOffset: null }; if (input.texData != null && input.texData.slice != null && input.texData.slice.flatOffset > 0) { shapeInfo.flatOffset = input.texData.slice.flatOffset; } return { name: program.variableNames[i], shapeInfo: shapeInfo }; }); var inShapeInfos = inputInfos.map(function (x) { return x.shapeInfo; }); var outShapeInfo = { logicalShape: output.shape, texShape: output.texData.texShape, isUniform: false, isPacked: output.texData.isPacked, flatOffset: null }; var source = makeShader(inputInfos, outShapeInfo, program); var fragmentShader = createFragmentShader(gpgpu.gl, source); var webGLProgram = gpgpu.createProgram(fragmentShader); if (!tf.env().get('ENGINE_COMPILE_ONLY')) { gpgpu.buildVao(webGLProgram); return Object.assign({ program: program, fragmentShader: fragmentShader, source: source, webGLProgram: webGLProgram, inShapeInfos: inShapeInfos, outShapeInfo: outShapeInfo }, getUniformLocations(gpgpu, program, webGLProgram)); } else { return { program: program, fragmentShader: fragmentShader, source: source, webGLProgram: webGLProgram, inShapeInfos: inShapeInfos, outShapeInfo: outShapeInfo, variablesLocations: null, customUniformLocations: null, infLoc: null, nanLoc: null, outShapeLocation: null, outShapeStridesLocation: null, outTexShapeLocation: null }; } } function getUniformLocations(gpgpu, program, webGLProgram) { var e_1, _a, e_2, _b; var variablesLocations = []; var customUniformLocations = []; var outShapeLocation; var outTexShapeLocation; var outShapeStridesLocation; var infLoc = null; var nanLoc = null; // Add special uniforms (NAN, INFINITY) nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false); if (tf.env().getNumber('WEBGL_VERSION') === 1) { infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false); } // Add user-defined uniforms var shouldThrow = false; try { for (var _c = __values(program.variableNames), _d = _c.next(); !_d.done; _d = _c.next()) { var varName = _d.value; var varLocs = { name: varName, uniform: gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow), offset: gpgpu.getUniformLocation(webGLProgram, "offset".concat(varName), shouldThrow), }; if (program.enableShapeUniforms) { varLocs.shape = gpgpu.getUniformLocation(webGLProgram, "".concat(varName, "Shape"), shouldThrow); varLocs.texShape = gpgpu.getUniformLocation(webGLProgram, "".concat(varName, "TexShape"), shouldThrow); } variablesLocations.push(varLocs); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_d && !_d.done && (_a = _c.return)) _a.call(_c); } finally { if (e_1) throw e_1.error; } } if (program.enableShapeUniforms) { outShapeLocation = gpgpu.getUniformLocation(webGLProgram, 'outShape', shouldThrow); outShapeStridesLocation = gpgpu.getUniformLocation(webGLProgram, 'outShapeStrides', shouldThrow); outTexShapeLocation = gpgpu.getUniformLocation(webGLProgram, 'outTexShape', shouldThrow); } if (program.customUniforms) { try { for (var _e = __values(program.customUniforms), _f = _e.next(); !_f.done; _f = _e.next()) { var d = _f.value; customUniformLocations.push(gpgpu.getUniformLocation(webGLProgram, d.name, shouldThrow)); } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (_f && !_f.done && (_b = _e.return)) _b.call(_e); } finally { if (e_2) throw e_2.error; } } } return { variablesLocations: variablesLocations, customUniformLocations: customUniformLocations, infLoc: infLoc, nanLoc: nanLoc, outShapeLocation: outShapeLocation, outShapeStridesLocation: outShapeStridesLocation, outTexShapeLocation: outTexShapeLocation }; } function validateBinaryAndProgram(shapeInfos, inputs) { if (shapeInfos.length !== inputs.length) { throw Error("Binary was compiled with ".concat(shapeInfos.length, " inputs, but ") + "was executed with ".concat(inputs.length, " inputs")); } shapeInfos.forEach(function (s, i) { var shapeA = s.logicalShape; var input = inputs[i]; var shapeB = input.shape; if (!tf.util.arraysEqual(shapeA, shapeB)) { throw Error("Binary was compiled with different shapes than " + "the current args. Shapes ".concat(shapeA, " and ").concat(shapeB, " must match")); } // The input is uploaded as uniform. if (s.isUniform && input.isUniform) { return; } var texShapeA = s.texShape; var texShapeB = input.isUniform ? null : input.texData.texShape; if (!tf.util.arraysEqual(texShapeA, texShapeB)) { throw Error("Binary was compiled with different texture shapes than the" + " current args. Shape ".concat(texShapeA, " and ").concat(texShapeB, " must match")); } }); } function runProgram(gpgpu, binary, inputs, output, customUniformValues) { if (!binary.program.enableShapeUniforms) { validateBinaryAndProgram(binary.inShapeInfos, inputs); validateBinaryAndProgram([binary.outShapeInfo], [output]); } var outTex = output.texData.texture; var outTexShape = output.texData.texShape; if (output.texData.isPacked) { gpgpu.setOutputPackedMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]); } else { gpgpu.setOutputMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]); } gpgpu.setProgram(binary.webGLProgram); gpgpu.bindVertexArray(binary.webGLProgram.vao); // Set special uniforms (NAN, INFINITY) if (tf.env().getNumber('WEBGL_VERSION') === 1) { if (binary.infLoc !== null) { gpgpu.gl.uniform1f(binary.infLoc, Infinity); } } if (binary.nanLoc !== null) { gpgpu.gl.uniform1f(binary.nanLoc, NaN); } // Set user-defined inputs for (var i = 0; i < inputs.length; ++i) { var input = inputs[i]; var _a = binary.variablesLocations[i], varLoc = _a.uniform, varOffsetLoc = _a.offset, varShapeLoc = _a.shape, varTexShapeLoc = _a.texShape; if (varShapeLoc) { var uniformShape = getUniformInfoFromShape(binary.program.packedInputs, input.shape, input.texData.texShape).uniformShape; switch (uniformShape.length) { case 1: gpgpu.gl.uniform1iv(varShapeLoc, new Int32Array(uniformShape)); break; case 2: gpgpu.gl.uniform2iv(varShapeLoc, new Int32Array(uniformShape)); break; case 3: gpgpu.gl.uniform3iv(varShapeLoc, new Int32Array(uniformShape)); break; case 4: gpgpu.gl.uniform4iv(varShapeLoc, new Int32Array(uniformShape)); break; } } if (varTexShapeLoc) { gpgpu.gl.uniform2i(varTexShapeLoc, input.texData.texShape[0], input.texData.texShape[1]); } if (varLoc == null) { // The compiler inferred that this variable is not used in this shader. continue; } if (input.isUniform) { // Upload the values of the tensor as uniform. if (tf.util.sizeFromShape(input.shape) < 2) { gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]); } else { var vals = input.uniformValues; if (!(vals instanceof Float32Array)) { vals = new Float32Array(vals); } gpgpu.gl.uniform1fv(varLoc, vals); } continue; } // If the input was sliced, upload the flat offset index. if (input.texData.slice != null && varOffsetLoc != null) { gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset); } gpgpu.setInputMatrixTexture(input.texData.texture.texture, varLoc, i); } var outShapeLoc = binary.outShapeLocation; if (outShapeLoc) { switch (output.shape.length) { case 1: gpgpu.gl.uniform1iv(outShapeLoc, new Int32Array(output.shape)); break; case 2: gpgpu.gl.uniform2iv(outShapeLoc, new Int32Array(output.shape)); break; case 3: gpgpu.gl.uniform3iv(outShapeLoc, new Int32Array(output.shape)); break; case 4: gpgpu.gl.uniform4iv(outShapeLoc, new Int32Array(output.shape)); break; } } if (binary.outShapeStridesLocation) { var strides = tf.util.computeStrides(output.shape); switch (output.shape.length) { case 2: gpgpu.gl.uniform1iv(binary.outShapeStridesLocation, new Int32Array(strides)); break; case 3: gpgpu.gl.uniform2iv(binary.outShapeStridesLocation, new Int32Array(strides)); break; case 4: gpgpu.gl.uniform3iv(binary.outShapeStridesLocation, new Int32Array(strides)); break; } } if (binary.outTexShapeLocation) { gpgpu.gl.uniform2i(binary.outTexShapeLocation, output.texData.texShape[0], output.texData.texShape[1]); } if (binary.program.customUniforms && customUniformValues) { for (var i = 0; i < binary.program.customUniforms.length; ++i) { var d = binary.program.customUniforms[i]; var customLoc = binary.customUniformLocations[i]; var customValue = customUniformValues[i]; if (d.type === 'float') { gpgpu.gl.uniform1fv(customLoc, customValue); } else if (d.type === 'vec2') { gpgpu.gl.uniform2fv(customLoc, customValue); } else if (d.type === 'vec3') { gpgpu.gl.uniform3fv(customLoc, customValue); } else if (d.type === 'vec4') { gpgpu.gl.uniform4fv(customLoc, customValue); } else if (d.type === 'int') { gpgpu.gl.uniform1iv(customLoc, customValue); } else if (d.type === 'ivec2') { gpgpu.gl.uniform2iv(customLoc, customValue); } else if (d.type === 'ivec3') { gpgpu.gl.uniform3iv(customLoc, customValue); } else if (d.type === 'ivec4') { gpgpu.gl.uniform4iv(customLoc, customValue); } else { throw Error("uniform type ".concat(d.type, " is not supported yet.")); } } } gpgpu.executeProgram(); } function makeShaderKey(program, inputs, output) { var keyInputs = ''; inputs.concat(output).forEach(function (x) { var hasOffset = x.texData != null && x.texData.slice != null && x.texData.slice.flatOffset > 0; // TODO: Remove the condition of !x.isUniform. if (program.enableShapeUniforms && !x.isUniform) { var xTexShape = x.texData.texShape; var _a = getUniformInfoFromShape(program.packedInputs, x.shape, xTexShape), useSqueezeShape = _a.useSqueezeShape, uniformShape = _a.uniformShape, keptDims = _a.keptDims; var rank1 = '', rank2 = '', rank34 = ''; if (uniformShape.length === 1 && program.packedInputs) { var packedTexShape = [Math.ceil(xTexShape[0] / 2), Math.ceil(xTexShape[1] / 2)]; rank1 = "".concat(packedTexShape[0] > 1, "_").concat(packedTexShape[1] > 1); } else if (uniformShape.length === 2 && !program.packedInputs) { rank2 = "".concat(uniformShape[0] > 1, "_").concat(uniformShape[1] > 1); } else if (uniformShape.length > 2 && !program.packedInputs) { var strides = tf.util.computeStrides(uniformShape); rank34 = "".concat(strides[0] === xTexShape[1], "_").concat(strides[strides.length - 1] === xTexShape[1]); } var xRank = x.shape.length; var isLogicalShapTexShapeEqual = uniformShape.length === 2 && tf.util.arraysEqual(x.shape, xTexShape); var isScalar = tf.util.sizeFromShape(x.shape) === 1; var broadcastDims = tf.backend_util.getBroadcastDims(x.shape, output.shape); var isInOutTexShapeEqual = !program.packedInputs && xRank === output.shape.length && tf.util.arraysEqual(xTexShape, output.texData.texShape); var isTexShapeGreaterThanOne = program.packedInputs || uniformShape.length > 2 ? '' : "".concat(xTexShape[0] > 1, "_").concat(xTexShape[1] > 1); // These key components are needed due to shader_compiler is embedding // them in the shader. // |xRank| is used to determine the coords length. See // get[Packed]SamplerAtOutputCoords. // |isInOutTexShapeEqual| is used to determine whether going to an // optimization path in getSamplerAtOutputCoords. // |useSqueezeShape| is extracted from squeezeInputInfo of // getSampler[2|3|4]D/getPackedSampler3D. // |isScalar| is extracted from isInputScalar/isOutputScalar in // getPackedSamplerAtOutputCoords. // |broadcastDims| is extracted from get[Packed]SamplerAtOutputCoords. // |isLogicalShapTexShapeEqual| is used in // getOutput[Packed]2DCoords/get[Packed]Sampler2D. // |rank1| is used in getOutputPacked1DCoords. // |rank2| is used in getOutput2DCoords. // |rank34| is used in getSampler3D/getSampler4D. // |isTexShapeGreaterThanOne| are used in // getSampler[Scalar|1D|2D]/getOutput1DCoords. keyInputs += "".concat(xRank, "_").concat(isInOutTexShapeEqual, "_").concat(useSqueezeShape ? keptDims : '', "_").concat(uniformShape.length, "_").concat(isScalar, "_").concat(broadcastDims, "_").concat(isLogicalShapTexShapeEqual, "_").concat(rank1, "_").concat(rank2, "_").concat(rank34, "_").concat(isTexShapeGreaterThanOne, "_").concat(hasOffset); } else { var texShape = x.isUniform ? 'uniform' : x.texData.texShape; keyInputs += "".concat(x.shape, "_").concat(texShape, "_").concat(hasOffset); } }); var keyUserCode = program.userCode; var key = program.constructor.name; // Fast string concat. See https://jsperf.com/string-concatenation/14. key += '_' + keyInputs + '_' + keyUserCode + "".concat(tf.env().getNumber('WEBGL_VERSION')); return key; } function useShapeUniforms(rank) { // TODO: Remove the limitaion of rank <= 4. return tf.env().getBool('WEBGL_USE_SHAPES_UNIFORMS') && rank <= 4; } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var DecodeMatrixProgram = /** @class */ (function () { function DecodeMatrixProgram(outputShape) { this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; this.outPackingScheme = PackingScheme.DENSE; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; var glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n ".concat(this.enableShapeUniforms ? getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape), "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));\n int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getA(rc.x, rc.y, rc.z);\n }\n\n ").concat(glsl.output, " = result;\n }\n "); } return DecodeMatrixProgram; }()); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var DecodeMatrixPackedProgram = /** @class */ (function () { function DecodeMatrixPackedProgram(outputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outPackingScheme = PackingScheme.DENSE; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; var glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n ".concat(this.enableShapeUniforms ? getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape), "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));\n int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));\n }\n\n ").concat(glsl.output, " = result;\n }\n "); } return DecodeMatrixPackedProgram; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var EncodeFloatProgram = /** @class */ (function () { function EncodeFloatProgram(outputShape) { this.variableNames = ['A']; this.outTexUsage = TextureUsage.DOWNLOAD; var glsl = getGlslDifferences(); this.outputShape = outputShape; this.userCode = "\n ".concat(ENCODE_FLOAT_SNIPPET, "\n\n void main() {\n float x = getAAtOutCoords();\n ").concat(glsl.output, " = encode_float(x);\n }\n "); } return EncodeFloatProgram; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var EncodeFloatPackedProgram = /** @class */ (function () { function EncodeFloatPackedProgram(outputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = false; this.outTexUsage = TextureUsage.DOWNLOAD; var glsl = getGlslDifferences(); this.outputShape = outputShape; this.userCode = "\n ".concat(ENCODE_FLOAT_SNIPPET, "\n\n void main() {\n ivec3 coords = getOutputCoords();\n float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));\n ").concat(glsl.output, " = encode_float(x);\n }\n "); } return EncodeFloatPackedProgram; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var CHANNEL_CHAR_TO_INDEX_MAP = { 'R': 0, 'G': 1, 'B': 2, 'A': 3 }; var EncodeMatrixProgram = /** @class */ (function () { function EncodeMatrixProgram(outputShape, inputIsUnsignedByte, usedChannels) { if (inputIsUnsignedByte === void 0) { inputIsUnsignedByte = false; } if (usedChannels === void 0) { usedChannels = 'RGBA'; } this.variableNames = ['A']; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; var glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); var output = "result"; if (inputIsUnsignedByte) { output = "floor(result * 255. + 0.5)"; } var mainLoop = ''; for (var usedChannelIndex = 0; usedChannelIndex < usedChannels.length; usedChannelIndex++) { var curChannel = usedChannels[usedChannelIndex]; mainLoop += "\n if(offset == ".concat(usedChannelIndex, ") {\n result = values[").concat(CHANNEL_CHAR_TO_INDEX_MAP[curChannel], "];\n }"); } this.userCode = "\n ".concat(this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape), "\n\n void main() {\n ivec3 coords = getOutputCoords();\n int flatIndex = getFlatIndex(coords);\n float result = 0.;\n int offset = imod(flatIndex, ").concat(usedChannels.length, ");\n\n flatIndex = idiv(flatIndex, ").concat(usedChannels.length, ", 1.);\n\n int r = flatIndex / texShape[1];\n if (r < texShape[0]) {\n int c = imod(flatIndex, texShape[1]);\n vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);\n vec4 values = ").concat(glsl.texture2D, "(A, uv);\n ").concat(mainLoop, "\n }\n ").concat(glsl.output, " = vec4(").concat(output, ", 0., 0., 0.);\n }\n "); } return EncodeMatrixProgram; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /* This is how the shader encodes a tensor with shape = [2, 3, 5] (indices are [batch, row, col]). 000|001 002|003 004|xxx 020|021 022|023 024|xxx ------- ------- ------- ------- ------- ------- 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx 100|101 102|103 104|xxx 120|121 122|123 124|xxx ------- ------- ------- ------- ------- ------- 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx Single texels contain only values from the same batch, and from adjacent rows and columns. */ var EncodeMatrixPackedProgram = /** @class */ (function () { function EncodeMatrixPackedProgram(outputShape, inputIsUnsignedByte) { if (inputIsUnsignedByte === void 0) { inputIsUnsignedByte = false; } this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; var glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); var mainLoop = ''; var output = 'result'; if (inputIsUnsignedByte) { output = 'floor(result * 255. + 0.5)'; } for (var row = 0; row <= 1; row++) { for (var col = 0; col <= 1; col++) { var channel = row * 2 + col; mainLoop += "\n localCoords = coords;\n if(localCoords[2] + ".concat(col, " < ").concat(this.enableShapeUniforms ? 'outShape[2]' : "".concat(outputShape[2]), ") {\n localCoords[2] += ").concat(col, ";\n if (localCoords[1] + ").concat(row, " < ").concat(this.enableShapeUniforms ? 'outShape[1]' : "".concat(outputShape[1]), ") {\n localCoords[1] += ").concat(row, ";\n\n flatIndex = getFlatIndex(localCoords);\n offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n int r = flatIndex / texShape[1];\n int c = imod(flatIndex, texShape[1]);\n vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);\n values = ").concat(glsl.texture2D, "(A, uv);\n\n if (offset == 0) {\n result[").concat(channel, "] = values[0];\n } else if (offset == 1) {\n result[").concat(channel, "] = values[1];\n } else if (offset == 2) {\n result[").concat(channel, "] = values[2];\n } else {\n result[").concat(channel, "] = values[3];\n }\n }\n }\n "); } } this.userCode = "\n ".concat(this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape), "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n vec4 result = vec4(0.);\n int flatIndex, r, c, offset;\n ivec3 localCoords;\n vec2 uv;\n vec4 values;\n\n ").concat(mainLoop, "\n\n ").concat(glsl.output, " = ").concat(output, ";\n }\n "); } return EncodeMatrixPackedProgram; }()); function createVertexShader(gl) { var glsl = getGlslDifferences(); var vertexShaderSource = "".concat(glsl.version, "\n precision highp float;\n ").concat(glsl.attribute, " vec3 clipSpacePos;\n ").concat(glsl.attribute, " vec2 uv;\n ").concat(glsl.varyingVs, " vec2 resultUV;\n\n void main() {\n gl_Position = vec4(clipSpacePos, 1);\n resultUV = uv;\n }"); return createVertexShader$1(gl, vertexShaderSource); } function createVertexBuffer(gl) { // [x y z u v] * [upper-left, lower-left, upper-right, lower-right] var vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]); return createStaticVertexBuffer(gl, vertexArray); } function createIndexBuffer(gl) { // OpenGL (and WebGL) have "CCW == front" winding var triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]); return createStaticIndexBuffer(gl, triangleVertexIndices); } function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) { validateTextureSize(width, height); var texture = createTexture(gl); var tex2d = gl.TEXTURE_2D; callAndCheck(gl, function () { return gl.bindTexture(tex2d, texture); }); callAndCheck(gl, function () { return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); }); callAndCheck(gl, function () { return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); }); callAndCheck(gl, function () { return gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST); }); callAndCheck(gl, function () { return gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST); }); if (tf.env().getNumber('WEBGL_VERSION') === 1) { callAndCheck(gl, function () { return gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null); }); } else { callAndCheck(gl, function () { return gl .texStorage2D(tex2d, 1, internalFormat, width, height); }); } callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); return { texture: texture, texShape: [height, width] }; } function getInternalFormatForFloat32MatrixTexture(textureConfig) { return textureConfig.internalFormatFloat; } function createFloat32MatrixTexture(gl, rows, columns, textureConfig) { var _a = __read(getUnpackedMatrixTextureShapeWidthHeight(rows, columns), 2), width = _a[0], height = _a[1]; return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT); } function getInternalFormatForFloat16MatrixTexture(textureConfig) { return textureConfig.internalFormatHalfFloat; } function createFloat16MatrixTexture(gl, rows, columns, textureConfig) { var _a = __read(getUnpackedMatrixTextureShapeWidthHeight(rows, columns), 2), width = _a[0], height = _a[1]; return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat); } function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) { return textureConfig.downloadTextureFormat; } function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) { var _a = __read(getUnpackedMatrixTextureShapeWidthHeight(rows, columns), 2), width = _a[0], height = _a[1]; return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE); } function getInternalFormatForPackedMatrixTexture(textureConfig) { return textureConfig.internalFormatPackedFloat; } function createPackedMatrixTexture(gl, rows, columns, textureConfig) { var _a = __read(getPackedMatrixTextureShapeWidthHeight(rows, columns), 2), width = _a[0], height = _a[1]; return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT); } function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) { return textureConfig.internalFormatPackedHalfFloat; } function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) { var _a = __read(getPackedMatrixTextureShapeWidthHeight(rows, columns), 2), width = _a[0], height = _a[1]; return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat); } function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) { var posOffset = 0; // x is the first buffer element var uvOffset = 3 * 4; // uv comes after [x y z] var stride = (3 * 4) + (2 * 4); // xyz + uv, each entry is 4-byte float. callAndCheck(gl, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer); }); var success = bindVertexBufferToProgramAttribute(gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset); return success && bindVertexBufferToProgramAttribute(gl, program, 'uv', vertexBuffer, 2, stride, uvOffset); } function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) { callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); }); var dataForUpload, texelDataType, internalFormat; if (data instanceof Uint8Array) { dataForUpload = new Uint8Array(width * height * 4); texelDataType = gl.UNSIGNED_BYTE; internalFormat = gl.RGBA; } else { dataForUpload = new Float32Array(width * height * 4); texelDataType = gl.FLOAT; internalFormat = textureConfig.internalFormatPackedFloat; } dataForUpload.set(data); if (tf.env().getNumber('WEBGL_VERSION') === 2) { callAndCheck(gl, function () { return gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, width, height, gl.RGBA, texelDataType, dataForUpload); }); } else { callAndCheck(gl, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload); }); } callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); } function uploadPixelDataToTexture(gl, texture, pixels) { callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); }); if (pixels.data instanceof Uint8Array) { if (tf.env().getNumber('WEBGL_VERSION') === 2) { callAndCheck(gl, function () { return gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, pixels.width, pixels.height, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data); }); } else { callAndCheck(gl, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data); }); } } else { if (tf.env().getNumber('WEBGL_VERSION') === 2) { callAndCheck(gl, function () { return gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels); }); } else { callAndCheck(gl, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels); }); } } callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); } function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) { // Create and bind the buffer. var buffer = gl2.createBuffer(); callAndCheck(gl2, function () { return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); }); // Initialize the buffer to the size of the texture in bytes. var bytesPerFloat = 4; var valuesPerTexel = 4; var bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns; callAndCheck(gl2, function () { return gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ); }); // Enqueue a command on the GPU command queue to copy of texture into the // buffer. callAndCheck(gl2, function () { return gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0); }); callAndCheck(gl2, function () { return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); }); return buffer; } function downloadFloat32MatrixFromBuffer(gl, buffer, size) { var gl2 = gl; var downloadTarget = new Float32Array(size); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); return downloadTarget; } function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) { var _a = __read(getUnpackedMatrixTextureShapeWidthHeight(rows, columns), 2), w = _a[0], h = _a[1]; var numChannels = 4; var downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels)); callAndCheck(gl, function () { return gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget); }); // By wrapping the buffer in a Float32Array, we use native browser IEEE 754 // decoding of the 4 bytes that back each 32 bit float. return new Float32Array(downloadTarget.buffer); } function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) { var gl2 = gl; var downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols)); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); return downloadTarget; } function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) { var packedRGBA = new Float32Array(physicalRows * physicalCols * 4); callAndCheck(gl, function () { return gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA); }); return packedRGBA; } var gpgpu_util = { __proto__: null, bindVertexProgramAttributeStreams: bindVertexProgramAttributeStreams, createBufferFromOutputTexture: createBufferFromOutputTexture, createFloat16MatrixTexture: createFloat16MatrixTexture, createFloat16PackedMatrixTexture: createFloat16PackedMatrixTexture, createFloat32MatrixTexture: createFloat32MatrixTexture, createIndexBuffer: createIndexBuffer, createPackedMatrixTexture: createPackedMatrixTexture, createUnsignedBytesMatrixTexture: createUnsignedBytesMatrixTexture, createVertexBuffer: createVertexBuffer, createVertexShader: createVertexShader, downloadByteEncodedFloatMatrixFromOutputTexture: downloadByteEncodedFloatMatrixFromOutputTexture, downloadFloat32MatrixFromBuffer: downloadFloat32MatrixFromBuffer, downloadMatrixFromPackedOutputTexture: downloadMatrixFromPackedOutputTexture, downloadPackedMatrixFromBuffer: downloadPackedMatrixFromBuffer, getInternalFormatForFloat16MatrixTexture: getInternalFormatForFloat16MatrixTexture, getInternalFormatForFloat16PackedMatrixTexture: getInternalFormatForFloat16PackedMatrixTexture, getInternalFormatForFloat32MatrixTexture: getInternalFormatForFloat32MatrixTexture, getInternalFormatForPackedMatrixTexture: getInternalFormatForPackedMatrixTexture, getInternalFormatForUnsignedBytesMatrixTexture: getInternalFormatForUnsignedBytesMatrixTexture, uploadDenseMatrixToTexture: uploadDenseMatrixToTexture, uploadPixelDataToTexture: uploadPixelDataToTexture }; var GPGPUContext = /** @class */ (function () { function GPGPUContext(gl) { this.outputTexture = null; this.program = null; this.disposed = false; this.itemsToPoll = []; var glVersion = tf.env().getNumber('WEBGL_VERSION'); if (gl != null) { this.gl = gl; setWebGLContext(glVersion, gl); } else { this.gl = getWebGLContext(glVersion); } gl = this.gl; if (tf.env().getNumber('WEBGL_VERSION') === 2) { var gl2_1 = gl; this.createVertexArray = function () { return callAndCheck(gl2_1, function () { return gl2_1.createVertexArray(); }); }; this.bindVertexArray = function (vao) { return callAndCheck(gl2_1, function () { return gl2_1.bindVertexArray(vao); }); }; this.deleteVertexArray = function (vao) { return callAndCheck(gl2_1, function () { return gl2_1.deleteVertexArray(vao); }); }; this.getVertexArray = function () { return callAndCheck(gl2_1, function () { return gl2_1.getParameter(gl2_1.VERTEX_ARRAY_BINDING); }); }; } else if (gl != null) { var ext_1 = gl.getExtension('OES_vertex_array_object'); if (ext_1 == null) { throw new Error('All WebGL1 implementations are expected to offer' + ' OES_vertex_array_object.'); } this.createVertexArray = function () { return callAndCheck(gl, function () { return ext_1.createVertexArrayOES(); }); }; this.bindVertexArray = function (vao) { return callAndCheck(gl, function () { return ext_1.bindVertexArrayOES(vao); }); }; this.deleteVertexArray = function (vao) { return callAndCheck(gl, function () { return ext_1.deleteVertexArrayOES(vao); }); }; this.getVertexArray = function () { return callAndCheck(gl, function () { return gl.getParameter(ext_1.VERTEX_ARRAY_BINDING_OES); }); }; } // WebGL 2.0 enables texture floats without an extension. var COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float'; var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float'; this.parallelCompilationExtension = this.gl.getExtension('KHR_parallel_shader_compile'); if (tf.env().getNumber('WEBGL_VERSION') === 1) { var TEXTURE_FLOAT = 'OES_texture_float'; var TEXTURE_HALF_FLOAT = 'OES_texture_half_float'; this.textureFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_FLOAT); if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) { this.textureHalfFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT); } else if (tf.env().get('WEBGL_FORCE_F16_TEXTURES')) { throw new Error('GL context does not support half float textures, yet the ' + 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.'); } this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT); if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) { this.colorBufferHalfFloatExtension = getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT); } else if (tf.env().get('WEBGL_FORCE_F16_TEXTURES')) { throw new Error('GL context does not support color renderable half floats, yet ' + 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.'); } } else { COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float'; if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) { this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT); } else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) { this.colorBufferHalfFloatExtension = this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT); } else { throw new Error('GL context does not support color renderable floats'); } } this.vertexBuffer = createVertexBuffer(this.gl); this.indexBuffer = createIndexBuffer(this.gl); this.framebuffer = createFramebuffer(this.gl); this.textureConfig = getTextureConfig(this.gl, this.textureHalfFloatExtension); } Object.defineProperty(GPGPUContext.prototype, "debug", { get: function () { return tf.env().getBool('DEBUG'); }, enumerable: false, configurable: true }); GPGPUContext.prototype.dispose = function () { var _this = this; if (this.disposed) { return; } if (this.program != null) { console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' + ' This is probably a resource leak, delete the program with ' + 'GPGPUContext.deleteProgram before disposing.'); } if (this.outputTexture != null) { console.warn('Disposing a GPGPUContext that still has a bound output matrix ' + 'texture. This is probably a resource leak, delete the output ' + 'matrix texture with GPGPUContext.deleteMatrixTexture before ' + 'disposing.'); } var gl = this.gl; callAndCheck(gl, function () { return gl.finish(); }); callAndCheck(gl, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, null); }); callAndCheck(gl, function () { return gl.deleteFramebuffer(_this.framebuffer); }); callAndCheck(gl, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, null); }); callAndCheck(gl, function () { return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null); }); callAndCheck(gl, function () { return gl.deleteBuffer(_this.indexBuffer); }); this.disposed = true; }; GPGPUContext.prototype.createFloat32MatrixTexture = function (rows, columns) { this.throwIfDisposed(); return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig); }; GPGPUContext.prototype.createFloat16MatrixTexture = function (rows, columns) { this.throwIfDisposed(); return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig); }; GPGPUContext.prototype.createUnsignedBytesMatrixTexture = function (rows, columns) { this.throwIfDisposed(); return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig); }; GPGPUContext.prototype.uploadPixelDataToTexture = function (texture, pixels) { this.throwIfDisposed(); uploadPixelDataToTexture(this.gl, texture, pixels); }; GPGPUContext.prototype.uploadDenseMatrixToTexture = function (texture, width, height, data) { this.throwIfDisposed(); uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig); }; GPGPUContext.prototype.createFloat16PackedMatrixTexture = function (rows, columns) { this.throwIfDisposed(); return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig); }; GPGPUContext.prototype.createPackedMatrixTexture = function (rows, columns) { this.throwIfDisposed(); return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig); }; GPGPUContext.prototype.deleteMatrixTexture = function (texture) { var _this = this; this.throwIfDisposed(); if (this.outputTexture === texture) { unbindColorTextureFromFramebuffer(this.gl, this.framebuffer); this.outputTexture = null; } callAndCheck(this.gl, function () { return _this.gl.deleteTexture(texture); }); }; GPGPUContext.prototype.downloadByteEncodedFloatMatrixFromOutputTexture = function (texture, rows, columns) { var _this = this; return this.downloadMatrixDriver(texture, function () { return downloadByteEncodedFloatMatrixFromOutputTexture(_this.gl, rows, columns, _this.textureConfig); }); }; GPGPUContext.prototype.downloadPackedMatrixFromBuffer = function (buffer, batch, rows, columns, physicalRows, physicalCols) { return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig); }; GPGPUContext.prototype.downloadFloat32MatrixFromBuffer = function (buffer, size) { return downloadFloat32MatrixFromBuffer(this.gl, buffer, size); }; GPGPUContext.prototype.createBufferFromTexture = function (texture, rows, columns) { this.bindTextureToFrameBuffer(texture); var result = createBufferFromOutputTexture(this.gl, rows, columns, this.textureConfig); this.unbindTextureToFrameBuffer(); return result; }; GPGPUContext.prototype.createAndWaitForFence = function () { var fenceContext = this.createFence(this.gl); return this.pollFence(fenceContext); }; GPGPUContext.prototype.createFence = function (gl) { var _this = this; var query; var isFencePassed; if (tf.env().getBool('WEBGL_FENCE_API_ENABLED')) { var gl2_2 = gl; var sync_1 = gl2_2.fenceSync(gl2_2.SYNC_GPU_COMMANDS_COMPLETE, 0); gl.flush(); isFencePassed = function () { var status = gl2_2.clientWaitSync(sync_1, 0, 0); return status === gl2_2.ALREADY_SIGNALED || status === gl2_2.CONDITION_SATISFIED; }; query = sync_1; } else if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { query = this.beginQuery(); this.endQuery(); isFencePassed = function () { return _this.isQueryAvailable(query, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); }; } else { // If we have no way to fence, return true immediately. This will fire in // WebGL 1.0 when there is no disjoint query timer. In this case, because // the fence passes immediately, we'll immediately ask for a download of // the texture, which will cause the UI thread to hang. isFencePassed = function () { return true; }; } return { query: query, isFencePassed: isFencePassed }; }; GPGPUContext.prototype.downloadMatrixFromPackedTexture = function (texture, physicalRows, physicalCols) { var _this = this; return this.downloadMatrixDriver(texture, function () { return downloadMatrixFromPackedOutputTexture(_this.gl, physicalRows, physicalCols); }); }; GPGPUContext.prototype.createProgram = function (fragmentShader) { var _this = this; this.throwIfDisposed(); var gl = this.gl; if (this.vertexShader == null) { this.vertexShader = createVertexShader(gl); } var program = createProgram(gl); callAndCheck(gl, function () { return gl.attachShader(program, _this.vertexShader); }); callAndCheck(gl, function () { return gl.attachShader(program, fragmentShader); }); linkProgram(gl, program); var program2 = Object.assign(program, { vao: this.createVertexArray() }); if (this.debug) { validateProgram(gl, program2); } return program2; }; GPGPUContext.prototype.buildVao = function (program) { var _this = this; this.setProgram(program); this.bindVertexArray(program.vao); var gl = this.gl; // Bind index buffer, and vertex buffers based on program attrib // locations. callAndCheck(gl, function () { return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, _this.indexBuffer); }); bindVertexProgramAttributeStreams(gl, program, this.vertexBuffer); }; GPGPUContext.prototype.deleteProgram = function (program) { var _this = this; this.throwIfDisposed(); if (program === this.program) { this.program = null; } if (program != null) { callAndCheck(this.gl, function () { return _this.gl.deleteProgram(program); }); this.deleteVertexArray(program.vao); } }; GPGPUContext.prototype.setProgram = function (program) { var _this = this; this.throwIfDisposed(); this.program = program; if (this.program != null) { if (this.debug) { validateProgram(this.gl, this.program); } } callAndCheck(this.gl, function () { return _this.gl.useProgram(program); }); }; GPGPUContext.prototype.getUniformLocation = function (program, uniformName, shouldThrow) { if (shouldThrow === void 0) { shouldThrow = true; } this.throwIfDisposed(); if (shouldThrow) { return getProgramUniformLocationOrThrow(this.gl, program, uniformName); } else { return getProgramUniformLocation(this.gl, program, uniformName); } }; GPGPUContext.prototype.getAttributeLocation = function (program, attribute) { var _this = this; this.throwIfDisposed(); return callAndCheck(this.gl, function () { return _this.gl.getAttribLocation(program, attribute); }); }; GPGPUContext.prototype.getUniformLocationNoThrow = function (program, uniformName) { this.throwIfDisposed(); return this.gl.getUniformLocation(program, uniformName); }; GPGPUContext.prototype.setInputMatrixTexture = function (inputMatrixTexture, uniformLocation, textureUnit) { this.throwIfDisposed(); this.throwIfNoProgram(); bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit); }; GPGPUContext.prototype.setOutputMatrixTexture = function (outputMatrixTexture, rows, columns) { this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows); }; GPGPUContext.prototype.setOutputPackedMatrixTexture = function (outputPackedMatrixTexture, rows, columns) { this.throwIfDisposed(); var _a = __read(getPackedMatrixTextureShapeWidthHeight(rows, columns), 2), width = _a[0], height = _a[1]; this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height); }; GPGPUContext.prototype.setOutputMatrixWriteRegion = function (startRow, numRows, startColumn, numColumns) { this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows); }; GPGPUContext.prototype.setOutputPackedMatrixWriteRegion = function (startRow, numRows, startColumn, numColumns) { throw new Error('setOutputPackedMatrixWriteRegion not implemented.'); }; GPGPUContext.prototype.debugValidate = function () { if (this.program != null) { validateProgram(this.gl, this.program); } validateFramebuffer(this.gl); }; GPGPUContext.prototype.executeProgram = function () { this.throwIfDisposed(); this.throwIfNoProgram(); var gl = this.gl; if (this.debug) { var boundVao = this.getVertexArray(); console.assert(boundVao === this.program.vao, 'VAO changed between setProgram and executeProgram!'); this.debugValidate(); } callAndCheck(gl, function () { return gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0); }); }; GPGPUContext.prototype.blockUntilAllProgramsCompleted = function () { var _this = this; this.throwIfDisposed(); callAndCheck(this.gl, function () { return _this.gl.finish(); }); }; GPGPUContext.prototype.getQueryTimerExtension = function () { if (this.disjointQueryTimerExtension == null) { this.disjointQueryTimerExtension = getExtensionOrThrow(this.gl, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ? 'EXT_disjoint_timer_query_webgl2' : 'EXT_disjoint_timer_query'); } return this.disjointQueryTimerExtension; }; GPGPUContext.prototype.getQueryTimerExtensionWebGL2 = function () { return this.getQueryTimerExtension(); }; GPGPUContext.prototype.getQueryTimerExtensionWebGL1 = function () { return this.getQueryTimerExtension(); }; GPGPUContext.prototype.beginQuery = function () { if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { var gl2 = this.gl; var ext_2 = this.getQueryTimerExtensionWebGL2(); var query_1 = gl2.createQuery(); gl2.beginQuery(ext_2.TIME_ELAPSED_EXT, query_1); return query_1; } var ext = this.getQueryTimerExtensionWebGL1(); var query = ext.createQueryEXT(); ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query); return query; }; GPGPUContext.prototype.endQuery = function () { if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { var gl2 = this.gl; var ext_3 = this.getQueryTimerExtensionWebGL2(); gl2.endQuery(ext_3.TIME_ELAPSED_EXT); return; } var ext = this.getQueryTimerExtensionWebGL1(); ext.endQueryEXT(ext.TIME_ELAPSED_EXT); }; GPGPUContext.prototype.waitForQueryAndGetTime = function (query) { return __awaiter(this, void 0, void 0, function () { var _this = this; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, tf.util.repeatedTry(function () { return _this.disposed || // while testing contexts are created / disposed // in rapid succession, so without this check we // may poll for the query timer indefinitely _this.isQueryAvailable(query, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); })]; case 1: _a.sent(); return [2 /*return*/, this.getQueryTime(query, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'))]; } }); }); }; GPGPUContext.prototype.getQueryTime = function (query, queryTimerVersion) { if (queryTimerVersion === 0) { return null; } if (queryTimerVersion === 2) { var gl2 = this.gl; var timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT); // Return milliseconds. return timeElapsedNanos / 1000000; } else { var ext = this.getQueryTimerExtensionWebGL1(); var timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT); // Return milliseconds. return timeElapsedNanos / 1000000; } }; GPGPUContext.prototype.isQueryAvailable = function (query, queryTimerVersion) { if (queryTimerVersion === 0) { return true; } if (queryTimerVersion === 2) { var gl2 = this.gl; var ext = this.getQueryTimerExtensionWebGL2(); var available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE); if (this.disjoint == null) { this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); } return available && !this.disjoint; } else { var ext = this.getQueryTimerExtensionWebGL1(); var available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT); if (this.disjoint == null) { this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); } return available && !this.disjoint; } }; GPGPUContext.prototype.pollFence = function (fenceContext) { var _this = this; return new Promise(function (resolve) { _this.addItemToPoll(function () { return fenceContext.isFencePassed(); }, function () { return resolve(); }); }); }; GPGPUContext.prototype.pollItems = function () { // Find the last query that has finished. var index = linearSearchLastTrue(this.itemsToPoll.map(function (x) { return x.isDoneFn; })); for (var i = 0; i <= index; ++i) { var resolveFn = this.itemsToPoll[i].resolveFn; resolveFn(); } this.itemsToPoll = this.itemsToPoll.slice(index + 1); }; GPGPUContext.prototype.addItemToPoll = function (isDoneFn, resolveFn) { var _this = this; this.itemsToPoll.push({ isDoneFn: isDoneFn, resolveFn: resolveFn }); if (this.itemsToPoll.length > 1) { // We already have a running loop that polls. return; } // Start a new loop that polls. var scheduleFn = undefined; if ('setTimeoutCustom' in tf.env().platform) { scheduleFn = tf.env().platform.setTimeoutCustom.bind(tf.env().platform); } tf.util.repeatedTry(function () { _this.pollItems(); // End the loop if no more items to poll. return _this.itemsToPoll.length === 0; }, function () { return 0; }, null, scheduleFn); }; GPGPUContext.prototype.bindTextureToFrameBuffer = function (texture) { this.throwIfDisposed(); bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer); if (this.debug) { validateFramebuffer(this.gl); } }; GPGPUContext.prototype.unbindTextureToFrameBuffer = function () { if (this.outputTexture != null) { bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer); if (this.debug) { validateFramebuffer(this.gl); } } else { unbindColorTextureFromFramebuffer(this.gl, this.framebuffer); } }; GPGPUContext.prototype.downloadMatrixDriver = function (texture, downloadAndDecode) { this.bindTextureToFrameBuffer(texture); var result = downloadAndDecode(); this.unbindTextureToFrameBuffer(); return result; }; GPGPUContext.prototype.setOutputMatrixTextureDriver = function (outputMatrixTextureMaybePacked, width, height) { this.throwIfDisposed(); var gl = this.gl; bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer); if (this.debug) { validateFramebuffer(gl); } this.outputTexture = outputMatrixTextureMaybePacked; callAndCheck(gl, function () { return gl.viewport(0, 0, width, height); }); callAndCheck(gl, function () { return gl.scissor(0, 0, width, height); }); }; GPGPUContext.prototype.setOutputMatrixWriteRegionDriver = function (x, y, width, height) { var _this = this; this.throwIfDisposed(); callAndCheck(this.gl, function () { return _this.gl.scissor(x, y, width, height); }); }; GPGPUContext.prototype.throwIfDisposed = function () { if (this.disposed) { throw new Error('Attempted to use disposed GPGPUContext.'); } }; GPGPUContext.prototype.throwIfNoProgram = function () { if (this.program == null) { throw new Error('No GPU program is currently set.'); } }; return GPGPUContext; }()); /** * Finds the index of the last true element using linear search. * Note: We can't do binary search because Chrome expects us to explicitly * test all fences before download: * https://github.com/tensorflow/tfjs/issues/1145 */ function linearSearchLastTrue(arr) { var i = 0; for (; i < arr.length; ++i) { var isDone = arr[i](); if (!isDone) { break; } } return i - 1; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function simpleAbsImpl(vals) { var resultValues = new Float32Array(vals.length); for (var i = 0; i < vals.length; ++i) { resultValues[i] = Math.abs(vals[i]); } return resultValues; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Template that creates implementation for binary ops. Supports broadcast. */ function createSimpleBinaryKernelImpl(op) { return function (aShape, bShape, aVals, bVals, dtype) { var newShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape); var resultRank = newShape.length; var resultStrides = tf.util.computeStrides(newShape); var resultSize = tf.util.sizeFromShape(newShape); var result = tf.util.getTypedArrayFromDType(dtype, resultSize); var aRank = aShape.length; var bRank = bShape.length; var aStrides = tf.util.computeStrides(aShape); var bStrides = tf.util.computeStrides(bShape); var aBroadcastDims = tf.backend_util.getBroadcastDims(aShape, newShape); var bBroadcastDims = tf.backend_util.getBroadcastDims(bShape, newShape); if (aBroadcastDims.length + bBroadcastDims.length === 0) { for (var i = 0; i < result.length; ++i) { result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]); } } else { var _loop_1 = function (i) { var loc = tf.util.indexToLoc(i, resultRank, resultStrides); var aLoc = loc.slice(-aRank); aBroadcastDims.forEach(function (d) { return aLoc[d] = 0; }); var aIndex = tf.util.locToIndex(aLoc, aRank, aStrides); var bLoc = loc.slice(-bRank); bBroadcastDims.forEach(function (d) { return bLoc[d] = 0; }); var bIndex = tf.util.locToIndex(bLoc, bRank, bStrides); result[i] = op(aVals[aIndex], bVals[bIndex]); }; for (var i = 0; i < result.length; ++i) { _loop_1(i); } } return [result, newShape]; }; } function castImpl(values, shape, inputType, dtype) { if (dtype === 'int32') { var resultValues = Int32Array.from(values); return [shape, 'int32', resultValues]; } if (dtype === 'bool') { // This is essentially the result of notEqual(x, 0). We avoid using // kernel notEqual to avoid circular dependency, i.e. binary_utils -> // cast -> notEqual -> binary_utils. var zero = tf.util.toTypedArray([0], inputType); var _a = __read(createSimpleBinaryKernelImpl(function (a, b) { return (a !== b) ? 1 : 0; })(shape, [], values, zero, 'bool'), 2), resultData = _a[0], resultShape = _a[1]; return [resultShape, 'bool', resultData]; } throw new Error("Error in Cast: failed to cast ".concat(inputType, " to ").concat(dtype)); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var addImpl = createSimpleBinaryKernelImpl((function (a, b) { return a + b; })); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function bincountImpl(xVals, weightsVals, weightsDtype, weightsShape, size) { var weightsSize = tf.util.sizeFromShape(weightsShape); var outVals = tf.util.makeZerosTypedArray(size, weightsDtype); for (var i = 0; i < xVals.length; i++) { var value = xVals[i]; if (value < 0) { throw new Error('Input x must be non-negative!'); } if (value >= size) { continue; } if (weightsSize > 0) { outVals[value] += weightsVals[i]; } else { outVals[value] += 1; } } return outVals; } function bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput) { if (binaryOutput === void 0) { binaryOutput = false; } var numRows = xBuf.shape[0]; var numCols = xBuf.shape[1]; var outBuf = tf.buffer([numRows, size], weightsBuf.dtype); for (var i = 0; i < numRows; i++) { for (var j = 0; j < numCols; j++) { var value = xBuf.get(i, j); if (value < 0) { throw new Error('Input x must be non-negative!'); } if (value >= size) { continue; } if (binaryOutput) { outBuf.set(1, i, value); } else { if (weightsBuf.size > 0) { outBuf.set(outBuf.get(i, value) + weightsBuf.get(i, j), i, value); } else { outBuf.set(outBuf.get(i, value) + 1, i, value); } } } } return outBuf; } /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var bitwiseAndImpl = createSimpleBinaryKernelImpl((function (a, b) { return a & b; })); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Template that creates implementation for unary op. */ function createSimpleUnaryImpl(op) { return function (values, dtype, attrs) { var newValues = tf.util.getArrayFromDType(dtype, values.length); for (var i = 0; i < values.length; ++i) { newValues[i] = op(values[i], attrs); } return newValues; }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ceilImpl = createSimpleUnaryImpl(function (xi) { return Math.ceil(xi); }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function concatImpl$1(inputs, outShape, dtype, simplyConcat) { var outVals = tf.util.getArrayFromDType(dtype, tf.util.sizeFromShape(outShape)); if (simplyConcat && dtype !== 'string') { // Use built-in TypedArray.set() method for speed. var offset_1 = 0; inputs.forEach(function (input) { var size = tf.util.sizeFromShape(input.shape); outVals.set(input.vals, offset_1); offset_1 += size; }); } else { var colOffset_1 = 0; inputs.forEach(function (input) { var decodedData = dtype === 'string' ? tf.backend_util.fromUint8ToStringArray(input.vals) : input.vals; var tIdx = 0; for (var row = 0; row < input.shape[0]; ++row) { var resIdx = row * outShape[1] + colOffset_1; for (var col = 0; col < input.shape[1]; ++col) { outVals[resIdx + col] = decodedData[tIdx++]; } } colOffset_1 += input.shape[1]; }); } return outVals; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var equalImpl = createSimpleBinaryKernelImpl(function (a, b) { return (a === b) ? 1 : 0; }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var expImpl = createSimpleUnaryImpl(function (xi) { return Math.exp(xi); }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var expm1Impl = createSimpleUnaryImpl(function (xi) { return Math.expm1(xi); }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var floorImpl = createSimpleUnaryImpl(function (xi) { return Math.floor(xi); }); function gatherNdImpl(indicesData, paramsBuf, dtype, numSlices, sliceRank, sliceSize, strides, paramsShape, paramsSize) { var outBuf = tf.buffer([numSlices, sliceSize], dtype); for (var i = 0; i < numSlices; i++) { var index = []; var flattenIndex = 0; for (var j = 0; j < sliceRank; j++) { var dim = indicesData[i * sliceRank + j]; flattenIndex += dim * strides[j]; index.push(dim); } if (flattenIndex < 0 || flattenIndex >= paramsSize / sliceSize) { throw new Error("Invalid indices: ".concat(index, " does not index into ").concat(paramsShape)); } for (var k = 0; k < sliceSize; k++) { outBuf.values[i * sliceSize + k] = paramsBuf.get.apply(paramsBuf, __spreadArray([], __read(paramsBuf.indexToLoc(flattenIndex * sliceSize + k)), false)); } } return outBuf; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function gatherV2Impl(xBuf, indicesBuf, flattenOutputShape) { var outBuf = tf.buffer(flattenOutputShape, xBuf.dtype); for (var i = 0; i < outBuf.size; ++i) { var newLoc = outBuf.indexToLoc(i); var originalLoc = newLoc.slice(); var batchIdx = originalLoc[0]; var indicesIdx = originalLoc[2]; var indicesIndex = indicesBuf.locToIndex([batchIdx, indicesIdx]); originalLoc[2] = indicesBuf.values[indicesIndex]; var originalIndex = xBuf.locToIndex(originalLoc); if (0 <= originalIndex && originalIndex < xBuf.values.length) { outBuf.values[i] = xBuf.values[originalIndex]; } // Else, index is out of bounds, so leave the default zero val in outBuf. } return outBuf; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var greaterImpl = createSimpleBinaryKernelImpl(function (a, b) { return (a > b) ? 1 : 0; }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var greaterEqualImpl = createSimpleBinaryKernelImpl(function (a, b) { return (a >= b) ? 1 : 0; }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var lessImpl = createSimpleBinaryKernelImpl(function (a, b) { return (a < b) ? 1 : 0; }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var lessEqualImpl = createSimpleBinaryKernelImpl(function (a, b) { return (a <= b) ? 1 : 0; }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function linSpaceImpl(start, stop, num) { var step = (stop - start) / (num - 1); var values = tf.util.makeZerosTypedArray(num, 'float32'); values[0] = start; for (var i = 1; i < values.length; i++) { values[i] = values[i - 1] + step; } return values; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var logImpl = createSimpleUnaryImpl(function (xi) { return Math.log(xi); }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxImpl$1(aVals, reduceSize, outShape, dtype) { var vals = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(outShape)); for (var i = 0; i < vals.length; ++i) { var offset = i * reduceSize; var max = aVals[offset]; for (var j = 0; j < reduceSize; ++j) { var value = aVals[offset + j]; if (Number.isNaN(value) || value > max) { // comparison with NaN always return false max = value; } } vals[i] = max; } return vals; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var maximumImpl = createSimpleBinaryKernelImpl((function (aValue, bValue) { return Math.max(aValue, bValue); })); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var minimumImpl = createSimpleBinaryKernelImpl((function (aValue, bValue) { return Math.min(aValue, bValue); })); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var multiplyImpl = createSimpleBinaryKernelImpl((function (aValue, bValue) { return aValue * bValue; })); function negImpl(xVals, xShape, xDtype) { var minusOne = tf.util.createScalarValue(-1, xDtype); return multiplyImpl([], xShape, minusOne, xVals, xDtype); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var notEqualImpl = createSimpleBinaryKernelImpl((function (a, b) { return (a !== b) ? 1 : 0; })); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function transposeImpl$1(xVals, xShape, dtype, perm, newShape) { var xRank = xShape.length; var xSize = tf.util.sizeFromShape(xShape); var xStrides = tf.util.computeStrides(xShape); var newStrides = tf.util.computeStrides(newShape); var result = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(newShape)); for (var i = 0; i < xSize; ++i) { var loc = tf.util.indexToLoc(i, xRank, xStrides); // Permute location. var newLoc = new Array(loc.length); for (var i_1 = 0; i_1 < newLoc.length; i_1++) { newLoc[i_1] = loc[perm[i_1]]; } var newIndex = tf.util.locToIndex(newLoc, xRank, newStrides); result[newIndex] = xVals[i]; } return result; } function prodImpl(xShape, xDtype, xVals, reductionAxes) { var _a = __read(tf.backend_util.computeOutAndReduceShapes(xShape, reductionAxes), 2), outShape = _a[0], reduceShape = _a[1]; var outDtype = tf.upcastType(xDtype, 'int32'); var outVals = tf.util.makeZerosTypedArray(tf.util.sizeFromShape(outShape), outDtype); var reduceSize = tf.util.sizeFromShape(reduceShape); for (var i = 0; i < outVals.length; ++i) { var offset = i * reduceSize; var prod_1 = 1; for (var j = 0; j < reduceSize; ++j) { prod_1 *= xVals[offset + j]; } outVals[i] = prod_1; } return { outVals: outVals, outShape: outShape, outDtype: outDtype }; } function validateIndices(indices, indicesShape, numParams) { indices.forEach(function (index, i) { if (index < 0 || index >= numParams) { var locString = tf.util.indexToLoc(i, indicesShape.length, tf.util.computeStrides(indicesShape)) .join(','); throw new Error("indices[".concat(locString, "] = ").concat(index, " is not in [0, ").concat(numParams, ")")); } }); } function validateSplits(paramsNestedSplits, numParamsDenseValues) { // Validate for (var dim = 0; dim < paramsNestedSplits.length; ++dim) { var splits = paramsNestedSplits[dim]; var lastSplit = (dim === paramsNestedSplits.length - 1) ? numParamsDenseValues : paramsNestedSplits[dim + 1].length; if (splits.length === 0) { throw new Error('Ragged splits may not be empty'); } if (splits[0] < 0) { throw new Error('Ragged splits must be non-negative'); } if (splits[splits.length - 1] > lastSplit) { throw new Error('Ragged splits must not point past values'); } for (var i = 1; i < splits.length; ++i) { if (splits[i - 1] > splits[i]) { throw new Error('Ragged splits must be sorted in ascending order'); } } } } // Construct the `splits` output tensors, encoded using a nested vector. // Also find the slices of values that need to be copied, and store them // in `valueSlices`. The total number of values that will be copied (which // we need for allocating the output values tensor) is stored in `numValues`. function makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues) { var valueSlices = []; var numValues = 0; var numSplits = indicesShape.length - 1 + paramsNestedSplits.length; var outSplits = new Array(numSplits).fill(null).map(function () { return [0]; }); validateSplits(paramsNestedSplits, numParamsDenseValues); // Add `splits` that come from all but the last dimension of the dense // Tensor `indices`. In particular, for each dimension D, we add a // splits tensor whose values are: // range(reduceProd(splits.shape[:D]) + 1) * splits.shape[D+1] // E.g., if indices.shape=[2, 3, 4] then we will add splits tensors: // [0, 3, 6] # length=2+1, stride=3 // [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4 var nrows = 1; for (var dim = 0; dim < indicesShape.length - 1; ++dim) { nrows *= indicesShape[dim]; var rowLength = indicesShape[dim + 1]; for (var i = 1; i < nrows + 1; ++i) { outSplits[dim].push(i * rowLength); } } // Add `splits` that come from `paramsNestedSplits`. Starting with the // outermost ragged dimension (i.e., the first `splits` tensor), we work // our way in, finding the range of values that should be copied. As we // go, we update the output `splits` for each dimension with the appropriate // values. In particular, the *lengths* of the slices from `param_splits` // should be copied to generate corresponding slice lengths in the output // splits. E.g., if we are copying a ragged row with length 4, then we // should add a new split point to outSplits that is 4 greater than the // previous split point in outSplits. for (var i = 0; i < indices.length; ++i) { var start = indices[i]; var limit = indices[i] + 1; // Copy splits. for (var dim = 0; dim < paramsNestedSplits.length; ++dim) { var splits = paramsNestedSplits[dim]; var outDim = dim + indicesShape.length - 1; if (outDim >= 0) { var outSplitsOutDim = outSplits[outDim]; var delta = outSplitsOutDim[outSplitsOutDim.length - 1] - splits[start]; for (var j = start; j < limit; ++j) { outSplits[outDim].push(splits[j + 1] + delta); } } start = splits[start]; limit = splits[limit]; } if (limit !== start) { valueSlices.push([start, limit]); numValues += limit - start; } } return { outSplits: outSplits, valueSlices: valueSlices, numValues: numValues }; } function getSplits(outSplits) { var splitsOut = []; var _loop_1 = function (i) { var numSplits = outSplits[i].length; var splits = tf.util.getArrayFromDType('int32', numSplits); splitsOut.push(splits); outSplits[i].forEach(function (value, j) { return splits[j] = value; }); }; for (var i = 0; i < outSplits.length; ++i) { _loop_1(i); } return splitsOut; } function computeFlatOuterDims(orig, numOutDims) { var outDims = orig.slice(0, numOutDims); while (outDims.length < numOutDims) { outDims.push(1); } for (var inDim = numOutDims; inDim < orig.length; inDim++) { outDims[numOutDims - 1] *= orig[inDim]; } return outDims; } // For each slice in `(start, limit)` in `valueSlices`, append // `paramsDenseValues[start,...,limit] to `values`. `valueSize` indicates // the number of scalars contained in each value paramsDenseValues[i]. function writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, values, valuesShape) { var e_1, _a; var denseM = computeFlatOuterDims(paramsDenseValuesShape, 2)[1]; var valuesM = computeFlatOuterDims(valuesShape, 2)[1]; var outPos = 0; try { for (var valueSlices_1 = __values(valueSlices), valueSlices_1_1 = valueSlices_1.next(); !valueSlices_1_1.done; valueSlices_1_1 = valueSlices_1.next()) { var slice = valueSlices_1_1.value; for (var i = slice[0]; i < slice[1]; ++i) { for (var j = 0; j < valueSize; ++j) { values[outPos * valuesM + j] = paramsDenseValues[i * denseM + j]; } ++outPos; } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (valueSlices_1_1 && !valueSlices_1_1.done && (_a = valueSlices_1.return)) _a.call(valueSlices_1); } finally { if (e_1) throw e_1.error; } } } function getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues) { var valuesShape = paramsDenseValuesShape.slice(); valuesShape[0] = numValues; var valuesOut = tf.util.getArrayFromDType(paramsDenseValuesDType, tf.util.sizeFromShape(valuesShape)); var numElements = paramsDenseValues.length; var valueSize = numElements === 0 ? 0 : (numElements / paramsDenseValuesShape[0]); writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, valuesOut, valuesShape); return [valuesOut, valuesShape]; } function raggedGatherImpl(paramsNestedSplits, paramsNestedSplitsShapes, paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, indices, indicesShape, outputRaggedRank) { if (paramsNestedSplits.length === 0) { throw new Error('paramsNestedSplits must be non empty'); } if (paramsNestedSplitsShapes[0].length === 0) { throw new Error('Split tensors must not be scalars'); } var numParams = paramsNestedSplitsShapes[0][0] - 1; validateIndices(indices, indicesShape, numParams); if (paramsDenseValuesShape.length === 0) { throw new Error('params.rank must be nonzero'); } var numParamsDenseValues = paramsDenseValuesShape[0]; // Calculate the `splits`, and store the value slices that we need to // copy in `valueSlices`. var _a = makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues), outSplits = _a.outSplits, valueSlices = _a.valueSlices, numValues = _a.numValues; // Write the output tensors. var outputNestedSplits = getSplits(outSplits); var outputDenseValues = getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues); return [outputNestedSplits, outputDenseValues[0], outputDenseValues[1]]; } /** * @license * Copyright 2022 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var INT32_MAX = 2147483647; function raggedRangeImpl(starts, startsShape, startsDType, limits, limitsShape, deltas, deltasShape) { // Check input tensor shapes. if (startsShape.length > 1) { throw new Error('starts must be a scalar or vector'); } if (limitsShape.length > 1) { throw new Error('limits must be a scalar or vector'); } if (deltasShape.length > 1) { throw new Error('deltas must be a scalar or vector'); } // Determine which tensors we need to broadcast. var broadcastStarts = startsShape.length === 0; var broadcastLimits = limitsShape.length === 0; var broadcastDeltas = deltasShape.length === 0; // nRows (number of output rows) is the size of the non-broadcast inputs, // or 1 if all inputs are scalars. var inSizes = []; if (!broadcastStarts) { inSizes.push(startsShape[0]); } if (!broadcastLimits) { inSizes.push(limitsShape[0]); } if (!broadcastDeltas) { inSizes.push(deltasShape[0]); } for (var i = 1; i < inSizes.length; ++i) { if (inSizes[i] !== inSizes[i - 1]) { throw new Error('starts, limits, and deltas must have the same shape'); } } var nRows = inSizes.length === 0 ? 1 : inSizes[0]; // Construct the rtNestedSplits tensor. var rtNestedSplits = tf.util.getArrayFromDType('int32', nRows + 1); rtNestedSplits[0] = 0; for (var row = 0; row < nRows; ++row) { var start = broadcastStarts ? starts[0] : starts[row]; var limit = broadcastLimits ? limits[0] : limits[row]; var delta = broadcastDeltas ? deltas[0] : deltas[row]; if (delta === 0) { throw new Error('Requires delta != 0'); } var size = // The number of elements in the specified range. void 0; // The number of elements in the specified range. if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) { size = 0; } else { size = Math.ceil(Math.abs((limit - start) / delta)); if (size > INT32_MAX) { throw new Error("Requires ((limit - start) / delta) <= ".concat(INT32_MAX)); } } rtNestedSplits[row + 1] = rtNestedSplits[row] + size; } var nVals = rtNestedSplits[nRows]; // Construct the rtDenseValues tensor. var rtDenseValues = tf.util.getArrayFromDType(startsDType, nVals); var valueIndex = 0; for (var row = 0; row < nRows; ++row) { var rowSize = rtNestedSplits[row + 1] - rtNestedSplits[row]; var value = broadcastStarts ? starts[0] : starts[row]; var delta = broadcastDeltas ? deltas[0] : deltas[row]; for (var i = 0; i < rowSize; ++i) { rtDenseValues[valueIndex++] = value; value += delta; } } return [rtNestedSplits, rtDenseValues]; } var RowPartitionType = tf.backend_util.RowPartitionType; // Based on // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc var RaggedTensorToTensorOp = /** @class */ (function () { function RaggedTensorToTensorOp(shape, shapeShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypeStrings) { this.shape = shape; this.shapeShape = shapeShape; this.values = values; this.valuesShape = valuesShape; this.valuesDType = valuesDType; this.defaultValue = defaultValue; this.defaultValueShape = defaultValueShape; this.rowPartitionValues = rowPartitionValues; this.rowPartitionValuesShapes = rowPartitionValuesShapes; this.rowPartitionTypes = tf.backend_util.getRowPartitionTypesHelper(rowPartitionTypeStrings); this.raggedRank = tf.backend_util.getRaggedRank(this.rowPartitionTypes); } RaggedTensorToTensorOp.prototype.getRowPartitionTypeByDimension = function (dimension) { if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) { return this.rowPartitionTypes[dimension + 1]; } else { return this.rowPartitionTypes[dimension]; } }; // Returns the relationship between dimension and dimension + 1. RaggedTensorToTensorOp.prototype.getRowPartitionTensor = function (dimension) { if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) { return this.rowPartitionValues[dimension + 1]; } else { return this.rowPartitionValues[dimension]; } }; RaggedTensorToTensorOp.prototype.getMaxWidth = function (dimension) { var rowPartitionTensor = this.getRowPartitionTensor(dimension - 1); switch (this.getRowPartitionTypeByDimension(dimension - 1)) { case RowPartitionType.VALUE_ROWIDS: return RaggedTensorToTensorOp.getMaxWidthValueRowID(rowPartitionTensor); case RowPartitionType.ROW_SPLITS: return RaggedTensorToTensorOp.getMaxWidthRowSplit(rowPartitionTensor); default: throw new Error("Cannot handle partition type ".concat(RowPartitionType[this.getRowPartitionTypeByDimension(dimension - 1)])); } }; RaggedTensorToTensorOp.getMaxWidthRowSplit = function (rowSplit) { var tensorLength = rowSplit.length; if (tensorLength === 0 || tensorLength === 1) { return 0; } var maxWidth = 0; for (var i = 0; i < tensorLength - 1; ++i) { var currentWidth = rowSplit[i + 1] - rowSplit[i]; if (currentWidth > maxWidth) { maxWidth = currentWidth; } } return maxWidth; }; RaggedTensorToTensorOp.getMaxWidthValueRowID = function (valueRowIds) { var indexLength = valueRowIds.length; if (indexLength === 0) { return 0; } var firstEqualIndex = 0; var firstEqualIndexValue = valueRowIds[0]; var maxWidth = 0; for (var i = 1; i < indexLength; ++i) { var value = valueRowIds[i]; if (value !== firstEqualIndexValue) { firstEqualIndexValue = value; maxWidth = Math.max(i - firstEqualIndex, maxWidth); firstEqualIndex = i; } } return Math.max(indexLength - firstEqualIndex, maxWidth); }; RaggedTensorToTensorOp.prototype.tensorShapeFromTensor = function (t, tShape, isPartial) { if (isPartial === void 0) { isPartial = true; } if (tShape.length === 0) { if (t[0] === -1) { return []; } throw new Error("The only valid scalar shape tensor is the fully unknown shape specified as -1."); } // MakePartialShape/MakeShapeHelper. return makeShape(t, isPartial); }; RaggedTensorToTensorOp.prototype.calculateOutputSize = function (firstDim) { var valueShape = this.valuesShape; var defaultValueShape = this.defaultValueShape; tf.backend_util.validateDefaultValueShape(defaultValueShape, valueShape); var shape = this.tensorShapeFromTensor(this.shape, this.shapeShape); var outputShape = tf.backend_util.combineRaggedTensorToTensorShapes(this.raggedRank, shape, valueShape); var result = outputShape; if (result[0] < 0) { result[0] = firstDim; } for (var i = 1; i <= this.raggedRank; ++i) { if (result[i] < 0) { result[i] = this.getMaxWidth(i); } } return result; }; /** * The outputIndex represents the index in the output tensor * where the first element of a particular dimension would be written. * If it is -1, it indicates that the index is out of scope. * Example, given firstDimension = 10, firstDimensionOutput = 6, * and outputIndexMultiplier = 100: * result = [0 100 200 300 400 500 -1 -1 -1 -1] * If firstDimensionOutput = 11 instead, then: * result = [0 100 200 300 400 500 600 700 800 900] */ RaggedTensorToTensorOp.prototype.calculateFirstParentOutputIndex = function (firstDimension, outputIndexMultiplier, firstDimensionOutput) { var minDimension = Math.min(firstDimension, firstDimensionOutput); var result = []; var currentOutputIndex = 0; for (var i = 0; i < minDimension; ++i, currentOutputIndex += outputIndexMultiplier) { result.push(currentOutputIndex); } for (var i = minDimension; i < firstDimension; ++i) { result.push(-1); } tf.util.assert(result.length === firstDimension, function () { return 'Final length of result must be equal to firstDimension.'; }); return result; }; RaggedTensorToTensorOp.prototype.calculateOutputIndexRowSplit = function (rowSplit, parentOutputIndex, outputIndexMultiplier, outputSize) { var rowSplitSize = rowSplit.length; var result = []; for (var i = 0; i < rowSplitSize - 1; ++i) { var rowLength = rowSplit[i + 1] - rowSplit[i]; var realLength = Math.min(outputSize, rowLength); var parentOutputIndexCurrent = parentOutputIndex[i]; if (parentOutputIndexCurrent === -1) { realLength = 0; } for (var j = 0; j < realLength; ++j) { result.push(parentOutputIndexCurrent); parentOutputIndexCurrent += outputIndexMultiplier; } for (var j = 0; j < rowLength - realLength; ++j) { result.push(-1); } } if (rowSplitSize > 0 && result.length !== rowSplit[rowSplitSize - 1]) { throw new Error('Invalid row split size.'); } return result; }; // Calculate the output index of the first element of a list. // The parentOutputIndex is the same computation for the previous list. // -1 indicates an element or list that is out of range. // The outputIndexMultiplier is the number of output indices one moves // forward for each column. // E.g., given: // valueRowIds:[0 1 2 2 2 3 5 5 6] // parentOutputIndex:[1000 1100 2000 2100 -1 3000 4000] // outputIndexMultiplier: 10 // outputSize: 2 // You get: // result = [1000 1100 2000 2010 -1 2100 -1 -1 3000] // result[0] = parentOutputIndex[valueRowIds[0]] // result[1] = parentOutputIndex[valueRowIds[1]] // result[2] = parentOutputIndex[valueRowIds[2]] // result[3] = parentOutputIndex[valueRowIds[2] + 10] // result[4] = -1 because it is the third element the size is 2. // result[5] = parentOutputIndex[valueRowIds[3]] // result[6] = -1 because parentOutputIndex[valueRowIds[6]] == -1 // result[7] = -1 because parentOutputIndex[valueRowIds[6]] == -1 // result[8] = parentOutputIndex[valueRowIds[7]] RaggedTensorToTensorOp.prototype.calculateOutputIndexValueRowID = function (valueRowIds, parentOutputIndex, outputIndexMultiplier, outputSize) { var indexSize = valueRowIds.length; var result = []; if (indexSize === 0) { return []; } var currentOutputColumn = 0; var currentValueRowId = valueRowIds[0]; if (currentValueRowId >= parentOutputIndex.length) { throw new Error("Got currentValueRowId=".concat(currentValueRowId, ", which is not less than ").concat(parentOutputIndex.length)); } var currentOutputIndex = parentOutputIndex[currentValueRowId]; result.push(currentOutputIndex); for (var i = 1; i < indexSize; ++i) { var nextValueRowId = valueRowIds[i]; if (nextValueRowId === currentValueRowId) { if (currentOutputIndex >= 0) { ++currentOutputColumn; if (currentOutputColumn < outputSize) { currentOutputIndex += outputIndexMultiplier; } else { currentOutputIndex = -1; } } } else { currentOutputColumn = 0; currentValueRowId = nextValueRowId; if (nextValueRowId >= parentOutputIndex.length) { throw new Error("Got nextValueRowId=".concat(nextValueRowId, " which is not less than ").concat(parentOutputIndex.length)); } currentOutputIndex = parentOutputIndex[nextValueRowId]; } result.push(currentOutputIndex); } if (result.length !== valueRowIds.length) { throw new Error('Invalid row ids.'); } return result; }; RaggedTensorToTensorOp.prototype.calculateOutputIndex = function (dimension, parentOutputIndex, outputIndexMultiplier, outputSize) { var rowPartitionTensor = this.getRowPartitionTensor(dimension); var partitionType = this.getRowPartitionTypeByDimension(dimension); switch (partitionType) { case RowPartitionType.VALUE_ROWIDS: return this.calculateOutputIndexValueRowID(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize); case RowPartitionType.ROW_SPLITS: if (rowPartitionTensor.length - 1 > parentOutputIndex.length) { throw new Error("Row partition size is greater than output size: ".concat(rowPartitionTensor.length - 1, " > ").concat(parentOutputIndex.length)); } return this.calculateOutputIndexRowSplit(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize); default: throw new Error("Unsupported partition type: ".concat(RowPartitionType[partitionType])); } }; RaggedTensorToTensorOp.prototype.getFirstDimensionSize = function () { var firstPartitionTensor = this.rowPartitionValues[0]; if (this.rowPartitionTypes.length === 0) { throw new Error('No row_partition_types given.'); } var firstPartitionType = this.rowPartitionTypes[0]; switch (firstPartitionType) { case RowPartitionType.FIRST_DIM_SIZE: return firstPartitionTensor[0]; case RowPartitionType.VALUE_ROWIDS: throw new Error('Cannot handle VALUE_ROWIDS in first dimension.'); case RowPartitionType.ROW_SPLITS: return this.rowPartitionValuesShapes[0][0] - 1; default: throw new Error("Cannot handle type ".concat(RowPartitionType[firstPartitionType])); } }; RaggedTensorToTensorOp.prototype.compute = function () { var firstPartitionTensor = this.rowPartitionValues[0]; if (firstPartitionTensor.length <= 0) { throw new Error('Invalid first partition input. ' + 'Tensor requires at least one element.'); } var firstDimension = this.getFirstDimensionSize(); var outputSize = this.calculateOutputSize(firstDimension); var multiplier = new Array(this.raggedRank + 1); multiplier[multiplier.length - 1] = 1; for (var i = multiplier.length - 2; i >= 0; --i) { multiplier[i] = multiplier[i + 1] * outputSize[i + 1]; } // Full size of the tensor. var outputShape = makeShape(outputSize, false); var outputTensor = tf.util.getArrayFromDType(this.valuesDType, tf.util.sizeFromShape(outputShape)); var fullSize = multiplier[0] * outputSize[0]; if (fullSize > 0) { var outputIndex = this.calculateFirstParentOutputIndex(firstDimension, multiplier[0], outputSize[0]); for (var i = 1; i <= this.raggedRank; ++i) { var newOutputIndex = this.calculateOutputIndex(i - 1, outputIndex, multiplier[i], outputSize[i]); outputIndex = newOutputIndex; } this.setOutput(this.raggedRank, outputIndex, outputTensor, outputShape); } return [outputShape, outputTensor]; }; RaggedTensorToTensorOp.prototype.setOutput = function (raggedRank, outputIndex, outputTensor, outputShape) { if (outputTensor.length === 0) { return; } var valuesBase = this.values; var outputBase = outputTensor; var elementShape = outputShape.slice(); elementShape = elementShape.slice(raggedRank + 1); var valueElementSize = tf.util.sizeFromShape(elementShape); var outputIndexSize = outputIndex.length; // Broadcast the default value to value_element_size. (We can skip this // if defaultValueTensor.size == 1, since we use fill when that's true.) var defaultValue = this.defaultValue; if (defaultValue.length !== valueElementSize && defaultValue.length !== 1) { var srcShape_1 = this.defaultValueShape; tf.tidy(function () { var defaultValueTensor = tf.reshape(defaultValue, srcShape_1); var bCastDefault = tf.broadcastTo(defaultValueTensor, elementShape); defaultValue = bCastDefault.dataSync(); }); } // Loop through the outputIndex array, finding contiguous regions that // should be copied. Once we find the end of a contiguous region, copy it // and add any necessary padding (with defaultValue). var srcStart = 0; // Start of contiguous region (in values) var dstStart = 0; // Destination for contiguous region (in output) var dstEnd = 0; // Destination for contiguous region (in output) for (var srcI = 0; srcI <= outputIndexSize; ++srcI) { // dstI is the destination where the value at srcI should be copied. var dstI = srcI < outputIndexSize ? outputIndex[srcI] : -1; // If we're still in a contiguous region, then update dstEnd go to the // next srcI. if (dstI === dstEnd) { ++dstEnd; continue; } // We found the end of contiguous region. This can be because we found // a gap (dstI > dstEnd), or a source value that shouldn't be copied // because it's out-of-bounds (dstI == -1), or the end of the tensor // (dstI === -1). if (dstStart < dstEnd) { // Copy the contiguous region. var src = valuesBase.subarray(srcStart * valueElementSize); var dst = outputBase.subarray(dstStart * valueElementSize); var nVals = (dstEnd - dstStart) * valueElementSize; copyArray(dst, src, nVals); } // Add any necessary padding (w/ defaultValue). if (srcI >= outputIndexSize) { // We reached the end of values: pad to the end of output. var outputSize = outputTensor.length; dstI = Math.floor(outputSize / valueElementSize); } if (dstI > dstEnd) { if (this.defaultValue.length === 1) { outputBase .subarray(dstEnd * valueElementSize, dstI * valueElementSize) .fill(this.defaultValue[0]); dstEnd = dstI; } else { while (dstI > dstEnd) { var dst = outputBase.slice(dstEnd * valueElementSize); copyArray(dst, defaultValue, valueElementSize); ++dstEnd; } } } // Update indices. if (dstI < 0) { // srcI should be skipped -- leave it out of the contiguous region. srcStart = srcI + 1; dstStart = dstEnd; } else { // srcI should be copied -- include it in the contiguous region. srcStart = srcI; dstStart = dstEnd; dstEnd = dstStart + 1; } } }; return RaggedTensorToTensorOp; }()); function copyArray(dst, src, size) { for (var i = 0; i < size; i++) { dst[i] = src[i]; } } function makeShape(shape, isPartial) { var e_1, _a; var out = []; try { for (var shape_1 = __values(shape), shape_1_1 = shape_1.next(); !shape_1_1.done; shape_1_1 = shape_1.next()) { var dim = shape_1_1.value; if (dim < 0) { if (!isPartial) { throw new Error("Dimension ".concat(dim, " must be >= 0")); } if (dim < -1) { throw new Error("Dimension ".concat(dim, " must be >= -1")); } dim = -1; } out.push(dim); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (shape_1_1 && !shape_1_1.done && (_a = shape_1.return)) _a.call(shape_1); } finally { if (e_1) throw e_1.error; } } return out; } function raggedTensorToTensorImpl(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes) { return new RaggedTensorToTensorOp(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes) .compute(); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function rangeImpl(start, stop, step, dtype) { var sameStartStop = start === stop; var increasingRangeNegativeStep = start < stop && step < 0; var decreasingRangePositiveStep = stop < start && step > 1; if (sameStartStop || increasingRangeNegativeStep || decreasingRangePositiveStep) { return tf.util.makeZerosTypedArray(0, dtype); } var numElements = Math.abs(Math.ceil((stop - start) / step)); var values = tf.util.makeZerosTypedArray(numElements, dtype); if (stop < start && step === 1) { // Auto adjust the step's sign if it hasn't been set // (or was set to 1) step = -1; } values[0] = start; for (var i = 1; i < values.length; i++) { values[i] = values[i - 1] + step; } return values; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var rsqrtImpl = createSimpleUnaryImpl(function (xi) { return 1 / Math.sqrt(xi); }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function scatterImpl(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) { var flattenShape = [outputSize / sliceSize, sliceSize]; var indicesData = indices.values; var updatesData = updates.values; if (outputSize === 0) { return tf.buffer(shape, updates.dtype); } var outBuf = (defaultValue instanceof tf.TensorBuffer) ? defaultValue : tf.buffer(flattenShape, updates.dtype); if (typeof defaultValue === 'string') { outBuf.values.fill(defaultValue); } else if (typeof defaultValue === 'number') { outBuf.values.fill(defaultValue); } else if (typeof defaultValue === 'boolean') { outBuf.values.fill(+defaultValue); } for (var i = 0; i < numUpdates; i++) { var index = []; var flattenIndex = 0; for (var j = 0; j < sliceRank; j++) { var dim = indicesData[i * sliceRank + j]; index.push(dim); flattenIndex += dim * strides[j]; } if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) { throw new Error("Invalid indices: ".concat(index, " does not index into ").concat(shape)); } for (var k = 0; k < sliceSize; k++) { if (sumDupeIndices) { outBuf.values[flattenIndex * sliceSize + k] += updatesData[i * sliceSize + k]; } else { outBuf.values[flattenIndex * sliceSize + k] = updates.rank === 0 ? updatesData[0] : updatesData[i * sliceSize + k]; } } } return outBuf; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var sigmoidImpl = createSimpleUnaryImpl(function (xi) { return 1 / (1 + Math.exp(-xi)); }); function sliceImpl(vals, begin, size, shape, dtype) { var isContinous = tf.slice_util.isSliceContinous(shape, begin, size); var length = tf.util.sizeFromShape(size); var xStrides = tf.util.computeStrides(shape); if (isContinous) { var flatOffset = tf.slice_util.computeFlatOffset(begin, xStrides); if (dtype === 'string') { return vals.slice(flatOffset, flatOffset + length); } return vals.subarray(flatOffset, flatOffset + length); } var decodedData = dtype === 'string' ? tf.backend_util.fromUint8ToStringArray(vals) : vals; var inBuf = tf.buffer(shape, dtype, decodedData); var outBuf = tf.buffer(size, dtype); for (var i = 0; i < outBuf.size; ++i) { var outLoc = outBuf.indexToLoc(i); var inLoc = outLoc.map(function (idx, j) { return idx + begin[j]; }); outBuf.set.apply(outBuf, __spreadArray([inBuf.get.apply(inBuf, __spreadArray([], __read(inLoc), false))], __read(outLoc), false)); } if (dtype === 'string') { return tf.backend_util.fromStringArrayToUint8(outBuf.values); } return outBuf.values; } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) { var indicesCount = indicesShape[0]; var denseRows = denseShape[0]; var emptyRowIndicator = new Array(denseRows); var reverseIndexMap = new Array(indicesCount); var rank = indicesShape[1]; if (denseRows === 0) { if (indicesCount !== 0) { throw new Error(tf.backend_util.getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesCount)); } var outputIndices = tf.util.getArrayFromDType(indicesDType, 0); var outputValues = tf.util.getArrayFromDType(valuesDType, 0); return [ outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap ]; } var rowsAreOrdered = true; var lastIndicesRow = 0; var csrOffset = new Array(denseRows).fill(0); for (var i = 0; i < indicesCount; ++i) { // indices is a 2d tensor with shape of [N, rank] var row = indices[i * rank]; if (row < 0) { throw new Error(tf.backend_util.getSparseFillEmptyRowsNegativeIndexErrorMessage(i, row)); } if (row >= denseRows) { throw new Error(tf.backend_util.getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(i, row, denseRows)); } ++csrOffset[row]; rowsAreOrdered = rowsAreOrdered && (row >= lastIndicesRow); lastIndicesRow = row; } var allRowsFull = true; for (var row = 0; row < denseRows; ++row) { // csrOffset here describes the number of elements in this dense row var rowEmpty = (csrOffset[row] === 0); emptyRowIndicator[row] = rowEmpty; allRowsFull = allRowsFull && !rowEmpty; // In filled version, each row has at least one element. csrOffset[row] = Math.max(csrOffset[row], 1); // Update csrOffset to represent the number of elements up to and // including denseRows + 1: // csrOffset[0] == #{elements of row 0} // csrOffset[1] == #{elements of row 1} + #{elements of row 0} // .. // csrOffset[i] == starting index for elements in row i + 1. if (row > 0) { csrOffset[row] += csrOffset[row - 1]; } } if (allRowsFull && rowsAreOrdered) { var outputIndices = indices; var outputValues = values; for (var i = 0; i < indicesCount; ++i) { reverseIndexMap[i] = i; } return [ outputIndices, [indicesCount, rank], outputValues, emptyRowIndicator, reverseIndexMap ]; } else { var fullIndicesCount = csrOffset[denseRows - 1]; var outputIndices = tf.util.getArrayFromDType(indicesDType, fullIndicesCount * rank); var outputValues = tf.util.getArrayFromDType(valuesDType, fullIndicesCount); var filledCount = new Array(denseRows).fill(0); // Fill in values for rows that are not missing for (var i = 0; i < indicesCount; ++i) { // indices is a 2d tensor with shape of [N, rank] var row = indices[i * rank]; var offset = filledCount[row]; var outputI = ((row === 0) ? 0 : csrOffset[row - 1]) + offset; filledCount[row]++; // Increment the filled count for this row. for (var j = 0; j < rank; ++j) { // indices and outputIndices are 2d tensors with shape of [N, rank] outputIndices[outputI * rank + j] = indices[i * rank + j]; } outputValues[outputI] = values[i]; // We'll need this reverse index map to backprop correctly. reverseIndexMap[i] = outputI; } // Fill in values for rows that are missing for (var row = 0; row < denseRows; ++row) { var rowCount = filledCount[row]; if (rowCount === 0) { // We haven't filled this row var startingIndex = (row === 0) ? 0 : csrOffset[row - 1]; // Remaining index values were set to zero already. // Just need to set the row index in the right location. // outputIndices is a 2d tensor with shape of [N, rank] outputIndices[startingIndex * rank + 0] = row; for (var col = 1; col < rank; ++col) { outputIndices[startingIndex * rank + col] = 0; } outputValues[startingIndex] = defaultValue; } } return [ outputIndices, [fullIndicesCount, rank], outputValues, emptyRowIndicator, reverseIndexMap ]; } } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) { var denseSize = tf.util.sizeFromShape(inputShape); var nnz = inputIndicesShape[0]; var outputRank = targetShape.length; // Compute the output shape. Determine product of specified dimensions, and // find the index of the unspecified one. var outputShape = []; var product = 1; var unknownIndex = -1; for (var d = 0; d < outputRank; ++d) { var size = targetShape[d]; if (size === -1) { if (unknownIndex !== -1) { throw new Error(tf.backend_util .getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(unknownIndex, d)); } unknownIndex = d; outputShape.push(1); } else { if (size < 0) { throw new Error(tf.backend_util.getSparseReshapeNegativeOutputDimErrorMessage(d, size)); } product *= size; outputShape.push(size); } } if (unknownIndex !== -1) { if (product <= 0) { throw new Error(tf.backend_util.getSparseReshapeEmptyTensorZeroOutputDimErrorMessage()); } var missing = Math.trunc(denseSize / product); if (product * missing !== denseSize) { throw new Error(tf.backend_util.getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape)); } outputShape[unknownIndex] = missing; } var outputSize = tf.util.sizeFromShape(outputShape); if (outputSize !== denseSize) { throw new Error(tf.backend_util.getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape)); } var inputRank = inputShape.length; var inputStrides = []; if (inputRank > 0) { inputStrides[inputRank - 1] = 1; for (var d = inputRank - 2; d >= 0; --d) { inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1]; } } var outputStrides = []; if (outputRank > 0) { outputStrides[outputRank - 1] = 1; for (var d = outputRank - 2; d >= 0; --d) { outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1]; } } var newIndices = tf.util.getArrayFromDType(inputDType, nnz * outputRank); for (var i = 0; i < nnz; ++i) { var id = 0; for (var j = 0; j < inputRank; ++j) { // inputIndices is a 2d tensor with shape of [nnz, inputRank] id += inputIndices[i * inputRank + j] * inputStrides[j]; } for (var j = 0; j < outputRank; ++j) { // newIndices is a 2d tensor with shape of [nnz, outputRank] newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]); id %= outputStrides[j]; } } return [newIndices, [nnz, outputRank], outputShape]; } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds, isMean, defaultValue) { if (isMean === void 0) { isMean = false; } if (defaultValue === void 0) { defaultValue = 0; } var numIndices = indices.length; // Flatten the array to two dimensions var inputFlat = [inputShape[0], input.length / inputShape[0]]; var numCol = inputFlat[1]; // Note that the current implementation assumes that segmentIds values are // sorted. var lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0; var outputRows = lastSegmentIdPlusOne; if (outputRows < 0) { throw new Error(tf.backend_util.getSparseSegmentReductionNegativeSegmentIdsErrorMessage()); } var outputShape = inputShape.slice(); outputShape[0] = outputRows; var outputLength = outputShape.reduce(function (product, value) { return product * value; }, 1); // Output array is initialized with the value 0 by default. var output = tf.util.getArrayFromDType(inputDType, outputLength); // Note that we do not initialize the output buffer with a default value, so // we need to explicitly set missing indices to the default value. if (numIndices === 0) { if (outputRows > 0) { output.fill(defaultValue); } return [output, outputShape]; } if (outputRows <= 0) { throw new Error(tf.backend_util.getSparseSegmentReductionNegativeSegmentIdsErrorMessage()); } var start = 0, end = 1; // Index from which the output is not initialized. var uninitializedIndex = 0; var outIndex = segmentIds[start]; while (true) { // We initialize nextIndex to 0 to avoid may be uninitialized warning var nextIndex = 0; if (end < numIndices) { nextIndex = segmentIds[end]; if (outIndex === nextIndex) { ++end; continue; } // We have a new segment here. Verify that the segment ids are growing. if (outIndex >= nextIndex) { throw new Error(tf.backend_util .getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage()); } } if (outIndex < 0 || outIndex >= outputRows) { throw new Error(tf.backend_util.getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(outIndex, outputRows)); } // If there is a gap between two indices, we need to set that gap to the // default value. if (outIndex > uninitializedIndex) { output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol); } for (var i = start; i < end; ++i) { var index = indices[i]; if (index < 0 || index >= inputFlat[0]) { throw new Error(tf.backend_util.getSparseSegmentReductionIndicesOutOfRangeErrorMessage(i, indices[i], inputFlat[0])); } for (var j = 0; j < numCol; j++) { output[outIndex * numCol + j] += input[index * numCol + j]; } } if (isMean) { for (var j = 0; j < numCol; j++) { output[outIndex * numCol + j] /= end - start; } } start = end; ++end; uninitializedIndex = outIndex + 1; outIndex = nextIndex; if (end > numIndices) { break; } } // Fill the gap at the end with the default value. if (uninitializedIndex < outputRows) { output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol); } return [output, outputShape]; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var sqrtImpl = createSimpleUnaryImpl(function (xi) { return Math.sqrt(xi); }); /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var staticRegexReplaceImpl = createSimpleUnaryImpl(function (x, attrs) { var pattern = attrs.pattern, replaceGlobal = attrs.replaceGlobal, rewrite = attrs.rewrite; // TODO(mattSoulanille): Don't create a regex each time. return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite); }); function stridedSliceImpl(outShape, xBuf, strides, begin) { var outBuf = tf.buffer(outShape, xBuf.dtype); for (var i = 0; i < outBuf.size; i++) { var loc = outBuf.indexToLoc(i); var newLoc = new Array(loc.length); for (var j = 0; j < newLoc.length; j++) { newLoc[j] = loc[j] * strides[j] + begin[j]; } outBuf.set.apply(outBuf, __spreadArray([xBuf.get.apply(xBuf, __spreadArray([], __read(newLoc), false))], __read(loc), false)); } return outBuf; } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * The StringNGramsOp class creates ngrams from ragged string data. * The constructor contains all attributes related to the operation such as * padding widths and strings, and the compute function can be used to * compute the ngrams for different ragged tensor inputs. */ var StringNGramsOp = /** @class */ (function () { function StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) { this.separator = tf.util.encodeString(separator); this.nGramWidths = nGramWidths; this.leftPad = tf.util.encodeString(leftPad); this.rightPad = tf.util.encodeString(rightPad); this.padWidth = padWidth; this.preserveShort = preserveShortSequences; } StringNGramsOp.prototype.getPadWidth = function (nGramWidth) { // Ngrams can be padded with either a fixed pad width or a dynamic pad // width depending on the 'padWidth' arg, but in no case should the padding // ever be wider than 'nGramWidth' - 1. return Math.min(this.padWidth < 0 ? nGramWidth - 1 : this.padWidth, nGramWidth - 1); }; StringNGramsOp.prototype.getNumNGrams = function (length, nGramWidth) { var padWidth = this.getPadWidth(nGramWidth); return Math.max(0, ((length + 2 * padWidth) - nGramWidth) + 1); }; StringNGramsOp.prototype.createNGrams = function (data, splitIndex, output, outputStartIndex, numNGrams, nGramWidth) { var _loop_1 = function (nGramIndex) { var padWidth = this_1.getPadWidth(nGramWidth); var leftPadding = Math.max(0, padWidth - nGramIndex); var rightPadding = Math.max(0, padWidth - (numNGrams - (nGramIndex + 1))); var numTokens = nGramWidth - (leftPadding + rightPadding); var dataStartIndex = splitIndex + (leftPadding > 0 ? 0 : nGramIndex - padWidth); // Calculate the total expected size of the nGram so we can reserve the // correct amount of space in the string. var nGramSize = 0; // Size of the left padding. nGramSize += leftPadding * this_1.leftPad.length; // Size of the tokens. for (var n = 0; n < numTokens; ++n) { nGramSize += data[dataStartIndex + n].length; } // Size of the right padding. nGramSize += rightPadding * this_1.rightPad.length; // Size of the separators. var numSeparators = leftPadding + rightPadding + numTokens - 1; nGramSize += numSeparators * this_1.separator.length; // Build the nGram. output[outputStartIndex + nGramIndex] = new Uint8Array(nGramSize); var nGram = output[outputStartIndex + nGramIndex]; var nextNGramIndex = 0; var appendToNGram = function (str) { return str.forEach(function (value) { return nGram[nextNGramIndex++] = value; }); }; for (var n = 0; n < leftPadding; ++n) { appendToNGram(this_1.leftPad); appendToNGram(this_1.separator); } // Only output first numTokens - 1 pairs of data and separator for (var n = 0; n < numTokens - 1; ++n) { appendToNGram(data[dataStartIndex + n]); appendToNGram(this_1.separator); } // Handle case when there are no tokens or no right padding as these // can result in consecutive separators. if (numTokens > 0) { // If we have tokens, then output last and then pair each separator // with the right padding that follows, to ensure nGram ends either with // the token or with the right pad. appendToNGram(data[dataStartIndex + numTokens - 1]); for (var n = 0; n < rightPadding; ++n) { appendToNGram(this_1.separator); appendToNGram(this_1.rightPad); } } else { // If we don't have tokens, then the last item inserted into the nGram // has been the separator from the left padding loop above. Hence, // output right pad and separator and make sure to finish with a // padding, not a separator. for (var n = 0; n < rightPadding - 1; ++n) { appendToNGram(this_1.rightPad); appendToNGram(this_1.separator); } appendToNGram(this_1.rightPad); } }; var this_1 = this; for (var nGramIndex = 0; nGramIndex < numNGrams; ++nGramIndex) { _loop_1(nGramIndex); } }; // Data and splits together form the definition of the ragged tensor, // where data is 1 dimensional and contains the values of the tensor // and splits denotes the indices at which each row starts. StringNGramsOp.prototype.compute = function (data, splits) { var _this = this; // Validate that the splits are valid indices into data, only if there are // splits specified. var inputDataSize = data.length; var splitsSize = splits.length; if (splitsSize > 0) { var prevSplit = splits[0]; if (prevSplit !== 0) { throw new Error("First split value must be 0, got ".concat(prevSplit)); } for (var i = 1; i < splitsSize; ++i) { var validSplits = splits[i] >= prevSplit; validSplits = validSplits && (splits[i] <= inputDataSize); if (!validSplits) { throw new Error("Invalid split value ".concat(splits[i], ", must be in [").concat(prevSplit, ", ").concat(inputDataSize, "]")); } prevSplit = splits[i]; } if (prevSplit !== inputDataSize) { throw new Error("Last split value must be data size. Expected ".concat(inputDataSize, ", got ").concat(prevSplit)); } } var numBatchItems = splitsSize - 1; var nGramsSplits = tf.util.getArrayFromDType('int32', splitsSize); // If there is no data or size, return an empty ragged tensor. if (inputDataSize === 0 || splitsSize === 0) { var empty = new Array(inputDataSize); for (var i = 0; i <= numBatchItems; ++i) { nGramsSplits[i] = 0; } return [empty, nGramsSplits]; } nGramsSplits[0] = 0; var _loop_2 = function (i) { var length = splits[i] - splits[i - 1]; var numNGrams = 0; this_2.nGramWidths.forEach(function (nGramWidth) { numNGrams += _this.getNumNGrams(length, nGramWidth); }); if (this_2.preserveShort && length > 0 && numNGrams === 0) { numNGrams = 1; } nGramsSplits[i] = nGramsSplits[i - 1] + numNGrams; }; var this_2 = this; for (var i = 1; i <= numBatchItems; ++i) { _loop_2(i); } var nGrams = new Array(nGramsSplits[numBatchItems]); var _loop_3 = function (i) { var splitIndex = splits[i]; var outputStartIdx = nGramsSplits[i]; this_3.nGramWidths.forEach(function (nGramWidth) { var length = splits[i + 1] - splits[i]; var numNGrams = _this.getNumNGrams(length, nGramWidth); _this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth); outputStartIdx += numNGrams; }); // If we're preserving short sequences, check to see if no sequence was // generated by comparing the current output start idx to the original // one (nGramSplitsdata). If no ngrams were generated, then they will // be equal (since we increment outputStartIdx by numNGrams every // time we create a set of ngrams.) if (this_3.preserveShort && outputStartIdx === nGramsSplits[i]) { var dataLength = splits[i + 1] - splits[i]; // One legitimate reason to not have any ngrams when this.preserveShort // is true is if the sequence itself is empty. In that case, move on. if (dataLength === 0) { return "continue"; } // We don't have to worry about dynamic padding sizes here: if padding // was dynamic, every sequence would have had sufficient padding to // generate at least one nGram. var nGramWidth = dataLength + 2 * this_3.padWidth; var numNGrams = 1; this_3.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth); } }; var this_3 = this; for (var i = 0; i < numBatchItems; ++i) { _loop_3(i); } return [nGrams, nGramsSplits]; }; return StringNGramsOp; }()); function stringNGramsImpl(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) { return new StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) .compute(data, dataSplits); } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function split(str, delimiters, skipEmpty, result) { if (!str.length) { return; } // When the delimiter is empty, the input is split into individual characters. if (delimiters.length === 0) { for (var i = 0; i < str.length; ++i) { result.push(str.subarray(i, i + 1)); } return; } // When there is one delimiter, the input is split only at that delimiter. if (delimiters.length === 1) { var delimiter = delimiters[0]; var f = str.indexOf(delimiter); while (f !== -1) { var token = str.subarray(0, f); if (!skipEmpty || token.length !== 0) { result.push(token); } str = str.subarray(f + 1); f = str.indexOf(delimiter); } if (!skipEmpty || str.length !== 0) { result.push(str); } return; } // When there are multiple delimiters, the input is split at every instance // one of the delimiters appears. var tokenStart = 0; for (var i = 0; i < str.length + 1; i++) { if ((i === str.length) || (delimiters.indexOf(str[i]) !== -1)) { var token = str.subarray(tokenStart, i); if (!skipEmpty || token.length !== 0) { result.push(token); } tokenStart = i + 1; } } } function stringSplitImpl(input, delimiter, skipEmpty) { var batchSize = input.length; // Empty delimiter means split the input character by character. var tokens = []; var outputSize = 0; var maxNumEntries = 0; var numIndices = new Array(batchSize); for (var i = 0; i < batchSize; ++i) { var prevTokensLength = tokens.length; split(input[i], delimiter, skipEmpty, tokens); var nEntries = tokens.length - prevTokensLength; numIndices[i] = nEntries; outputSize += nEntries; maxNumEntries = Math.max(maxNumEntries, nEntries); } var indices = tf.util.getArrayFromDType('int32', outputSize * 2); var values = new Array(outputSize); var shape = [batchSize, maxNumEntries]; var c = 0; for (var i = 0; i < batchSize; ++i) { for (var j = 0; j < numIndices[i]; ++j) { // indices is a 2d tensor with shape of [outputSize, 2] indices[c * 2] = i; indices[c * 2 + 1] = j; values[c] = tokens[c]; ++c; } } return [indices, values, shape]; } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function stringToHashBucketFastImpl(input, numBuckets) { var output = tf.util.getArrayFromDType('int32', input.length); for (var i = 0; i < input.length; ++i) { output[i] = tf.util.fingerPrint64(input[i]).modulo(numBuckets).getLowBitsUnsigned(); } return output; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var subImpl = createSimpleBinaryKernelImpl((function (aValue, bValue) { return aValue - bValue; })); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * An implementation of the tile kernel shared between webgl and cpu for string * tensors only. */ function tileImpl(xBuf, reps) { var newShape = new Array(xBuf.rank); for (var i = 0; i < newShape.length; i++) { newShape[i] = xBuf.shape[i] * reps[i]; } var result = tf.buffer(newShape, xBuf.dtype); for (var i = 0; i < result.values.length; ++i) { var newLoc = result.indexToLoc(i); var originalLoc = new Array(xBuf.rank); for (var j = 0; j < originalLoc.length; j++) { originalLoc[j] = newLoc[j] % xBuf.shape[j]; } var originalIndex = xBuf.locToIndex(originalLoc); result.values[i] = xBuf.values[originalIndex]; } return result; } var comparePair = function (a, b) { var valueDiff = b.value - a.value; return valueDiff === 0 ? a.index - b.index : valueDiff; }; /** * Partitions array where all elements smaller than the (k+1) smallest element * are found to the left of it, and all larger to the right of it. * Based on the Floyd-Rivest Algorithm, ref: * https://en.wikipedia.org/wiki/Floyd%E2%80%93Rivest_algorithm * @param array: Array to partition * @param left: Left index for the interval * @param right: Right index for the interval * @param k: Desired index value, where array[k] is the (k+1)th smallest element * when left = 0 */ function select$1(array, k, left, right) { if (left === void 0) { left = 0; } if (right === void 0) { right = array.length - 1; } while (right > left) { // Use select recursively to sample a smaller set of size s // the arbitrary constants 600 and 0.5 are used in the original // version to minimize execution time. if (right - left > 600) { var n = right - left + 1; var i_1 = k - left + 1; var z = Math.log(n); var s = 0.5 * Math.exp(2 * z / 3); var sd = 0.5 * Math.sqrt(z * s * (n - s) / n) * Math.sign(i_1 - n / 2); var newLeft = Math.max(left, Math.floor(k - i_1 * s / n + sd)); var newRight = Math.min(right, Math.floor(k + (n - i_1) * s / n + sd)); select$1(array, k, newLeft, newRight); } // partition the elements between left and right around t var t = array[k]; var i = left; var j = right; tf.util.swap(array, left, k); if (comparePair(array[right], t) > 0) { tf.util.swap(array, left, right); } while (i < j) { tf.util.swap(array, i, j); i++; j--; while (comparePair(array[i], t) < 0) { i = i + 1; } while (comparePair(array[j], t) > 0) { j = j - 1; } } if (comparePair(array[left], t) === 0) { tf.util.swap(array, left, j); } else { j = j + 1; tf.util.swap(array, j, right); } // Adjust left and right towards the boundaries of the subset // containing the (k - left + 1)th smallest element. if (j <= k) { left = j + 1; } if (k <= j) { right = j - 1; } } } function topKImpl(x, xShape, xDtype, k, sorted) { // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim. var lastDim = xShape[xShape.length - 1]; var _a = __read([x.length / lastDim, lastDim], 2), batch = _a[0], size = _a[1]; var allTopKVals = tf.util.getTypedArrayFromDType(xDtype, batch * k); var allTopKIndices = tf.util.getTypedArrayFromDType('int32', batch * k); var _loop_1 = function (b) { var offset = b * size; var vals = x.subarray(offset, offset + size); var valAndInd = new Array(vals.length); vals.forEach(function (value, index) { return valAndInd[index] = { value: value, index: index }; }); if (k < valAndInd.length) { select$1(valAndInd, k); valAndInd = valAndInd.slice(0, k); } if (sorted) { valAndInd.sort(comparePair); } var outOffset = b * k; var topKVals = allTopKVals.subarray(outOffset, outOffset + k); var topKIndices = allTopKIndices.subarray(outOffset, outOffset + k); for (var i = 0; i < k; i++) { topKVals[i] = valAndInd[i].value; topKIndices[i] = valAndInd[i].index; } }; for (var b = 0; b < batch; b++) { _loop_1(b); } // Reshape back to the original input shape, except that the last // dimension is k. var outputShape = xShape.slice(); outputShape[outputShape.length - 1] = k; return [ tf.buffer(outputShape, xDtype, allTopKVals), tf.buffer(outputShape, 'int32', allTopKIndices) ]; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function uniqueImpl(values, axis, shape, dtype) { // Normalize and validate axis. var $axis = tf.util.parseAxisParam(axis, shape)[0]; // Calculate the new shape that is suitable for extracting data along the // given axis. // // The rank is 3. // The size of the 1st dimension is the size of all the axes < the given axis. // The size of the 2nd dimension is the same as the size of the given axis. // The size of the 3rd dimension is the size of all the axes > the given axis. // // For example, for a 4D tensor with shape=[2, 3, 5, 4] and axis=2, the // newShape would be: [2*3, 5, 4]. // // Note that this is not the final output shape. This will be the shape for an // intermediate TensorBuffer (see inputBuffer below) to allow us to extract // values along the given axis. To demonstrate how it works, consider the // following example: // // Input: a 3D tensor, with shape [1, 2, 3] // [ // [ // [1,2,3], // [4,5,6] // ] // ] // Axis: 2 (the last axis). // Along axis 2, we expect to extract 3 tensors: [1,4], [2,5], [3,6]. // // For this example, newShape would be: [2, 3, 1], where 2 is calculated from // 1*2. The re-shaped data would look like: // // [ // [ // [1], [2], [3] // ], // [ // [4], [5], [6] // ] // ] // // Then, we can construct a 3-level nested loop by the following dimension // order to extract the values along the axis (dimension1): // i: dimension1 // 0,1,2 (newShape[1]) // m: dimension0 // 0,1 (newShape[0]) // n: dimension2 // 0 (newShape[2]) // // m, i, n // --------- // Iteration 0: data at [0, 0, 0] => "1" // Iteration 1: data at [1, 0, 0] => "4" // We got [1,4]. // Iteration 2: data at [0, 1, 0] => "2" // Iteration 3: data at [1, 1, 0] => "5" // We got [2,5]. // Iteration 4: data at [0, 2, 0] => "3" // Iteration 5: data at [1, 2, 0] => "6" // We got [3,6]. var newShape = [1, shape[0], 1]; for (var i = 0; i < $axis; i++) { newShape[0] *= shape[i]; } newShape[1] = shape[$axis]; for (var i = $axis + 1; i < shape.length; i++) { newShape[2] *= shape[i]; } // A map from unique elements (their string representations) to their values // in "indices" (below). var uniqueElements = new Map(); // The indices of each unique element in the original tensor along the given // axis. It is 1D and has the same size as the given axis. var indices = new Int32Array(shape[$axis]); // Create a buffer so we can easily extract value at a given location. var inputBuffer = new tf.TensorBuffer(newShape, dtype, values); // The indices along the given axis that have unique elements. This is a // de-duped version of "indices" above. var uniqueIndices = []; var is1DTensor = newShape[0] === 1 && newShape[2] === 1; for (var i = 0; i < shape[$axis]; i++) { // Extract values along the axis. var element = void 0; if (is1DTensor) { // Fast path for 1D tensor input. element = values[i].toString(); } else { var axisValues = []; for (var m = 0; m < newShape[0]; m++) { for (var n = 0; n < newShape[2]; n++) { axisValues.push(inputBuffer.get(m, i, n)); } } element = axisValues.join(','); } // Dedup and update various indices. var existingIndex = uniqueElements.get(element); if (existingIndex != null) { indices[i] = existingIndex; } else { var uniqueIndex = uniqueElements.size; uniqueElements.set(element, uniqueIndex); indices[i] = uniqueIndex; uniqueIndices.push(i); } } // Now we know where each of the unique elements are located along the axis // (uniqueIndices). Extract them from input buffer and store them in the // output buffer. var outputTmpShape = newShape.slice(); outputTmpShape[1] = uniqueElements.size; var outputBuffer = new tf.TensorBuffer(outputTmpShape, dtype); uniqueIndices.forEach(function (uniqueElementIndex, i) { for (var m = 0; m < newShape[0]; m++) { for (var n = 0; n < newShape[2]; n++) { outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n); } } }); // The output shape can be calculated from the input shape with the size of // the given axis replaced by the number of unique elements along that axis. var outputShape = shape.slice(); outputShape[$axis] = outputTmpShape[1]; return { outputValues: outputBuffer.values, outputShape: outputShape, indices: indices, }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var addImplCPU = addImpl, bincountImplCPU = bincountImpl, bincountReduceImplCPU = bincountReduceImpl, bitwiseAndImplCPU = bitwiseAndImpl, castImplCPU = castImpl, ceilImplCPU = ceilImpl, concatImplCPU = concatImpl$1, equalImplCPU = equalImpl, expImplCPU = expImpl, expm1ImplCPU = expm1Impl, floorImplCPU = floorImpl, gatherNdImplCPU = gatherNdImpl, gatherV2ImplCPU = gatherV2Impl, greaterImplCPU = greaterImpl, greaterEqualImplCPU = greaterEqualImpl, lessImplCPU = lessImpl, lessEqualImplCPU = lessEqualImpl, linSpaceImplCPU = linSpaceImpl, logImplCPU = logImpl, maxImplCPU = maxImpl$1, maximumImplCPU = maximumImpl, minimumImplCPU = minimumImpl, multiplyImplCPU = multiplyImpl, negImplCPU = negImpl, notEqualImplCPU = notEqualImpl, prodImplCPU = prodImpl, raggedGatherImplCPU = raggedGatherImpl, raggedRangeImplCPU = raggedRangeImpl, raggedTensorToTensorImplCPU = raggedTensorToTensorImpl, rangeImplCPU = rangeImpl, rsqrtImplCPU = rsqrtImpl, scatterImplCPU = scatterImpl, sigmoidImplCPU = sigmoidImpl, simpleAbsImplCPU = simpleAbsImpl, sliceImplCPU = sliceImpl, sparseFillEmptyRowsImplCPU = sparseFillEmptyRowsImpl, sparseReshapeImplCPU = sparseReshapeImpl, sparseSegmentReductionImplCPU = sparseSegmentReductionImpl, sqrtImplCPU = sqrtImpl, staticRegexReplaceImplCPU = staticRegexReplaceImpl, stridedSliceImplCPU = stridedSliceImpl, stringNGramsImplCPU = stringNGramsImpl, stringSplitImplCPU = stringSplitImpl, stringToHashBucketFastImplCPU = stringToHashBucketFastImpl, subImplCPU = subImpl, tileImplCPU = tileImpl, topKImplCPU = topKImpl, transposeImplCPU = transposeImpl$1, uniqueImplCPU = uniqueImpl; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function getVecChannels(name, rank) { return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(function (d) { return "".concat(name, ".").concat(d); }); } function getChannels(name, rank) { if (rank === 1) { return [name]; } return getVecChannels(name, rank); } function getSourceCoords$2(rank, dims) { if (rank === 1) { return 'rc'; } var coords = ''; for (var i = 0; i < rank; i++) { coords += dims[i]; if (i < rank - 1) { coords += ','; } } return coords; } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var PackProgram = /** @class */ (function () { function PackProgram(outputShape) { this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; // Only input / output 3D tensors. this.outputShape = outputShape; this.rank = outputShape.length; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); if (this.rank === 0) { this.userCode = "\n void main() {\n setOutput(vec4(getA(), 0., 0., 0.));\n }\n "; } else { var channels = getChannels('rc', this.rank); var dtype = getCoordsDataType(this.rank); var outOfBoundsCondition = this.getOutOfBoundsCondition(channels); var setup = this.getSetup(channels); var output = this.getOutput(channels); this.userCode = "\n void main() {\n ".concat(dtype, " rc = getOutputCoords();\n\n if(").concat(outOfBoundsCondition, ") {\n setOutput(vec4(0));\n } else {\n ").concat(setup, "\n\n setOutput(vec4(").concat(output, "));\n }\n }\n "); } } PackProgram.prototype.getSourceCoordsArr = function (dims) { var coords = []; for (var row = 0; row <= 1; row++) { for (var col = 0; col <= 1; col++) { var coord = "".concat(row === 0 ? 'r' : 'rp1', ", ").concat(col === 0 ? 'c' : 'cp1'); for (var d = 2; d < this.rank; d++) { coord = "".concat(dims[dims.length - 1 - d], ",") + coord; } coords.push(coord); } } return coords; }; PackProgram.prototype.getOutOfBoundsCondition = function (dims) { if (this.rank === 1) { return "rc > ".concat(this.enableShapeUniforms ? 'outShape' : this.outputShape[0]); } var cond = ''; for (var i = this.rank - 2; i < this.rank; i++) { cond += "".concat(dims[i], " >= ").concat(this.enableShapeUniforms ? "outShape[".concat(i, "]") : this.outputShape[i]); if (i < this.rank - 1) { cond += '||'; } } return cond; }; PackProgram.prototype.getSetup = function (dims) { if (this.rank === 1) { return ''; } var innerDims = dims.slice(-2); var col = this.enableShapeUniforms ? "outShape[".concat(this.rank, " - 1]") : this.outputShape[this.rank - 1]; var row = this.enableShapeUniforms ? "outShape[".concat(this.rank, " - 2]") : this.outputShape[this.rank - 2]; return "\n int r = ".concat(innerDims[0], ";\n int c = ").concat(innerDims[1], ";\n int rp1 = r + 1;\n int cp1 = c + 1;\n\n bool cEdge = cp1 >= ").concat(col, ";\n bool rEdge = rp1 >= ").concat(row, ";\n "); }; PackProgram.prototype.getOutput = function (dims) { var sourceCoords = this.getSourceCoordsArr(dims); if (this.rank === 1) { var outShape = this.enableShapeUniforms ? 'outShape' : this.outputShape[0]; return "getA(rc), (rc + 1 >= ".concat(outShape, " ? 0. : getA(rc + 1)), 0, 0"); } return "getA(".concat(sourceCoords[0], "),\n cEdge ? 0. : getA(").concat(sourceCoords[1], "),\n rEdge ? 0. : getA(").concat(sourceCoords[2], "),\n rEdge || cEdge ? 0. : getA(").concat(sourceCoords[3], ")"); }; return PackProgram; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ReshapePackedProgram = /** @class */ (function () { function ReshapePackedProgram(outputShape, inputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [{ name: 'inputShape', type: 'ivec3' }]; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); var mainLoop = ""; for (var i = 0; i < 4; i++) { var thisRC = "thisRC = rc;"; if (i % 2 === 1) { thisRC += "thisRC.z += 1;"; } if (i > 1) { thisRC += "thisRC.y += 1;"; } mainLoop += "\n ".concat(thisRC, "\n ").concat(i > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : '', "\n int flatIndex = getFlatIndex(thisRC);\n\n ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);\n vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));\n\n result[").concat(i, "] =\n getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);\n ").concat(i > 0 ? '}' : '', "\n "); } this.userCode = "\n ".concat(getReshapedInputCoords(inputShape, this.enableShapeUniforms), "\n ").concat(this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape), "\n\n void main() {\n ivec3 rc = getOutputCoords();\n\n vec4 result = vec4(0.);\n\n ivec3 thisRC;\n int rows = ").concat(this.enableShapeUniforms ? 'outShape[1]' : outputShape[1], ";\n int cols = ").concat(this.enableShapeUniforms ? 'outShape[2]' : outputShape[2], ";\n\n ").concat(mainLoop, "\n\n setOutput(result);\n }\n "); } return ReshapePackedProgram; }()); function getReshapedInputCoords(shape, enableShapeUniforms) { var coordsFromIndexSnippet = enableShapeUniforms ? getLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], 'inputShape') : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape); return "\n ivec3 inputCoordsFromReshapedOutCoords(int index) {\n ".concat(coordsFromIndexSnippet, "\n return ivec3(r, c, d);\n }\n "); } var TextureManager = /** @class */ (function () { function TextureManager(gpgpu) { this.gpgpu = gpgpu; this.numUsedTextures = 0; this.numFreeTextures = 0; this._numBytesAllocated = 0; // Number of bytes that have been allocated and available for reuse. this._numBytesFree = 0; this.freeTextures = {}; this.usedTextures = {}; this.logEnabled = false; } TextureManager.prototype.acquireTexture = function (shapeRC, usage, isPacked) { var physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked); var shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked); if (!(shapeKey in this.freeTextures)) { this.freeTextures[shapeKey] = []; } if (!(shapeKey in this.usedTextures)) { this.usedTextures[shapeKey] = []; } var texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked); if (this.freeTextures[shapeKey].length > 0) { this.numFreeTextures--; this.numUsedTextures++; this._numBytesFree -= texBytes; this.log(); var newTexture_1 = this.freeTextures[shapeKey].pop(); this.usedTextures[shapeKey].push(newTexture_1); return newTexture_1; } var newTexture; if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) { newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) { newTexture = this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) { newTexture = this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) { newTexture = this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) { newTexture = this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]); } this.usedTextures[shapeKey].push(newTexture); this.numUsedTextures++; this._numBytesAllocated += texBytes; this.log(); return newTexture; }; TextureManager.prototype.releaseTexture = function (texture, shape, logicalTexType, isPacked) { if (this.freeTextures == null) { // Already disposed. return; } var physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked); var shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked); if (!(shapeKey in this.freeTextures)) { this.freeTextures[shapeKey] = []; } var texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked); var deleteTexThreshold = tf.env() .getNumber('WEBGL_DELETE_TEXTURE_THRESHOLD'); if (deleteTexThreshold !== -1 && this._numBytesAllocated > deleteTexThreshold) { this.gpgpu.deleteMatrixTexture(texture.texture); this._numBytesAllocated -= texBytes; } else { this.freeTextures[shapeKey].push(texture); this.numFreeTextures++; this._numBytesFree += texBytes; } this.numUsedTextures--; var texList = this.usedTextures[shapeKey]; var texIndex = texList && texList.indexOf(texture); if (texIndex == null || texIndex < 0) { throw new Error('Cannot release a texture that was never provided by this ' + 'texture manager'); } texList[texIndex] = texList[texList.length - 1]; texList.pop(); this.log(); }; TextureManager.prototype.log = function () { if (!this.logEnabled) { return; } var total = this.numFreeTextures + this.numUsedTextures; console.log('Free/Used', "".concat(this.numFreeTextures, " / ").concat(this.numUsedTextures), "(".concat(total, ")")); var freeRatio = this._numBytesFree / this._numBytesAllocated; console.log("Bytes allocated: ".concat(this._numBytesAllocated)); console.log("Bytes unused: ".concat(this._numBytesFree, " (").concat(Math.round(100 * freeRatio), "%)")); }; Object.defineProperty(TextureManager.prototype, "numBytesAllocated", { get: function () { return this._numBytesAllocated; }, enumerable: false, configurable: true }); Object.defineProperty(TextureManager.prototype, "numBytesFree", { get: function () { return this._numBytesFree; }, enumerable: false, configurable: true }); TextureManager.prototype.getNumUsedTextures = function () { return this.numUsedTextures; }; TextureManager.prototype.getNumFreeTextures = function () { return this.numFreeTextures; }; TextureManager.prototype.dispose = function () { var _this = this; if (this.freeTextures == null) { // Already disposed. return; } for (var texShape in this.freeTextures) { this.freeTextures[texShape].forEach(function (tex) { _this.gpgpu.deleteMatrixTexture(tex.texture); }); } for (var texShape in this.usedTextures) { this.usedTextures[texShape].forEach(function (tex) { _this.gpgpu.deleteMatrixTexture(tex.texture); }); } // TODO: Assign non-null value (empty object) to textures after disposed. this.freeTextures = null; this.usedTextures = null; this.numUsedTextures = 0; this.numFreeTextures = 0; this._numBytesAllocated = 0; this._numBytesFree = 0; }; return TextureManager; }()); function numBytesForInternalFormat(gl, internalFormat) { // tslint:disable-next-line:no-any var glany = gl; if (internalFormat === glany.R32F) { return 4; } else if (internalFormat === glany.R16F) { return 2; } else if (internalFormat === glany.RGBA32F) { return 16; } else if (internalFormat === gl.RGBA) { return 16; } else if (internalFormat === glany.RGBA16F) { return 8; } else if (internalFormat === glany.RGBA8) { return 4; } throw new Error("Unknown internal format ".concat(internalFormat)); } function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) { // It is not possible to infer packed status from the texture type because // depending on the textureConfig, different texture types may resolve to the // same internal format (e.g. in WebGL1, the internal format for // UNPACKED_FLOAT16 textures is gl.RGBA). Therefore we pass in `isPacked` // explicitly. var internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig); var numElements; if (isPacked) { var _a = __read(getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]), 2), packedWidth = _a[0], packedHeight = _a[1]; numElements = packedWidth * packedHeight; } else { var _b = __read(getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]), 2), width = _b[0], height = _b[1]; numElements = width * height; } var bytesPerElement = numBytesForInternalFormat(gl, internalFormat); return numElements * bytesPerElement; } function internalFormatForPhysicalTexType(physicalTexType, textureConfig) { switch (physicalTexType) { case PhysicalTextureType.PACKED_2X2_FLOAT32: return getInternalFormatForPackedMatrixTexture(textureConfig); case PhysicalTextureType.PACKED_2X2_FLOAT16: return getInternalFormatForFloat16PackedMatrixTexture(textureConfig); case PhysicalTextureType.UNPACKED_FLOAT32: return getInternalFormatForFloat32MatrixTexture(textureConfig); case PhysicalTextureType.UNPACKED_FLOAT16: return getInternalFormatForFloat16MatrixTexture(textureConfig); case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE: return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig); default: throw new Error("Unknown physical texture type ".concat(physicalTexType)); } } function getPhysicalTextureForRendering(isPacked) { if (tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) { if (isPacked) { return PhysicalTextureType.PACKED_2X2_FLOAT32; } return PhysicalTextureType.UNPACKED_FLOAT32; } if (isPacked) { return PhysicalTextureType.PACKED_2X2_FLOAT16; } return PhysicalTextureType.UNPACKED_FLOAT16; } function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) { if (logicalTexType === TextureUsage.UPLOAD) { return PhysicalTextureType.PACKED_2X2_FLOAT32; } else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) { return getPhysicalTextureForRendering(isPacked); } else if (logicalTexType === TextureUsage.DOWNLOAD || logicalTexType === TextureUsage.PIXELS) { return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE; } throw new Error("Unknown logical texture type ".concat(logicalTexType)); } function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) { return "".concat(shapeRowsCol[0], "_").concat(shapeRowsCol[1], "_").concat(physicalTexType, "_").concat(isPacked); } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var UnaryOpProgram = /** @class */ (function () { function UnaryOpProgram(aShape, opSnippet) { this.variableNames = ['A']; this.outputShape = aShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = "\n float unaryOperation(float x) {\n ".concat(opSnippet, "\n }\n\n void main() {\n float x = getAAtOutCoords();\n float y = unaryOperation(x);\n\n setOutput(y);\n }\n "); } return UnaryOpProgram; }()); var CHECK_NAN_SNIPPET$1 = "if (isnan(x)) return x;"; var LINEAR$1 = "return x;"; var ABS$1 = "return abs(x);"; var ELU$2 = "return (x >= 0.0) ? x : (exp(x) - 1.0);"; var RELU$2 = CHECK_NAN_SNIPPET$1 + "\n return (x < 0.0) ? 0.0 : x;\n"; var RELU6$2 = CHECK_NAN_SNIPPET$1 + "\n return (x < 0.0) ? 0.0 : min(6.0, x);\n"; var CLONE = 'return x;'; var SIGMOID$2 = "return 1.0 / (1.0 + exp(-1.0 * x));"; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var LINEAR = "return x;"; var ELU$1 = "\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n"; var RELU$1 = "\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n"; var RELU6$1 = "\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n"; var SIGMOID$1 = "return 1.0 / (1.0 + exp(-1.0 * x));"; var UnaryOpPackedProgram = /** @class */ (function () { function UnaryOpPackedProgram(aShape, opSnippet) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outputShape = aShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = "\n vec4 unaryOperation(vec4 x) {\n ".concat(opSnippet, "\n }\n\n void main() {\n vec4 x = getAAtOutCoords();\n vec4 y = unaryOperation(x);\n\n setOutput(y);\n }\n "); } return UnaryOpPackedProgram; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var UnpackProgram = /** @class */ (function () { function UnpackProgram(outputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = false; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); var rank = outputShape.length; var channels = getChannels('rc', rank); var dtype = getCoordsDataType(rank); var sourceCoords = getSourceCoords$2(rank, channels); var innerDims = channels.slice(-2); var coords = rank <= 1 ? 'rc' : "vec2(".concat(innerDims.join(','), ")"); this.userCode = "\n void main() {\n ".concat(dtype, " rc = getOutputCoords();\n vec4 packedInput = getA(").concat(sourceCoords, ");\n\n setOutput(getChannel(packedInput, ").concat(coords, "));\n }\n "); } return UnpackProgram; }()); var whereImpl = tf.kernel_impls.whereImpl; var EPSILON_FLOAT32 = 1e-7; var EPSILON_FLOAT16 = 1e-4; var binaryCaches = {}; function getBinaryCache(webGLVersion) { if (webGLVersion in binaryCaches) { return binaryCaches[webGLVersion]; } binaryCaches[webGLVersion] = {}; return binaryCaches[webGLVersion]; } // Empirically determined constant used to determine size threshold for handing // off execution to the CPU. var CPU_HANDOFF_SIZE_THRESHOLD = tf.env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD'); // Empirically determined constant used to decide the number of MB on GPU // before we warn about high memory use. The MB are this constant * screen area // * dpi / 1024 / 1024. var BEFORE_PAGING_CONSTANT = 600; function numMBBeforeWarning() { if (tf.env().global.screen == null) { return 1024; // 1 GB. } return (tf.env().global.screen.height * tf.env().global.screen.width * window.devicePixelRatio) * BEFORE_PAGING_CONSTANT / 1024 / 1024; } var MathBackendWebGL = /** @class */ (function (_super) { __extends(MathBackendWebGL, _super); function MathBackendWebGL(gpuResource) { var _this = _super.call(this) || this; // Maps data ids that have a pending read operation, to list of subscribers. _this.pendingRead = new WeakMap(); // List of data ids that are scheduled for disposal, but are waiting on a // pending read operation. _this.pendingDisposal = new WeakSet(); // Used to count the number of 'shallow' sliced tensors that point to the // same data id. _this.dataRefCount = new WeakMap(); _this.numBytesInGPU = 0; // Accumulated time spent (including blocking) in uploading data to webgl. _this.uploadWaitMs = 0; // Accumulated time spent (including blocking in downloading data from webgl. _this.downloadWaitMs = 0; // record the last manual GL Flush time. _this.lastGlFlushTime = 0; _this.warnedAboutMemory = false; _this.pendingDeletes = 0; _this.disposed = false; if (!tf.env().getBool('HAS_WEBGL')) { throw new Error('WebGL is not supported on this device'); } var newGPGPU; if (gpuResource != null) { if (gpuResource instanceof GPGPUContext) { newGPGPU = gpuResource; } else { var gl = getWebGLContext(tf.env().getNumber('WEBGL_VERSION'), gpuResource); newGPGPU = new GPGPUContext(gl); } _this.binaryCache = {}; _this.gpgpuCreatedLocally = false; } else { var gl = getWebGLContext(tf.env().getNumber('WEBGL_VERSION')); newGPGPU = new GPGPUContext(gl); _this.binaryCache = getBinaryCache(tf.env().getNumber('WEBGL_VERSION')); _this.gpgpuCreatedLocally = true; } _this.gpgpu = newGPGPU; _this.canvas = _this.gpgpu.gl.canvas; _this.textureManager = new TextureManager(_this.gpgpu); _this.numMBBeforeWarning = numMBBeforeWarning(); _this.texData = new tf.DataStorage(_this, tf.engine()); return _this; } MathBackendWebGL.prototype.nextDataId = function () { return MathBackendWebGL.nextDataId++; }; MathBackendWebGL.prototype.numDataIds = function () { return this.texData.numDataIds() - this.pendingDeletes; }; // Writes a new entry to the data store with a WebGL texture, and registers it // to the texture manager. MathBackendWebGL.prototype.writeTexture = function (texture, shape, dtype, texHeight, texWidth, channels) { // Temporarily create an tensor info to make the texture compatible with // the runWebGLProgram's input. var input = this.makeTensorInfo(shape, dtype); var inData = this.texData.get(input.dataId); // Even though the input texture could be unpacked or dense packed, it is // always considered as unpacked for EncodeMatrixProgram. inData.isPacked = false; // Bind texture to the input tensor. inData.texture = { texture: texture, texShape: [texHeight, texWidth] }; inData.texShape = [texHeight, texWidth]; var shapeAs3D = getShapeAs3D(shape); var program = new EncodeMatrixProgram(shapeAs3D, false /* isByteArray */, channels); var output = this.runWebGLProgram(program, [input], dtype, [[texHeight, texWidth]]); output.shape = shape; // Unbind the texture from the input tensor to avoid the texture being // released. inData.texture = null; this.disposeIntermediateTensorInfo(input); return output.dataId; }; MathBackendWebGL.prototype.write = function (values, shape, dtype) { if (tf.env().getBool('WEBGL_CHECK_NUMERICAL_PROBLEMS') || tf.env().getBool('DEBUG')) { this.checkNumericalProblems(values); } if (dtype === 'complex64' && values != null) { throw new Error("Cannot write to a complex64 dtype. " + "Please use tf.complex(real, imag)."); } var dataId = { id: this.nextDataId() }; this.texData.set(dataId, { shape: shape, dtype: dtype, values: values, usage: TextureUsage.UPLOAD, refCount: 1 }); return dataId; }; /** Return refCount of a `TensorData`. */ MathBackendWebGL.prototype.refCount = function (dataId) { if (this.texData.has(dataId)) { var tensorData = this.texData.get(dataId); return tensorData.refCount; } return 0; }; /** Increase refCount of a `TextureData`. */ MathBackendWebGL.prototype.incRef = function (dataId) { var texData = this.texData.get(dataId); texData.refCount++; }; /** Decrease refCount of a `TextureData`. */ MathBackendWebGL.prototype.decRef = function (dataId) { if (this.texData.has(dataId)) { var texData = this.texData.get(dataId); texData.refCount--; } }; MathBackendWebGL.prototype.move = function (dataId, values, shape, dtype, refCount) { if (tf.env().getBool('DEBUG')) { this.checkNumericalProblems(values); } if (dtype === 'complex64') { throw new Error("Cannot write to a complex64 dtype. " + "Please use tf.complex(real, imag)."); } this.texData.set(dataId, { shape: shape, dtype: dtype, values: values, usage: TextureUsage.UPLOAD, refCount: refCount }); }; MathBackendWebGL.prototype.disposeIntermediateTensorInfo = function (tensorInfo) { this.disposeData(tensorInfo.dataId); }; MathBackendWebGL.prototype.readSync = function (dataId) { var texData = this.texData.get(dataId); var values = texData.values, dtype = texData.dtype, complexTensorInfos = texData.complexTensorInfos, slice = texData.slice, shape = texData.shape, isPacked = texData.isPacked; // The presence of `slice` indicates this tensor is a shallow slice of a // different tensor, and is using that original tensor's texture. Run // `clone` in order to copy that texture and read from it. if (slice != null) { var program = void 0; if (isPacked) { program = new UnaryOpPackedProgram(shape, CLONE); } else { program = new UnaryOpProgram(shape, CLONE); } var res = this.runWebGLProgram(program, [{ dataId: dataId, shape: shape, dtype: dtype }], dtype); var data = this.readSync(res.dataId); this.disposeIntermediateTensorInfo(res); return data; } if (values != null) { return this.convertAndCacheOnCPU(dataId); } if (dtype === 'string') { return values; } var shouldTimeProgram = this.activeTimers != null; var start; if (shouldTimeProgram) { start = tf.util.now(); } var result; if (dtype === 'complex64') { var realValues = this.readSync(complexTensorInfos.real.dataId); var imagValues = this.readSync(complexTensorInfos.imag.dataId); result = tf.backend_util.mergeRealAndImagArrays(realValues, imagValues); } else { result = this.getValuesFromTexture(dataId); } if (shouldTimeProgram) { this.downloadWaitMs += tf.util.now() - start; } return this.convertAndCacheOnCPU(dataId, result); }; MathBackendWebGL.prototype.read = function (dataId) { return __awaiter(this, void 0, void 0, function () { var subscribers_1, texData, values, shape, slice, dtype, complexTensorInfos, isPacked, program, res, data, buffer, tmpDownloadTarget, tmpData, vals, ps, realValues, imagValues, size, gl_1, dTypeVals, subscribers; var _b; return __generator(this, function (_c) { switch (_c.label) { case 0: if (this.pendingRead.has(dataId)) { subscribers_1 = this.pendingRead.get(dataId); return [2 /*return*/, new Promise(function (resolve) { return subscribers_1.push(resolve); })]; } texData = this.texData.get(dataId); values = texData.values, shape = texData.shape, slice = texData.slice, dtype = texData.dtype, complexTensorInfos = texData.complexTensorInfos, isPacked = texData.isPacked; // The presence of `slice` indicates this tensor is a shallow slice of a // different tensor, and is using that original tensor's texture. Run // `clone` in order to copy that texture and read from it. if (slice != null) { program = void 0; if (isPacked) { program = new UnaryOpPackedProgram(shape, CLONE); } else { program = new UnaryOpProgram(shape, CLONE); } res = this.runWebGLProgram(program, [{ dataId: dataId, shape: shape, dtype: dtype }], dtype); data = this.read(res.dataId); this.disposeIntermediateTensorInfo(res); return [2 /*return*/, data]; } if (values != null) { return [2 /*return*/, this.convertAndCacheOnCPU(dataId)]; } if (tf.env().getBool('DEBUG')) { // getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') caused a blocking GPU call. // For performance reason, only check it for debugging. In production, // it doesn't handle this use case anyway, so behavior is not changed. if (!tf.env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') && tf.env().getNumber('WEBGL_VERSION') === 2) { throw new Error("tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and " + "WEBGL_VERSION=2 not yet supported."); } } buffer = null; if (dtype !== 'complex64' && tf.env().get('WEBGL_BUFFER_SUPPORTED')) { // Possibly copy the texture into a buffer before inserting a fence. tmpDownloadTarget = this.decode(dataId); tmpData = this.texData.get(tmpDownloadTarget.dataId); buffer = (_b = this.gpgpu).createBufferFromTexture.apply(_b, __spreadArray([tmpData.texture.texture], __read(getDenseTexShape(shape)), false)); } this.pendingRead.set(dataId, []); if (!(dtype !== 'complex64')) return [3 /*break*/, 2]; // Create a fence and wait for it to resolve. return [4 /*yield*/, this.gpgpu.createAndWaitForFence()]; case 1: // Create a fence and wait for it to resolve. _c.sent(); _c.label = 2; case 2: if (!(dtype === 'complex64')) return [3 /*break*/, 4]; return [4 /*yield*/, Promise.all([ this.read(complexTensorInfos.real.dataId), this.read(complexTensorInfos.imag.dataId) ])]; case 3: ps = _c.sent(); realValues = ps[0]; imagValues = ps[1]; vals = tf.backend_util.mergeRealAndImagArrays(realValues, imagValues); return [3 /*break*/, 5]; case 4: if (buffer == null) { vals = this.getValuesFromTexture(dataId); } else { size = tf.util.sizeFromShape(shape); vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size); } _c.label = 5; case 5: if (tmpDownloadTarget != null) { this.disposeIntermediateTensorInfo(tmpDownloadTarget); } if (buffer != null) { gl_1 = this.gpgpu.gl; callAndCheck(gl_1, function () { return gl_1.deleteBuffer(buffer); }); } dTypeVals = this.convertAndCacheOnCPU(dataId, vals); subscribers = this.pendingRead.get(dataId); this.pendingRead.delete(dataId); // Notify all pending reads. subscribers.forEach(function (resolve) { return resolve(dTypeVals); }); if (this.pendingDisposal.has(dataId)) { this.pendingDisposal.delete(dataId); if (this.disposeData(dataId)) { tf.engine().removeDataId(dataId, this); } this.pendingDeletes--; } return [2 /*return*/, dTypeVals]; } }); }); }; /** * Read tensor to a new texture that is densely packed for ease of use. * @param dataId The source tensor. * @param options * customTexShape: Optional. If set, will use the user defined texture * shape to create the texture. */ MathBackendWebGL.prototype.readToGPU = function (dataId, options) { if (options === void 0) { options = {}; } var texData = this.texData.get(dataId); var values = texData.values, shape = texData.shape, slice = texData.slice, dtype = texData.dtype, isPacked = texData.isPacked, texture = texData.texture; if (dtype === 'complex64') { throw new Error('Does not support reading texture for complex64 dtype.'); } // The presence of `slice` indicates this tensor is a shallow slice of a // different tensor, and is using that original tensor's texture. Run // `clone` in order to copy that texture and read from it. if (slice != null) { var program = void 0; if (isPacked) { program = new UnaryOpPackedProgram(shape, CLONE); } else { program = new UnaryOpProgram(shape, CLONE); } var res = this.runWebGLProgram(program, [{ dataId: dataId, shape: shape, dtype: dtype }], dtype); var gpuResouorce = this.readToGPU(res, options); this.disposeIntermediateTensorInfo(res); return gpuResouorce; } if (texture == null) { if (values != null) { throw new Error('Data is not on GPU but on CPU.'); } else { throw new Error('There is no data on GPU or CPU.'); } } // Decode the texture so that it is stored densely (using four channels). var tmpTarget = this.decode(dataId, options.customTexShape); // Make engine track this tensor, so that we can dispose it later. var tensorRef = tf.engine().makeTensorFromTensorInfo(tmpTarget); var tmpData = this.texData.get(tmpTarget.dataId); return Object.assign({ tensorRef: tensorRef }, tmpData.texture); }; MathBackendWebGL.prototype.bufferSync = function (t) { var data = this.readSync(t.dataId); if (t.dtype === 'string') { try { // Decode the bytes into string. var strings = data.map(function (d) { return tf.util.decodeString(d); }); return tf.buffer(t.shape, t.dtype, strings); } catch (_a) { throw new Error('Failed to decode encoded string bytes into utf-8'); } } return tf.buffer(t.shape, t.dtype, data); }; MathBackendWebGL.prototype.checkNumericalProblems = function (values) { if (values == null) { return; } for (var i = 0; i < values.length; i++) { var num = values[i]; if (!canBeRepresented(num)) { if (tf.env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) { throw Error("The value ".concat(num, " cannot be represented with your ") + "current settings. Consider enabling float32 rendering: " + "'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'"); } throw Error("The value ".concat(num, " cannot be represented on this device.")); } } }; MathBackendWebGL.prototype.getValuesFromTexture = function (dataId) { var _b; var _c = this.texData.get(dataId), shape = _c.shape, dtype = _c.dtype, isPacked = _c.isPacked; var size = tf.util.sizeFromShape(shape); if (tf.env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) { var tmpTarget = this.decode(dataId); var tmpData_1 = this.texData.get(tmpTarget.dataId); var vals_1 = (_b = this.gpgpu) .downloadMatrixFromPackedTexture.apply(_b, __spreadArray([tmpData_1.texture.texture], __read(getDenseTexShape(shape)), false)).subarray(0, size); this.disposeIntermediateTensorInfo(tmpTarget); return vals_1; } var shouldUsePackedProgram = tf.env().getBool('WEBGL_PACK') && isPacked === true; var outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape; var program = shouldUsePackedProgram ? new EncodeFloatPackedProgram(outputShape) : new EncodeFloatProgram(outputShape); var output = this.runWebGLProgram(program, [{ shape: outputShape, dtype: dtype, dataId: dataId }], 'float32'); var tmpData = this.texData.get(output.dataId); var vals = this.gpgpu .downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture.texture, tmpData.texShape[0], tmpData.texShape[1]) .subarray(0, size); this.disposeIntermediateTensorInfo(output); return vals; }; MathBackendWebGL.prototype.timerAvailable = function () { return tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0; }; MathBackendWebGL.prototype.time = function (f) { var _this = this; var oldActiveTimers = this.activeTimers; var newActiveTimers = []; var outerMostTime = false; if (this.programTimersStack == null) { this.programTimersStack = newActiveTimers; outerMostTime = true; } else { this.activeTimers.push(newActiveTimers); } this.activeTimers = newActiveTimers; f(); // needing to split these up because util.flatten only accepts certain types var flattenedActiveTimerQueries = tf.util.flatten(this.activeTimers.map(function (d) { return d.query; })) .filter(function (d) { return d != null; }); var flattenedActiveTimerNames = tf.util.flatten(this.activeTimers.map(function (d) { return d.name; })) .filter(function (d) { return d != null; }); this.activeTimers = oldActiveTimers; if (outerMostTime) { this.programTimersStack = null; } var res = { uploadWaitMs: this.uploadWaitMs, downloadWaitMs: this.downloadWaitMs, kernelMs: null, wallMs: null // will be filled by the engine }; return (function () { return __awaiter(_this, void 0, void 0, function () { var kernelMs_1; return __generator(this, function (_b) { switch (_b.label) { case 0: if (!(tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0)) return [3 /*break*/, 2]; return [4 /*yield*/, Promise.all(flattenedActiveTimerQueries)]; case 1: kernelMs_1 = _b.sent(); res['kernelMs'] = tf.util.sum(kernelMs_1); res['getExtraProfileInfo'] = function () { return kernelMs_1 .map(function (d, i) { return ({ name: flattenedActiveTimerNames[i], ms: d }); }) .map(function (d) { return "".concat(d.name, ": ").concat(d.ms); }) .join(', '); }; return [3 /*break*/, 3]; case 2: res['kernelMs'] = { error: 'WebGL query timers are not supported in this environment.' }; _b.label = 3; case 3: this.uploadWaitMs = 0; this.downloadWaitMs = 0; return [2 /*return*/, res]; } }); }); })(); }; MathBackendWebGL.prototype.memory = function () { return { unreliable: false, numBytesInGPU: this.numBytesInGPU, numBytesInGPUAllocated: this.textureManager.numBytesAllocated, numBytesInGPUFree: this.textureManager.numBytesFree }; }; MathBackendWebGL.prototype.startTimer = function () { if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { return this.gpgpu.beginQuery(); } return { startMs: tf.util.now(), endMs: null }; }; MathBackendWebGL.prototype.endTimer = function (query) { if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { this.gpgpu.endQuery(); return query; } query.endMs = tf.util.now(); return query; }; MathBackendWebGL.prototype.getQueryTime = function (query) { return __awaiter(this, void 0, void 0, function () { var timerQuery; return __generator(this, function (_b) { if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { return [2 /*return*/, this.gpgpu.waitForQueryAndGetTime(query)]; } timerQuery = query; return [2 /*return*/, timerQuery.endMs - timerQuery.startMs]; }); }); }; /** * Decrease the RefCount on the dataId and dispose the memory if the dataId * has 0 refCount. If there are pending read on the data, the disposal would * added to the pending delete queue. Return true if the dataId is removed * from backend or the backend does not contain the dataId, false if the * dataId is not removed. Memory may or may not be released even when dataId * is removed, which also depends on dataRefCount, see `releaseGPU`. * @param dataId * @oaram force Optional, remove the data regardless of refCount */ MathBackendWebGL.prototype.disposeData = function (dataId, force) { if (force === void 0) { force = false; } if (this.pendingDisposal.has(dataId)) { return false; } // No-op if already disposed. if (!this.texData.has(dataId)) { return true; } // if force flag is set, change refCount to 0, this would ensure disposal // when added to the pendingDisposal queue. Memory may or may not be // released, which also depends on dataRefCount, see `releaseGPU`. if (force) { this.texData.get(dataId).refCount = 0; } else { this.texData.get(dataId).refCount--; } if (!force && this.texData.get(dataId).refCount > 0) { return false; } if (this.pendingRead.has(dataId)) { this.pendingDisposal.add(dataId); this.pendingDeletes++; return false; } this.releaseGPUData(dataId); var complexTensorInfos = this.texData.get(dataId).complexTensorInfos; if (complexTensorInfos != null) { this.disposeData(complexTensorInfos.real.dataId, force); this.disposeData(complexTensorInfos.imag.dataId, force); } this.texData.delete(dataId); return true; }; MathBackendWebGL.prototype.releaseGPUData = function (dataId) { var _b = this.texData.get(dataId), texture = _b.texture, dtype = _b.dtype, texShape = _b.texShape, usage = _b.usage, isPacked = _b.isPacked, slice = _b.slice; var key = slice && slice.origDataId || dataId; var refCount = this.dataRefCount.get(key); if (refCount > 1) { this.dataRefCount.set(key, refCount - 1); } else { this.dataRefCount.delete(key); if (texture != null) { this.numBytesInGPU -= this.computeBytes(texShape, dtype); this.textureManager.releaseTexture(texture, texShape, usage, isPacked); } } var texData = this.texData.get(dataId); texData.texture = null; texData.texShape = null; texData.isPacked = false; texData.slice = null; }; MathBackendWebGL.prototype.getTexture = function (dataId) { this.uploadToGPU(dataId); return this.texData.get(dataId).texture.texture; }; /** * Returns internal information for the specific data bucket. Used in unit * tests. */ MathBackendWebGL.prototype.getDataInfo = function (dataId) { return this.texData.get(dataId); }; /* Tests whether all the inputs to an op are small and on the CPU. This heuristic determines when it would be faster to execute a kernel on the CPU. WebGL kernels opt into running this check and forwarding when appropriate. TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more sustainable strategy for optimizing backend execution of ops. */ MathBackendWebGL.prototype.shouldExecuteOnCPU = function (inputs, sizeThreshold) { var _this = this; if (sizeThreshold === void 0) { sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD; } return tf.env().getBool('WEBGL_CPU_FORWARD') && inputs.every(function (input) { return _this.texData.get(input.dataId).texture == null && tf.util.sizeFromShape(input.shape) < sizeThreshold; }); }; MathBackendWebGL.prototype.getGPGPUContext = function () { return this.gpgpu; }; MathBackendWebGL.prototype.where = function (condition) { tf.backend_util.warn('tf.where() in webgl locks the UI thread. ' + 'Call tf.whereAsync() instead'); var condVals = condition.dataSync(); return whereImpl(condition.shape, condVals); }; MathBackendWebGL.prototype.packedUnaryOp = function (x, op, dtype) { var program = new UnaryOpPackedProgram(x.shape, op); var outInfo = this.compileAndRun(program, [x], dtype); return tf.engine().makeTensorFromTensorInfo(outInfo); }; // TODO(msoulanille) remove this once the backend has been modularized // a copy is needed here to break a circular dependency. // Also remove the op from unary_op. MathBackendWebGL.prototype.abs = function (x) { // TODO: handle cases when x is complex. if (this.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') { var outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values); return this.makeOutput(x.shape, x.dtype, outValues); } if (tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, ABS$1, x.dtype); } var program = new UnaryOpProgram(x.shape, ABS$1); var outInfo = this.compileAndRun(program, [x]); return tf.engine().makeTensorFromTensorInfo(outInfo); }; MathBackendWebGL.prototype.makeTensorInfo = function (shape, dtype, values) { var dataId; if (dtype === 'string' && values != null && values.length > 0 && tf.util.isString(values[0])) { var encodedValues = values.map(function (d) { return tf.util.encodeString(d); }); dataId = this.write(encodedValues, shape, dtype); } else { dataId = this.write(values, shape, dtype); } this.texData.get(dataId).usage = null; return { dataId: dataId, shape: shape, dtype: dtype }; }; MathBackendWebGL.prototype.makeOutput = function (shape, dtype, values) { return tf.engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this); }; MathBackendWebGL.prototype.unpackTensor = function (input) { var program = new UnpackProgram(input.shape); return this.runWebGLProgram(program, [input], input.dtype); }; MathBackendWebGL.prototype.packTensor = function (input) { var program = new PackProgram(input.shape); var preventEagerUnpackingOutput = true; return this.runWebGLProgram(program, [input], input.dtype, null /* customUniformValues */, preventEagerUnpackingOutput); }; MathBackendWebGL.prototype.packedReshape = function (input, afterShape) { var input3DShape = __spreadArray([ getBatchDim(input.shape) ], __read(getRowsCols(input.shape)), false); var input3D = { dtype: input.dtype, shape: input3DShape, dataId: input.dataId }; var afterShapeAs3D = __spreadArray([ getBatchDim(afterShape) ], __read(getRowsCols(afterShape)), false); var program = new ReshapePackedProgram(afterShapeAs3D, input3DShape); var preventEagerUnpackingOfOutput = true; var customValues = [input3DShape]; var output = this.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput); return { dataId: output.dataId, shape: afterShape, dtype: output.dtype }; }; MathBackendWebGL.prototype.decode = function (dataId, customTexShape) { var texData = this.texData.get(dataId); var isPacked = texData.isPacked, shape = texData.shape, dtype = texData.dtype; if (customTexShape != null) { var size = tf.util.sizeFromShape(shape); var texSize = customTexShape[0] * customTexShape[1] * 4; tf.util.assert(size <= texSize, function () { return 'customTexShape is too small. ' + 'Row * Column * 4 should be equal or larger than the ' + 'size of the tensor data.'; }); } var shapeAs3D = getShapeAs3D(shape); var program; if (isPacked) { program = new DecodeMatrixPackedProgram(shapeAs3D); } else { program = new DecodeMatrixProgram(shapeAs3D); } var preventEagerUnpackingOfOutput = true; var customValues = [customTexShape != null ? customTexShape : getDenseTexShape(shapeAs3D)]; var out = this.runWebGLProgram(program, [{ shape: shapeAs3D, dtype: dtype, dataId: dataId }], dtype, customValues, preventEagerUnpackingOfOutput, customTexShape); return { dtype: dtype, shape: shape, dataId: out.dataId }; }; MathBackendWebGL.prototype.runWebGLProgram = function (program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput, customTexShape) { var _this = this; if (preventEagerUnpackingOfOutput === void 0) { preventEagerUnpackingOfOutput = false; } var output = this.makeTensorInfo(program.outputShape, outputDtype); var outData = this.texData.get(output.dataId); if (program.packedOutput) { outData.isPacked = true; } if (program.outPackingScheme === PackingScheme.DENSE) { var texelShape = customTexShape != null ? customTexShape : getDenseTexShape(program.outputShape); // For a densely packed output, we explicitly set texShape // so it doesn't get assigned later according to our typical packing // scheme wherein a single texel can only contain values from adjacent // rows/cols. outData.texShape = texelShape.map(function (d) { return d * 2; }); } if (program.outTexUsage != null) { outData.usage = program.outTexUsage; } if (tf.util.sizeFromShape(output.shape) === 0) { // Short-circuit the computation since the result is empty (has 0 in its // shape). outData.values = tf.util.getTypedArrayFromDType(output.dtype, 0); return output; } var dataToDispose = []; var inputsData = inputs.map(function (input) { if (input.dtype === 'complex64') { throw new Error("GPGPUProgram does not support complex64 input. For complex64 " + "dtypes, please separate the program into real and imaginary " + "parts."); } var texData = _this.texData.get(input.dataId); if (texData.texture == null) { if (!program.packedInputs && tf.util.sizeFromShape(input.shape) <= tf.env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) { // Upload small tensors that live on the CPU as uniforms, not as // textures. Do this only when the environment supports 32bit floats // due to problems when comparing 16bit floats with 32bit floats. // TODO(https://github.com/tensorflow/tfjs/issues/821): Make it // possible for packed shaders to sample from uniforms. return { shape: input.shape, texData: null, isUniform: true, uniformValues: texData.values }; } // This ensures that if a packed program's inputs have not yet been // uploaded to the GPU, they get uploaded as packed right off the bat. if (program.packedInputs) { texData.isPacked = true; texData.shape = input.shape; } } _this.uploadToGPU(input.dataId); if (!!texData.isPacked !== !!program.packedInputs) { input = texData.isPacked ? _this.unpackTensor(input) : _this.packTensor(input); dataToDispose.push(input); texData = _this.texData.get(input.dataId); } else if (texData.isPacked && !isReshapeFree(texData.shape, input.shape)) { // This is a special case where a texture exists for a tensor // but the shapes are incompatible (due to packing constraints) because // the tensor did not have a chance to go through the packed reshape // shader. This only happens when we reshape the *same* tensor to form // *distinct* inputs to an op, e.g. dotting a vector with itself. This // case will disappear once packed uploading is the default. var savedInput = input; var targetShape = input.shape; input.shape = texData.shape; input = _this.packedReshape(input, targetShape); dataToDispose.push(input); texData = _this.texData.get(input.dataId); savedInput.shape = targetShape; } return { shape: input.shape, texData: texData, isUniform: false }; }); this.uploadToGPU(output.dataId); var outputData = { shape: output.shape, texData: outData, isUniform: false }; var key = makeShaderKey(program, inputsData, outputData); var binary = this.getAndSaveBinary(key, function () { return compileProgram(_this.gpgpu, program, inputsData, outputData); }); var shouldTimeProgram = this.activeTimers != null; var query; if (shouldTimeProgram) { query = this.startTimer(); } if (!tf.env().get('ENGINE_COMPILE_ONLY')) { runProgram(this.gpgpu, binary, inputsData, outputData, customUniformValues); } dataToDispose.forEach(function (info) { return _this.disposeIntermediateTensorInfo(info); }); if (shouldTimeProgram) { query = this.endTimer(query); this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) }); } var glFlushThreshold = tf.env().getNumber('WEBGL_FLUSH_THRESHOLD'); // Manually GL flush requested if (glFlushThreshold > 0) { var time = tf.util.now(); if ((time - this.lastGlFlushTime) > glFlushThreshold) { this.gpgpu.gl.flush(); this.lastGlFlushTime = time; } } if (!tf.env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked && preventEagerUnpackingOfOutput === false) { var unpacked = this.unpackTensor(output); this.disposeIntermediateTensorInfo(output); return unpacked; } return output; }; MathBackendWebGL.prototype.compileAndRun = function (program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput) { if (preventEagerUnpackingOfOutput === void 0) { preventEagerUnpackingOfOutput = false; } outputDtype = outputDtype || inputs[0].dtype; var outInfo = this.runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput); return outInfo; }; MathBackendWebGL.prototype.getAndSaveBinary = function (key, getBinary) { if (!(key in this.binaryCache)) { this.binaryCache[key] = getBinary(); } return this.binaryCache[key]; }; MathBackendWebGL.prototype.getTextureManager = function () { return this.textureManager; }; MathBackendWebGL.prototype.dispose = function () { var _this = this; if (this.disposed) { return; } // Avoid disposing the compiled webgl programs during unit testing because // it slows down test execution. if (!tf.env().getBool('IS_TEST')) { var allKeys = Object.keys(this.binaryCache); allKeys.forEach(function (key) { _this.gpgpu.deleteProgram(_this.binaryCache[key].webGLProgram); delete _this.binaryCache[key]; }); } this.textureManager.dispose(); if (this.canvas != null && (typeof (HTMLCanvasElement) !== 'undefined' && this.canvas instanceof HTMLCanvasElement)) { this.canvas.remove(); } else { this.canvas = null; } if (this.gpgpuCreatedLocally) { this.gpgpu.program = null; this.gpgpu.dispose(); } this.disposed = true; }; MathBackendWebGL.prototype.floatPrecision = function () { var _this = this; if (this.floatPrecisionValue == null) { this.floatPrecisionValue = tf.tidy(function () { if (!tf.env().get('WEBGL_RENDER_FLOAT32_ENABLED')) { // Momentarily switching DEBUG flag to false so we don't throw an // error trying to upload a small value. var debugFlag = tf.env().getBool('DEBUG'); tf.env().set('DEBUG', false); var underflowCheckValue = _this.abs(tf.scalar(1e-8)).dataSync()[0]; tf.env().set('DEBUG', debugFlag); if (underflowCheckValue > 0) { return 32; } } return 16; }); } return this.floatPrecisionValue; }; /** Returns the smallest representable number. */ MathBackendWebGL.prototype.epsilon = function () { return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; }; MathBackendWebGL.prototype.uploadToGPU = function (dataId) { var _b; var texData = this.texData.get(dataId); var shape = texData.shape, dtype = texData.dtype, values = texData.values, texture = texData.texture, usage = texData.usage, isPacked = texData.isPacked; if (texture != null) { // Array is already on GPU. No-op. return; } var shouldTimeProgram = this.activeTimers != null; var start; if (shouldTimeProgram) { start = tf.util.now(); } var texShape = texData.texShape; if (texShape == null) { // This texShape may not be the final texture shape. For packed or dense // textures, the texShape will be changed when textures are created. texShape = getTextureShapeFromLogicalShape(shape, isPacked); texData.texShape = texShape; } if (values != null) { var shapeAs3D = getShapeAs3D(shape); var program = void 0; var width = texShape[1], height = texShape[0]; var isByteArray = values instanceof Uint8Array || values instanceof Uint8ClampedArray; // texture for float array is PhysicalTextureType.PACKED_2X2_FLOAT32, we // need to make sure the upload uses the same packed size if (isPacked || !isByteArray) { _b = __read(getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]), 2), width = _b[0], height = _b[1]; } if (isPacked) { program = new EncodeMatrixPackedProgram(shapeAs3D, isByteArray); } else { program = new EncodeMatrixProgram(shapeAs3D, isByteArray); } // TexShape for float array needs to be the original shape, which byte // array needs to be packed size. This allow the data upload shape to be // matched with texture creation logic. var tempDenseInputTexShape = isByteArray ? [height, width] : texShape; var tempDenseInputHandle = this.makeTensorInfo(tempDenseInputTexShape, dtype); var tempDenseInputTexData = this.texData.get(tempDenseInputHandle.dataId); if (isByteArray) { tempDenseInputTexData.usage = TextureUsage.PIXELS; } else { tempDenseInputTexData.usage = TextureUsage.UPLOAD; } tempDenseInputTexData.texShape = tempDenseInputTexShape; this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values); var customValues = [[height, width]]; // We want the output to remain packed regardless of the value of // WEBGL_PACK. var preventEagerUnpacking = true; var encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, customValues, preventEagerUnpacking); // Have the original texture assume the identity of the encoded output. var outputTexData = this.texData.get(encodedOutputTarget.dataId); texData.texShape = outputTexData.texShape; texData.isPacked = outputTexData.isPacked; texData.usage = outputTexData.usage; if (!tf.env().get('ENGINE_COMPILE_ONLY')) { texData.texture = outputTexData.texture; // Once uploaded, don't store the values on cpu. texData.values = null; this.texData.delete(encodedOutputTarget.dataId); } else { this.disposeData(encodedOutputTarget.dataId); } this.disposeIntermediateTensorInfo(tempDenseInputHandle); if (shouldTimeProgram) { this.uploadWaitMs += tf.util.now() - start; } } else { var newTexture = this.acquireTexture(texShape, usage, dtype, isPacked); texData.texture = newTexture; } }; MathBackendWebGL.prototype.convertAndCacheOnCPU = function (dataId, float32Values) { var texData = this.texData.get(dataId); var dtype = texData.dtype; if (float32Values != null) { texData.values = float32ToTypedArray(float32Values, dtype); } return texData.values; }; MathBackendWebGL.prototype.acquireTexture = function (texShape, texType, dtype, isPacked) { this.numBytesInGPU += this.computeBytes(texShape, dtype); if (!this.warnedAboutMemory && this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) { var mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2); this.warnedAboutMemory = true; console.warn("High memory usage in GPU: ".concat(mb, " MB, ") + "most likely due to a memory leak"); } return this.textureManager.acquireTexture(texShape, texType, isPacked); }; MathBackendWebGL.prototype.computeBytes = function (shape, dtype) { return shape[0] * shape[1] * tf.util.bytesPerElement(dtype); }; MathBackendWebGL.prototype.checkCompileCompletion = function () { var e_1, _b; try { for (var _c = __values(Object.entries(this.binaryCache)), _d = _c.next(); !_d.done; _d = _c.next()) { var _e = __read(_d.value, 2), binary = _e[1]; this.checkCompletion_(binary); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_d && !_d.done && (_b = _c.return)) _b.call(_c); } finally { if (e_1) throw e_1.error; } } }; MathBackendWebGL.prototype.checkCompileCompletionAsync = function () { return __awaiter(this, void 0, void 0, function () { var ps, _b, _c, _d, binary, _loop_1, _e, _f, _g, binary; var e_2, _h, e_3, _j; var _this = this; return __generator(this, function (_k) { ps = []; if (this.gpgpu.parallelCompilationExtension) { try { for (_b = __values(Object.entries(this.binaryCache)), _c = _b.next(); !_c.done; _c = _b.next()) { _d = __read(_c.value, 2), binary = _d[1]; ps.push(this.checkCompletionAsync_(binary)); } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (_c && !_c.done && (_h = _b.return)) _h.call(_b); } finally { if (e_2) throw e_2.error; } } return [2 /*return*/, Promise.all(ps)]; } else { _loop_1 = function (binary) { var p = new Promise(function (resolve) { try { _this.checkCompletion_(binary); resolve(true); } catch (error) { throw error; } }); ps.push(p); }; try { for (_e = __values(Object.entries(this.binaryCache)), _f = _e.next(); !_f.done; _f = _e.next()) { _g = __read(_f.value, 2), binary = _g[1]; _loop_1(binary); } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (_f && !_f.done && (_j = _e.return)) _j.call(_e); } finally { if (e_3) throw e_3.error; } } return [2 /*return*/, Promise.all(ps)]; } }); }); }; MathBackendWebGL.prototype.checkCompletionAsync_ = function (binary) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_b) { switch (_b.label) { case 0: if (!this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.parallelCompilationExtension.COMPLETION_STATUS_KHR)) return [3 /*break*/, 1]; return [2 /*return*/, this.checkCompletion_(binary)]; case 1: return [4 /*yield*/, tf.nextFrame()]; case 2: _b.sent(); return [2 /*return*/, this.checkCompletionAsync_(binary)]; } }); }); }; MathBackendWebGL.prototype.checkCompletion_ = function (binary) { if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.gl.LINK_STATUS) === false) { console.log(this.gpgpu.gl.getProgramInfoLog(binary.webGLProgram)); if (this.gpgpu.gl.getShaderParameter(binary.fragmentShader, this.gpgpu.gl.COMPILE_STATUS) === false) { logShaderSourceAndInfoLog(binary.source, this.gpgpu.gl.getShaderInfoLog(binary.fragmentShader)); throw new Error('Failed to compile fragment shader.'); } throw new Error('Failed to link vertex and fragment shaders.'); } return true; }; MathBackendWebGL.prototype.getUniformLocations = function () { var e_4, _b; try { for (var _c = __values(Object.values(this.binaryCache)), _d = _c.next(); !_d.done; _d = _c.next()) { var binary = _d.value; // TODO: Iterating through all binaries to build VAOs is supposed to be in // a seperate function, like 'setVaos'. However, to avoid breaking changes // for the users using parallel compile feature now, buildVao is silently // added here. this.gpgpu.buildVao(binary.webGLProgram); var _e = getUniformLocations(this.gpgpu, binary.program, binary.webGLProgram), variablesLocations = _e.variablesLocations, customUniformLocations = _e.customUniformLocations, infLoc = _e.infLoc, nanLoc = _e.nanLoc, outShapeLocation = _e.outShapeLocation, outShapeStridesLocation = _e.outShapeStridesLocation, outTexShapeLocation = _e.outTexShapeLocation; binary.variablesLocations = variablesLocations; binary.customUniformLocations = customUniformLocations; binary.infLoc = infLoc; binary.nanLoc = nanLoc; binary.outShapeLocation = outShapeLocation; binary.outShapeStridesLocation = outShapeStridesLocation; binary.outTexShapeLocation = outTexShapeLocation; } } catch (e_4_1) { e_4 = { error: e_4_1 }; } finally { try { if (_d && !_d.done && (_b = _c.return)) _b.call(_c); } finally { if (e_4) throw e_4.error; } } }; /** * Create a TF.js tensor out of an existing WebGL texture. A new texture will * be created. */ MathBackendWebGL.prototype.createTensorFromGPUData = function (values, shape, dtype) { values.channels = values.channels || 'RGBA'; var texture = values.texture, height = values.height, width = values.width, channels = values.channels; var backend = tf.engine().backend; // Have to throw an error, otherwise WebGL just warns and returns wrong // values. if (!backend.gpgpu.gl.isTexture(texture)) { throw new Error("The texture is invalid. Also, please make sure the texture and " + "the TFJS WebGL backend are using the same canvas. If you want to " + "use your own custom canvas, you have to create and use the custom " + "TFJS WebGL backend created from the canvas through " + "'new tf.MathBackendWebGL(customCanvas)'."); } var dataId = backend.writeTexture(texture, shape, dtype, height, width, channels); return tf.engine().makeTensorFromDataId(dataId, shape, dtype, backend); }; return MathBackendWebGL; }(tf.KernelBackend)); MathBackendWebGL.nextDataId = 0; function float32ToTypedArray(a, dtype) { if (dtype === 'float32' || dtype === 'complex64') { return a; } else if (dtype === 'int32' || dtype === 'bool') { var result = (dtype === 'int32') ? new Int32Array(a.length) : new Uint8Array(a.length); for (var i = 0; i < result.length; ++i) { result[i] = Math.round(a[i]); } return result; } else { throw new Error("Unknown dtype ".concat(dtype)); } } /** @license See the LICENSE file. */ // This code is auto-generated, do not modify this file! var version = '4.15.0'; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Enforce use of half precision textures if available on the platform. * * @doc {heading: 'Environment', namespace: 'webgl'} */ function forceHalfFloat() { tf.env().set('WEBGL_FORCE_F16_TEXTURES', true); } /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ if (tf.device_util.isBrowser()) { tf.registerBackend('webgl', function () { return new MathBackendWebGL(); }, 2 /* priority */); } var webgl = { forceHalfFloat: forceHalfFloat }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var CHECK_NAN_SNIPPET = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n"; var BinaryOpProgram = /** @class */ (function () { function BinaryOpProgram(op, aShape, bShape) { this.variableNames = ['A', 'B']; this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape); this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = "\n float binaryOperation(float a, float b) {\n ".concat(op, "\n }\n\n void main() {\n float a = getAAtOutCoords();\n float b = getBAtOutCoords();\n setOutput(binaryOperation(a, b));\n }\n "); } return BinaryOpProgram; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var CHECK_NAN_SNIPPET_PACKED = "\n result.r = isNaN.r ? NAN : result.r;\n result.g = isNaN.g ? NAN : result.g;\n result.b = isNaN.b ? NAN : result.b;\n result.a = isNaN.a ? NAN : result.a;\n"; var BinaryOpPackedProgram = /** @class */ (function () { function BinaryOpPackedProgram(op, aShape, bShape, checkOutOfBounds) { if (checkOutOfBounds === void 0) { checkOutOfBounds = false; } this.variableNames = ['A', 'B']; this.supportsBroadcasting = true; this.packedInputs = true; this.packedOutput = true; this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape); var rank = this.outputShape.length; this.enableShapeUniforms = useShapeUniforms(rank); var checkOutOfBoundsString = ''; if (checkOutOfBounds) { if (rank === 0 || tf.util.sizeFromShape(this.outputShape) === 1) { checkOutOfBoundsString = "\n result.y = 0.;\n result.z = 0.;\n result.w = 0.;\n "; } else { var dtype = getCoordsDataType(rank); checkOutOfBoundsString = "\n ".concat(dtype, " coords = getOutputCoords();\n "); if (rank === 1) { if (this.enableShapeUniforms) { checkOutOfBoundsString += "\n result.y = (coords + 1) >= outShape ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n "; } else { checkOutOfBoundsString += "\n result.y = (coords + 1) >= ".concat(this.outputShape[0], " ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n "); } } else { var channels = getChannels('coords', rank); if (this.enableShapeUniforms) { checkOutOfBoundsString += "\n bool nextRowOutOfBounds =\n (".concat(channels[rank - 2], " + 1) >= outShape[").concat(rank, " - 2];\n bool nextColOutOfBounds =\n (").concat(channels[rank - 1], " + 1) >= outShape[").concat(rank, " - 1];\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n "); } else { checkOutOfBoundsString += "\n bool nextRowOutOfBounds =\n (".concat(channels[rank - 2], " + 1) >= ").concat(this.outputShape[rank - 2], ";\n bool nextColOutOfBounds =\n (").concat(channels[rank - 1], " + 1) >= ").concat(this.outputShape[rank - 1], ";\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n "); } } } } this.userCode = "\n vec4 binaryOperation(vec4 a, vec4 b) {\n ".concat(op, "\n }\n\n void main() {\n vec4 a = getAAtOutCoords();\n vec4 b = getBAtOutCoords();\n\n vec4 result = binaryOperation(a, b);\n ").concat(checkOutOfBoundsString, "\n\n setOutput(result);\n }\n "); } return BinaryOpPackedProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function identity(args) { var inputs = args.inputs, backend = args.backend; var x = inputs.x; backend.incRef(x.dataId); return { dataId: x.dataId, shape: x.shape, dtype: x.dtype }; } var identityConfig = { kernelName: tf.Identity, backendName: 'webgl', kernelFunc: identity }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * In WebGL data is stored in GPU textures which can't be efficiently copied, so * complex tensors share data with their real and imaginary components. Complex * tensors' reference to the components is tracked by refCount on the individual * component. The refCounts are increased by the identity call. * * When a complex tensor is disposed, it will reduce the refCount on the * components by calling disposeData on each. */ function complex(args) { var inputs = args.inputs, backend = args.backend; var real = inputs.real, imag = inputs.imag; var complexInfo = backend.makeTensorInfo(real.shape, 'complex64'); var complex = backend.texData.get(complexInfo.dataId); var realTensorInfo = identity({ inputs: { x: real }, backend: backend }); var imagTensorInfo = identity({ inputs: { x: imag }, backend: backend }); complex.complexTensorInfos = { real: realTensorInfo, imag: imagTensorInfo }; return complexInfo; } var complexConfig = { kernelName: tf.Complex, backendName: 'webgl', kernelFunc: complex }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var LEAKYRELU = "return (a < 0.) ? b * a : a;"; var LEAKYRELU_PACKED = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n"; function leakyRelu(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var alpha = attrs.alpha; var $alpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(alpha, 'float32')); var program = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(LEAKYRELU_PACKED, x.shape, $alpha.shape) : new BinaryOpProgram(LEAKYRELU, x.shape, $alpha.shape); var result = backend.runWebGLProgram(program, [x, $alpha], 'float32'); backend.disposeIntermediateTensorInfo($alpha); return result; } var leakyReluConfig = { kernelName: tf.LeakyRelu, backendName: 'webgl', kernelFunc: leakyRelu }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var PRELU = "return (a < 0.) ? b * a : a;"; var PRELU_PACKED = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n"; function prelu(args) { var inputs = args.inputs, backend = args.backend; var x = inputs.x, alpha = inputs.alpha; var program = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(PRELU_PACKED, x.shape, alpha.shape) : new BinaryOpProgram(PRELU, x.shape, alpha.shape); return backend.runWebGLProgram(program, [x, alpha], 'float32'); } var preluConfig = { kernelName: tf.Prelu, backendName: 'webgl', kernelFunc: prelu }; var CHECK_NAN_SNIPPET_UNARY = "if (isnan(x)) return x;"; /** * Template that creates a `KernelFunc` for unary ops. * @param opSnippet Op snippet to create `UnaryOpProgram`. * @param packedOpSnippet Op snippet to create `UnaryOpPackedProgram`. * @param dtype Optional. If set, the result has this dtype. Otherwise, the * result has the same dtype as the first input. This is mainly used in * comparison kernels, such as Equal, Less, Greater, etc. */ function unaryKernelFunc(_a) { var opSnippet = _a.opSnippet, packedOpSnippet = _a.packedOpSnippet, cpuKernelImpl = _a.cpuKernelImpl, dtype = _a.dtype; return function (_a) { var inputs = _a.inputs, backend = _a.backend; var x = inputs.x; var webglBackend = backend; var $dtype = dtype || x.dtype; if (webglBackend.shouldExecuteOnCPU([x]) && cpuKernelImpl != null) { var xData = webglBackend.texData.get(x.dataId); var outValues = cpuKernelImpl(xData.values, $dtype); return webglBackend.makeTensorInfo(x.shape, $dtype, outValues); } var shouldUsePackedProgram = tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS') && packedOpSnippet != null; var program; if (shouldUsePackedProgram) { program = new UnaryOpPackedProgram(x.shape, packedOpSnippet); } else { program = new UnaryOpProgram(x.shape, opSnippet); } return webglBackend.runWebGLProgram(program, [x], $dtype); }; } /** * Template that creates a `KernelFunc` for binary ops. * @param opSnippet Op snippet to create `BinaryOpProgram`. * @param packedOpSnippet Op snippet to create `BinaryOpPackedProgram`. * @param checkOutOfBoundsForPackedProgram Whether to set checkOutOfBounds=true * when creating BinaryOpPackedProgram. * @param dtype Optional. If set, the result has this dtype. Otherwise, the * result has the same dtype as the first input. This is mainly used in * comparison kernels, such as Equal, Less, Greater, etc. */ function binaryKernelFunc(_a) { var opSnippet = _a.opSnippet, packedOpSnippet = _a.packedOpSnippet, _b = _a.checkOutOfBounds, checkOutOfBounds = _b === void 0 ? false : _b, _c = _a.supportsComplex, supportsComplex = _c === void 0 ? false : _c, cpuKernelImpl = _a.cpuKernelImpl, dtype = _a.dtype; return function (_a) { var inputs = _a.inputs, backend = _a.backend; var a = inputs.a, b = inputs.b; var webglBackend = backend; if (supportsComplex && a.dtype === 'complex64') { var aData = webglBackend.texData.get(a.dataId); var bData = webglBackend.texData.get(b.dataId); var _b = __read([ [aData.complexTensorInfos.real, bData.complexTensorInfos.real], [aData.complexTensorInfos.imag, bData.complexTensorInfos.imag] ].map(function (complexParts) { var _a = __read(complexParts, 2), aPart = _a[0], bPart = _a[1]; var aHandle = { dataId: aPart.dataId, dtype: aPart.dtype, shape: a.shape }; var bHandle = { dataId: bPart.dataId, dtype: bPart.dtype, shape: b.shape }; var program = new BinaryOpProgram(opSnippet, a.shape, b.shape); return webglBackend.runWebGLProgram(program, [aHandle, bHandle], tf.upcastType(aPart.dtype, bPart.dtype)); }), 2), real = _b[0], imag = _b[1]; var complexOutput = complex({ inputs: { real: real, imag: imag }, backend: webglBackend }); webglBackend.disposeIntermediateTensorInfo(real); webglBackend.disposeIntermediateTensorInfo(imag); // TODO(annxingyuan): Implement CPU forwarding for complex inputs. return complexOutput; } var $dtype = dtype || tf.upcastType(a.dtype, b.dtype); if ((a.dtype === 'string' || b.dtype === 'string' || webglBackend.shouldExecuteOnCPU([a, b])) && cpuKernelImpl != null) { var aVals = webglBackend.texData.get(a.dataId).values; var bVals = webglBackend.texData.get(b.dataId).values; var decodedAVals = a.dtype === 'string' ? // tslint:disable-next-line: no-any tf.backend_util.fromUint8ToStringArray(aVals) : aVals; var decodedBVals = a.dtype === 'string' ? // tslint:disable-next-line: no-any tf.backend_util.fromUint8ToStringArray(bVals) : bVals; var _c = __read(cpuKernelImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype), 2), outValues = _c[0], outShape = _c[1]; var out = webglBackend.makeTensorInfo(outShape, $dtype); var outData = webglBackend.texData.get(out.dataId); outData.values = outValues; return out; } var shouldUsePackedProgram = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') && packedOpSnippet != null; var program; if (shouldUsePackedProgram) { program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds); } else { program = new BinaryOpProgram(opSnippet, a.shape, b.shape); } return webglBackend.runWebGLProgram(program, [a, b], $dtype); }; } function mapActivationToShaderProgram(activation, packed) { if (packed === void 0) { packed = false; } if (activation === 'linear') { if (packed) { return LINEAR; } return LINEAR$1; } else if (activation === 'relu') { if (packed) { return RELU$1; } return RELU$2; } else if (activation === 'elu') { if (packed) { return ELU$1; } return ELU$2; } else if (activation === 'relu6') { if (packed) { return RELU6$1; } return RELU6$2; } else if (activation === 'prelu') { if (packed) { return PRELU_PACKED; } return PRELU; } else if (activation === 'leakyrelu') { if (packed) { return LEAKYRELU_PACKED; } return LEAKYRELU; } else if (activation === 'sigmoid') { if (packed) { return SIGMOID$1; } return SIGMOID$2; } throw new Error("Activation ".concat(activation, " has not been implemented for the WebGL backend.")); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var MatMulPackedProgram = /** @class */ (function () { function MatMulPackedProgram(aShape, bShape, outputShape, transposeA, transposeB, addBias, activation, hasPreluActivation, hasLeakyreluActivation) { if (transposeA === void 0) { transposeA = false; } if (transposeB === void 0) { transposeB = false; } if (addBias === void 0) { addBias = false; } if (activation === void 0) { activation = null; } if (hasPreluActivation === void 0) { hasPreluActivation = false; } if (hasLeakyreluActivation === void 0) { hasLeakyreluActivation = false; } this.variableNames = ['matrixA', 'matrixB']; this.packedInputs = true; this.packedOutput = true; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); var sharedDim = transposeA ? aShape[1] : aShape[2]; var sharedDimensionPacked = Math.ceil(sharedDim / 2); var aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2'; var bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z'; var aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww']; var bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw']; var activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n ".concat(activation, "\n }"); } else if (hasLeakyreluActivation) { activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getLeakyreluAlphaAtOutCoords();\n ".concat(activation, "\n }"); } else { activationSnippet = "vec4 activation(vec4 x) {\n ".concat(activation, "\n }"); } applyActivationSnippet = "result = activation(result);"; } var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyreluActivation) { this.variableNames.push('leakyreluAlpha'); } var batchASnippet = 'rc.x'; var batchBSnippet = 'rc.x'; if (aShape[0] < bShape[0]) { batchASnippet = "imod(rc.x, ".concat(aShape[0], ")"); } else if (bShape[0] < aShape[0]) { batchBSnippet = "imod(rc.x, ".concat(bShape[0], ")"); } this.userCode = "\n ".concat(activationSnippet, "\n // Don't use uniform for sharedDimensionPacked for performance.\n const float sharedDimension = ").concat(sharedDimensionPacked, ".0;\n\n vec4 dot2x2ARowBCol(ivec3 rc) {\n vec4 result = vec4(0);\n int batchA = ").concat(batchASnippet, ";\n int batchB = ").concat(batchBSnippet, ";\n for (int i = 0; i < ").concat(sharedDimensionPacked, "; i++) {\n vec4 a = getMatrixA(batchA, ").concat(aSample, ");\n vec4 b = getMatrixB(batchB, ").concat(bSample, ");\n\n // These swizzled products need to be separately added.\n // See: https://github.com/tensorflow/tfjs/issues/1735\n result += (").concat(aSwizzle[0], " * ").concat(bSwizzle[0], ");\n result += (").concat(aSwizzle[1], " * ").concat(bSwizzle[1], ");\n }\n return result;\n }\n\n void main() {\n ivec3 rc = getOutputCoords();\n vec4 result = dot2x2ARowBCol(rc);\n\n ").concat(addBiasSnippet, "\n\n ").concat(applyActivationSnippet, "\n\n setOutput(result);\n }\n "); } return MatMulPackedProgram; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // (Ar + Ai)(Br + Bi) = // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr // Yr = ArBr - AB // Yi = ArBi + AiBr var COMPLEX_MULTIPLY = { REAL: 'return areal * breal - aimag * bimag;', IMAG: 'return areal * bimag + aimag * breal;' }; var BinaryOpComplexProgram = /** @class */ (function () { function BinaryOpComplexProgram(op, aShape, bShape) { this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag']; this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape); this.userCode = "\n float binaryOpComplex(\n float areal, float aimag, float breal, float bimag) {\n ".concat(op, "\n }\n\n void main() {\n float areal = getARealAtOutCoords();\n float aimag = getAImagAtOutCoords();\n float breal = getBRealAtOutCoords();\n float bimag = getBImagAtOutCoords();\n setOutput(binaryOpComplex(areal, aimag, breal, bimag));\n }\n "); } return BinaryOpComplexProgram; }()); var MUL = 'return a * b;'; function multiply(args) { var inputs = args.inputs, backend = args.backend; var a = inputs.a, b = inputs.b; var dtype = tf.backend_util.upcastType(a.dtype, b.dtype); if (a.dtype === 'complex64') { var aData = backend.texData.get(a.dataId); var bData = backend.texData.get(b.dataId); var realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape); var imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape); var inputs_1 = [ { dataId: aData.complexTensorInfos.real.dataId, dtype: aData.complexTensorInfos.real.dtype, shape: a.shape }, { dataId: aData.complexTensorInfos.imag.dataId, dtype: aData.complexTensorInfos.imag.dtype, shape: a.shape }, { dataId: bData.complexTensorInfos.real.dataId, dtype: bData.complexTensorInfos.real.dtype, shape: b.shape }, { dataId: bData.complexTensorInfos.imag.dataId, dtype: bData.complexTensorInfos.imag.dtype, shape: b.shape } ]; var realPart = backend.runWebGLProgram(realProgram, inputs_1, 'float32'); var imagPart = backend.runWebGLProgram(imagProgram, inputs_1, 'float32'); var complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend: backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(imagPart); // TODO(annxingyuan): CPU forwarding for complex inputs. return complexOutput; } if (backend.shouldExecuteOnCPU([a, b])) { var aData = backend.texData.get(a.dataId); var bData = backend.texData.get(b.dataId); var _a = __read(multiplyImplCPU(a.shape, b.shape, aData.values, bData.values, dtype), 2), outValues = _a[0], outShape = _a[1]; var out = backend.makeTensorInfo(outShape, dtype); var outData = backend.texData.get(out.dataId); outData.values = outValues; return out; } var program; if (tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { program = new BinaryOpPackedProgram(MUL, a.shape, b.shape); } else { program = new BinaryOpProgram(MUL, a.shape, b.shape); } return backend.runWebGLProgram(program, [a, b], dtype); } var multiplyConfig = { kernelName: tf.Multiply, backendName: 'webgl', kernelFunc: multiply }; function packedReshape(input, afterShape, backend) { var input3DShape = __spreadArray([getBatchDim(input.shape)], __read(getRowsCols(input.shape)), false); var input3D = { dtype: input.dtype, shape: input3DShape, dataId: input.dataId }; var afterShapeAs3D = __spreadArray([getBatchDim(afterShape)], __read(getRowsCols(afterShape)), false); var program = new ReshapePackedProgram(afterShapeAs3D, input3DShape); var preventEagerUnpackingOfOutput = true; var customValues = [input3DShape]; var output = backend.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput); return { dataId: output.dataId, shape: afterShape, dtype: output.dtype }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function reshape(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var shape = attrs.shape; var webglBackend = backend; var xSize = tf.util.sizeFromShape(x.shape); var $shape = tf.util.inferFromImplicitShape(shape, xSize); var $xSize = tf.util.sizeFromShape($shape); tf.util.assert(xSize === $xSize, function () { return "The new shape (".concat($shape, ") has ").concat($xSize, " elements and the old ") + "shape (".concat(x.shape, ") has ").concat(xSize, " elements. The new shape and old ") + "shape must have the same number of elements."; }); var xTexData = webglBackend.texData.get(x.dataId); if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) && !(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) { return packedReshape(x, $shape, webglBackend); } webglBackend.incRef(x.dataId); return { dataId: x.dataId, shape: $shape, dtype: x.dtype }; } var reshapeConfig = { kernelName: tf.Reshape, backendName: 'webgl', kernelFunc: reshape }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var MeanProgram = /** @class */ (function () { function MeanProgram(reduceInfo, divisor) { this.variableNames = ['x']; var windowSize = reduceInfo.windowSize, batchSize = reduceInfo.batchSize, inSize = reduceInfo.inSize, outSize = reduceInfo.outSize; this.outputShape = [batchSize, outSize]; var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; var windowSizeVec4Remainder = windowSize % 4; var updateSnippet = "sumValue += dot(values, ones);"; if (divisor != null) { var denominator = 1 / divisor; updateSnippet = "sumValue += dot(values * ".concat(tf.util.isInt(denominator) ? denominator.toPrecision(2) : denominator, ", ones);"); } var checkOutOfBounds = ''; if (inSize % windowSize > 0) { checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= ".concat(inSize, ") {\n return 0.0;\n }\n "); } this.userCode = "\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n ".concat(checkOutOfBounds, "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * ").concat(windowSize, ";\n\n float sumValue = 0.0;\n\n for (int i = 0; i < ").concat(windowSizeNearestVec4, "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n ").concat(updateSnippet, "\n }\n\n int inIdx = inOffset + ").concat(windowSizeNearestVec4, ";\n if (").concat(windowSizeVec4Remainder === 1, ") {\n vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 2, ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1), 0.0, 0.0);\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 3, ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2), 0.0);\n\n ").concat(updateSnippet, "\n }\n setOutput(sumValue);\n }\n "); } return MeanProgram; }()); /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ReduceProgram = /** @class */ (function () { function ReduceProgram(reduceInfo, reduceType) { this.variableNames = ['x']; var windowSize = reduceInfo.windowSize, batchSize = reduceInfo.batchSize, inSize = reduceInfo.inSize, outSize = reduceInfo.outSize; this.outputShape = [batchSize, outSize]; var initializationValue = '0.0'; var compareOp = ""; if (reduceType === 'prod') { initializationValue = '1.0'; } else if (reduceType === 'min') { // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. initializationValue = '1.0 / 1e-20'; compareOp = "min"; } else if (reduceType === 'max') { // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. initializationValue = '-1.0 / 1e-20'; compareOp = "max"; } var returnValue = "".concat(reduceType, "(").concat(reduceType, "(").concat(reduceType, "(") + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (reduceType === 'sum') { returnValue = "sumValue"; } else if (reduceType === 'prod') { returnValue = "prodValue"; } else if (reduceType === 'all') { returnValue = "allValue"; } else if (reduceType === 'any') { returnValue = "anyValue"; } var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; var windowSizeVec4Remainder = windowSize % 4; var updateSnippet = "\n if (".concat(reduceType === 'sum', ") {\n sumValue += dot(values, ones);\n } else if (").concat(reduceType === 'prod', ") {\n vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);\n prodValue *= tmp[0] * tmp[1];\n } else {\n minMaxValue = ").concat(compareOp, "(values, minMaxValue);\n if (").concat(reduceType === 'min', " || ").concat(reduceType === 'max', ") {\n minMaxValue = ").concat(compareOp, "(values, minMaxValue);\n bvec4 isNaN = isnan(values);\n if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {\n minMaxValue = vec4(NAN);\n }\n }\n }\n "); var vecType = "vec4"; if (reduceType === 'all') { initializationValue = '1.0'; updateSnippet = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n "; vecType = "bvec4"; } else if (reduceType === 'any') { initializationValue = '0.0'; updateSnippet = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n "; vecType = "bvec4"; } var checkOutOfBounds = ''; if (inSize % windowSize > 0) { checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= ".concat(inSize, ") {\n return initializationValue;\n }\n "); } this.userCode = "\n const float initializationValue = ".concat(initializationValue, ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n ").concat(checkOutOfBounds, "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * ").concat(windowSize, ";\n\n vec4 minMaxValue = vec4(").concat(initializationValue, ");\n float prodValue = 1.0;\n float sumValue = 0.0;\n float allValue = 1.0;\n float anyValue = 0.0;\n\n for (int i = 0; i < ").concat(windowSizeNearestVec4, "; i += 4) {\n int inIdx = inOffset + i;\n ").concat(vecType, " values = ").concat(vecType, "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n ").concat(updateSnippet, "\n }\n\n int inIdx = inOffset + ").concat(windowSizeNearestVec4, ";\n if (").concat(windowSizeVec4Remainder === 1, ") {\n ").concat(vecType, " values = ").concat(vecType, "(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 2, ") {\n ").concat(vecType, " values = ").concat(vecType, "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 3, ") {\n ").concat(vecType, " values = ").concat(vecType, "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n ").concat(updateSnippet, "\n }\n setOutput(").concat(returnValue, ");\n }\n "); } return ReduceProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Returns an array of configuration objects that describe each stage of the // reduction. function getReductionStages(inShape) { var stages = []; while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) { var outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1]; var windowSize = tf.backend_util.computeOptimalWindowSize(outSize); stages.push({ inSize: outSize, windowSize: windowSize, outSize: Math.ceil(outSize / windowSize) }); } return stages; } function reduce(x, dtype, reductionType, backend) { var reductionStages = getReductionStages(x.shape); var result = x; for (var i = 0; i < reductionStages.length; i++) { var _a = reductionStages[i], inSize = _a.inSize, windowSize = _a.windowSize, outSize = _a.outSize; var program = void 0; var previousResult = void 0; if (reductionType === 'mean') { program = i === 0 ? new MeanProgram({ windowSize: windowSize, inSize: inSize, batchSize: x.shape[0], outSize: outSize }, inSize) : new MeanProgram({ windowSize: windowSize, inSize: inSize, batchSize: x.shape[0], outSize: outSize }); } else { program = new ReduceProgram({ windowSize: windowSize, inSize: inSize, batchSize: x.shape[0], outSize: outSize }, reductionType); } previousResult = result; result = backend.runWebGLProgram(program, [result], dtype); if (previousResult.dataId !== x.dataId) { backend.disposeIntermediateTensorInfo(previousResult); } } return result; } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var TransposeProgram = /** @class */ (function () { function TransposeProgram(aShape, newDim) { this.variableNames = ['A']; var outputShape = new Array(aShape.length); for (var i = 0; i < outputShape.length; i++) { outputShape[i] = aShape[newDim[i]]; } this.outputShape = outputShape; this.rank = outputShape.length; var dtype = getCoordsDataType(this.rank); var switched = getSwitchedCoords(newDim); this.userCode = "\n void main() {\n ".concat(dtype, " resRC = getOutputCoords();\n setOutput(getA(").concat(switched, "));\n }\n "); } return TransposeProgram; }()); function getSwitchedCoords(newDim) { var rank = newDim.length; if (rank > 6) { throw Error("Transpose for rank ".concat(rank, " is not yet supported")); } var originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v']; var switchedCoords = new Array(rank); for (var i = 0; i < newDim.length; i++) { switchedCoords[newDim[i]] = originalOrder[i]; } return switchedCoords.join(); } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var TransposePackedProgram = /** @class */ (function () { function TransposePackedProgram(aShape, newDim) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; var outputShape = new Array(aShape.length); for (var i = 0; i < outputShape.length; i++) { outputShape[i] = aShape[newDim[i]]; } this.outputShape = outputShape; this.rank = outputShape.length; if (this.rank > 6) { throw Error("Packed transpose for rank ".concat(this.rank, " is not yet supported.")); } var dtype = getCoordsDataType(this.rank); var outputOrder = getVecChannels('rc', this.rank); var switchedOrder = new Array(this.rank); for (var i = 0; i < newDim.length; i++) { switchedOrder[newDim[i]] = outputOrder[i]; } var innerDims = "vec2(".concat(switchedOrder.slice(-2).join(), ")"); var nextColumn = "++".concat(outputOrder[this.rank - 1], " < ").concat(outputShape[this.rank - 1]); var getc = "getChannel(getA(".concat(switchedOrder.join(), "), ").concat(innerDims, ")"); this.userCode = "\n void main() {\n ".concat(dtype, " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result[0] = ").concat(getc, ";\n if(").concat(nextColumn, ") {\n result[1] = ").concat(getc, ";\n }\n --").concat(outputOrder[this.rank - 1], ";\n if(++").concat(outputOrder[this.rank - 2], " < ").concat(outputShape[this.rank - 2], ") {\n result[2] = ").concat(getc, ";\n if(").concat(nextColumn, ") {\n result[3] = ").concat(getc, ";\n }\n }\n setOutput(result);\n }\n "); } return TransposePackedProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function transposeImpl(x, perm, backend) { var program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new TransposePackedProgram(x.shape, perm) : new TransposeProgram(x.shape, perm); return backend.runWebGLProgram(program, [x], x.dtype); } function sumImpl(x, axis, keepDims, backend) { var reductionIndices = axis; var xRank = x.shape.length; var origAxes = tf.util.parseAxisParam(reductionIndices, x.shape); var axes = origAxes; var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); var sumInputIsTransposed = permutedAxes != null; var sumInput = x; if (sumInputIsTransposed) { sumInput = transposeImpl(x, permutedAxes, backend); axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); } tf.backend_util.assertAxesAreInnerMostDims('sum', axes, xRank); var _a = __read(tf.backend_util.computeOutAndReduceShapes(sumInput.shape, axes), 2), sumOutShape = _a[0], reduceShape = _a[1]; var outShape = sumOutShape; if (keepDims) { // rather than reshape at the end, set the target shape here. outShape = tf.backend_util.expandShapeToKeepDim(sumOutShape, origAxes); } var inSize = tf.util.sizeFromShape(reduceShape); var xSize = tf.util.sizeFromShape(x.shape); var batchSize = xSize / inSize; var reshapedInput = reshape({ inputs: { x: sumInput }, attrs: { shape: [batchSize, inSize] }, backend: backend }); var outType = tf.sumOutType(x.dtype); var reduced = reduce(reshapedInput, outType, 'sum', backend); var out = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend: backend }); backend.disposeIntermediateTensorInfo(reshapedInput); backend.disposeIntermediateTensorInfo(reduced); if (sumInputIsTransposed) { backend.disposeIntermediateTensorInfo(sumInput); } return out; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sum(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var axis = attrs.axis, keepDims = attrs.keepDims; return sumImpl(x, axis, keepDims, backend); } var sumConfig = { kernelName: tf.Sum, backendName: 'webgl', kernelFunc: sum }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function transpose(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var perm = attrs.perm; var webglBackend = backend; var xRank = x.shape.length; var newShape = new Array(xRank); for (var i = 0; i < newShape.length; i++) { newShape[i] = x.shape[perm[i]]; } var out; if (webglBackend.shouldExecuteOnCPU([x])) { var xTexData = webglBackend.texData.get(x.dataId); var values = xTexData.values; var outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape); out = webglBackend.makeTensorInfo(newShape, x.dtype); var outData = webglBackend.texData.get(out.dataId); outData.values = outValues; } else { out = transposeImpl(x, perm, webglBackend); } return out; } var transposeConfig = { kernelName: tf.Transpose, backendName: 'webgl', kernelFunc: transpose }; // Empirically determined minimal shared dimension in matmul before we forward // to a.mul(b).sum() in order to take advantage of GPU parallelism. See // https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks. var MATMUL_SHARED_DIM_THRESHOLD = 1000; function batchMatMulImpl(_a) { var e_1, _b; var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, backend = _a.backend, _c = _a.bias, bias = _c === void 0 ? null : _c, _d = _a.preluActivationWeights, preluActivationWeights = _d === void 0 ? null : _d, _e = _a.leakyreluAlpha, leakyreluAlpha = _e === void 0 ? 0 : _e, _f = _a.activation, activation = _f === void 0 ? null : _f; var aRank = a.shape.length; var bRank = b.shape.length; var innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1]; var innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2]; var outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2]; var outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1]; var outerDimsA = a.shape.slice(0, -2); var outerDimsB = b.shape.slice(0, -2); var batchDimA = tf.util.sizeFromShape(outerDimsA); var batchDimB = tf.util.sizeFromShape(outerDimsB); var outShapeOuterDims = tf.broadcast_util.assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2)); var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]); tf.util.assert(innerShapeA === innerShapeB, function () { return "Error in matMul: inner shapes (".concat(innerShapeA, ") and (") + "".concat(innerShapeB, ") of Tensors with shapes ").concat(a.shape, " and ") + "".concat(b.shape, " and transposeA=").concat(transposeA) + " and transposeB=".concat(transposeB, " must match."); }); var a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] : [batchDimA, outerShapeA, innerShapeA]; var b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] : [batchDimB, innerShapeB, outerShapeB]; // The rest of the implementation is designed to operate on rank-3 tensors var a3d = reshape({ inputs: { x: a }, backend: backend, attrs: { shape: a3dShape } }); var b3d = reshape({ inputs: { x: b }, backend: backend, attrs: { shape: b3dShape } }); var intermediates = [a3d, b3d]; var batchDim = Math.max(batchDimA, batchDimB); var sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2]; var hasBias = bias != null; var hasPreluActivationWeights = preluActivationWeights != null; var hasLeakyreluAlpha = activation === 'leakyrelu'; var fusedActivation = activation != null ? mapActivationToShaderProgram(activation, true) : null; var containsFusedOps = hasBias || hasPreluActivationWeights || hasLeakyreluAlpha || fusedActivation != null; var out; // Since the matrices are vectors, it is faster to call mul().sum() // because sum() is O(sqrt(N)) due to divide-and-conquer. if ((outerShapeA === 1 || outerShapeB === 1) && sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) { var aVec = a3d; var bVec = b3d; if (transposeA) { aVec = transpose({ inputs: { x: a3d }, backend: backend, attrs: { perm: [0, 2, 1] } }); intermediates.push(aVec); } if (transposeB) { bVec = transpose({ inputs: { x: b3d }, backend: backend, attrs: { perm: [0, 2, 1] } }); intermediates.push(bVec); } var shouldReshapeA = outerShapeB !== 1; var shouldReshapeB = outerShapeB === 1; var aVec3d = aVec; if (shouldReshapeA) { aVec3d = reshape({ inputs: { x: aVec }, backend: backend, attrs: { shape: [batchDim, sharedDim, 1] } }); intermediates.push(aVec3d); } var axis = outerShapeB === 1 ? 2 : 1; var bVec3d = bVec; if (shouldReshapeB) { bVec3d = reshape({ inputs: { x: bVec }, backend: backend, attrs: { shape: [batchDim, 1, sharedDim] } }); intermediates.push(bVec3d); } var product = multiply({ inputs: { a: aVec3d, b: bVec3d }, backend: backend }); out = sum({ inputs: { x: product }, backend: backend, attrs: { axis: axis, keepDims: true } }); intermediates.push(product); } else { var dtype = tf.upcastType(a.dtype, b.dtype); var program = new MatMulPackedProgram(a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); var inputs = [a3d, b3d]; if (bias != null) { inputs.push(bias); } if (hasPreluActivationWeights) { inputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } out = backend.runWebGLProgram(program, inputs, dtype); } var outReshaped = reshape({ inputs: { x: out }, backend: backend, attrs: { shape: outShape } }); intermediates.push(out); try { for (var intermediates_1 = __values(intermediates), intermediates_1_1 = intermediates_1.next(); !intermediates_1_1.done; intermediates_1_1 = intermediates_1.next()) { var i = intermediates_1_1.value; backend.disposeIntermediateTensorInfo(i); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (intermediates_1_1 && !intermediates_1_1.done && (_b = intermediates_1.return)) _b.call(intermediates_1); } finally { if (e_1) throw e_1.error; } } return outReshaped; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function _fusedMatMul(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var a = inputs.a, b = inputs.b, bias = inputs.bias, preluActivationWeights = inputs.preluActivationWeights; var transposeA = attrs.transposeA, transposeB = attrs.transposeB, activation = attrs.activation, leakyreluAlpha = attrs.leakyreluAlpha; return batchMatMulImpl({ a: a, b: b, transposeA: transposeA, transposeB: transposeB, backend: backend, bias: bias, preluActivationWeights: preluActivationWeights, leakyreluAlpha: leakyreluAlpha, activation: activation }); } var _fusedMatMulConfig = { kernelName: tf._FusedMatMul, backendName: 'webgl', kernelFunc: _fusedMatMul, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ABS = "return abs(x);"; function abs(args) { var inputs = args.inputs, backend = args.backend; var x = inputs.x; // TODO: handle cases when x is complex. Once the cpu implementation // can handle complex values, refactor to use unaryKernelFunc. if (backend.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') { var xData = backend.texData.get(x.dataId); var outValues = simpleAbsImplCPU(xData.values); return backend.makeTensorInfo(x.shape, x.dtype, outValues); } var program; if (tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { program = new UnaryOpPackedProgram(x.shape, ABS); } else { program = new UnaryOpProgram(x.shape, ABS); } return backend.runWebGLProgram(program, [x], x.dtype); } var absConfig = { kernelName: tf.Abs, backendName: 'webgl', kernelFunc: abs }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ACOS = CHECK_NAN_SNIPPET$1 + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return acos(x);\n"; var acos = unaryKernelFunc({ opSnippet: ACOS }); var acosConfig = { kernelName: tf.Acos, backendName: 'webgl', kernelFunc: acos, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ACOSH = CHECK_NAN_SNIPPET$1 + "\n if (x < 1.0) return NAN;\nreturn log(x + sqrt(x * x - 1.0));"; var acosh = unaryKernelFunc({ opSnippet: ACOSH }); var acoshConfig = { kernelName: tf.Acosh, backendName: 'webgl', kernelFunc: acosh, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ADD = 'return a + b;'; var addKernelFunc = binaryKernelFunc({ opSnippet: ADD, packedOpSnippet: ADD, supportsComplex: true, cpuKernelImpl: addImplCPU }); var addConfig = { kernelName: tf.Add, backendName: 'webgl', kernelFunc: addKernelFunc }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var AddNProgram = /** @class */ (function () { function AddNProgram(outputShape, shapes) { this.outputShape = []; this.outputShape = outputShape; this.variableNames = shapes.map(function (_, i) { return "T".concat(i); }); var snippets = []; // Get target elements from every input tensor. this.variableNames.forEach(function (variable) { snippets.push("float v".concat(variable, " = get").concat(variable, "AtOutCoords();")); }); // Calculate the sum of all elements. var operation = this.variableNames .map(function (variable) { return "v".concat(variable); }) .join(' + '); this.userCode = "\n void main() {\n ".concat(snippets.join('\n '), "\n\n float result = ").concat(operation, ";\n setOutput(result);\n }\n "); } return AddNProgram; }()); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var AddNPackedProgram = /** @class */ (function () { function AddNPackedProgram(outputShape, shapes) { this.outputShape = []; this.packedInputs = true; this.packedOutput = true; this.outputShape = outputShape; this.variableNames = shapes.map(function (_, i) { return "T".concat(i); }); var snippets = []; // Get target elements from every input tensor. this.variableNames.forEach(function (variable) { snippets.push("vec4 v".concat(variable, " = get").concat(variable, "AtOutCoords();")); }); // Calculate the sum of all elements. var operation = this.variableNames .map(function (variable) { return "v".concat(variable); }) .join(' + '); this.userCode = "\n void main() {\n ".concat(snippets.join('\n '), "\n\n vec4 result = ").concat(operation, ";\n setOutput(result);\n }\n "); } return AddNPackedProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function addN(args) { var inputs = args.inputs, backend = args.backend; var tensors = inputs; if (tensors.length === 1) { return identity({ inputs: { x: tensors[0] }, backend: backend }); } // Limit the number of uploaded textures for optimization. if (tensors.length > tf.env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) { var midIndex = Math.floor(tensors.length / 2); var leftSide = addN({ inputs: tensors.slice(0, midIndex), backend: backend }); var rightSide = addN({ inputs: tensors.slice(midIndex), backend: backend }); return addN({ inputs: [leftSide, rightSide], backend: backend }); } var dtype = tensors.map(function (t) { return t.dtype; }).reduce(function (d1, d2) { return tf.upcastType(d1, d2); }); var shapes = tensors.map(function (t) { return t.shape; }); // We can make sure shapes are identical in op level. var usePackedOp = tf.env().getBool('WEBGL_PACK'); var program = usePackedOp ? new AddNPackedProgram(tensors[0].shape, shapes) : new AddNProgram(tensors[0].shape, shapes); return backend.runWebGLProgram(program, tensors, dtype); } var addNConfig = { kernelName: tf.AddN, backendName: 'webgl', kernelFunc: addN }; function all(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var axis = attrs.axis, keepDims = attrs.keepDims; var xRank = x.shape.length; var origAxes = tf.util.parseAxisParam(axis, x.shape); var axes = origAxes; var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); var permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } }); axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); } tf.backend_util.assertAxesAreInnerMostDims('all', axes, xRank); var _a = __read(tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes), 2), outShape = _a[0], reduceShape = _a[1]; var inSize = tf.util.sizeFromShape(reduceShape); var a2D = reshape({ inputs: { x: permutedX }, backend: backend, attrs: { shape: [-1, inSize] } }); var reduced = reduce(a2D, a2D.dtype, 'all', backend); var res; if (keepDims) { var newShape = tf.backend_util.expandShapeToKeepDim(outShape, origAxes); res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: newShape } }); } else { res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: outShape } }); } backend.disposeIntermediateTensorInfo(a2D); backend.disposeIntermediateTensorInfo(reduced); if (permutedAxes != null) { backend.disposeIntermediateTensorInfo(permutedX); } return res; } var allConfig = { kernelName: tf.All, backendName: 'webgl', kernelFunc: all }; function any(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var axis = attrs.axis, keepDims = attrs.keepDims; var xRank = x.shape.length; var origAxes = tf.util.parseAxisParam(axis, x.shape); var axes = origAxes; var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); var permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } }); axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); } tf.backend_util.assertAxesAreInnerMostDims('any', axes, xRank); var _a = __read(tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes), 2), outShape = _a[0], reduceShape = _a[1]; var inSize = tf.util.sizeFromShape(reduceShape); var a2D = reshape({ inputs: { x: permutedX }, backend: backend, attrs: { shape: [-1, inSize] } }); var reduced = reduce(a2D, a2D.dtype, 'any', backend); var res; if (keepDims) { var newShape = tf.backend_util.expandShapeToKeepDim(outShape, origAxes); res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: newShape } }); } else { res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: outShape } }); } backend.disposeIntermediateTensorInfo(a2D); backend.disposeIntermediateTensorInfo(reduced); if (permutedAxes != null) { backend.disposeIntermediateTensorInfo(permutedX); } return res; } var anyConfig = { kernelName: tf.Any, backendName: 'webgl', kernelFunc: any }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ArgMinMaxProgram = /** @class */ (function () { function ArgMinMaxProgram(reduceInfo, op, firstPass) { this.variableNames = ['A']; var windowSize = reduceInfo.windowSize, batchSize = reduceInfo.batchSize, outSize = reduceInfo.outSize; if (!firstPass) { this.variableNames.push('bestIndicesA'); } this.outputShape = [batchSize, outSize]; var compOp = (op === 'max') ? '>' : '<'; var indexSnippet = firstPass ? 'inOffset + i;' : 'round(getBestIndicesA(batch, inOffset + i));'; this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * ".concat(windowSize, ";\n\n int bestIndex = inOffset;\n float bestValue = getA(batch, bestIndex);\n\n for (int i = 0; i < ").concat(windowSize, "; i++) {\n int inIdx = ").concat(indexSnippet, ";\n float candidate = getA(batch, inIdx);\n if (candidate ").concat(compOp, " bestValue) {\n bestValue = candidate;\n bestIndex = inIdx;\n }\n }\n setOutput(float(bestIndex));\n }\n "); } return ArgMinMaxProgram; }()); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ArgMinMaxPackedProgram = /** @class */ (function () { function ArgMinMaxPackedProgram(shape, windowSize, op, firstPass) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; tf.util.assert(shape.length > 2, function () { return "Packed arg".concat(op.charAt(0).toUpperCase() + op.slice(1), " supports only inputs with rank above 2."); }); var inSize = shape[shape.length - 1]; var outSize = Math.ceil(inSize / windowSize); this.outputShape = shape.slice(0, -1); if (outSize > 1) { this.outputShape.push(outSize); } if (!firstPass) { this.variableNames.push('bestIndicesA'); } var outShape = this.outputShape; var rank = outShape.length; var dtype = getCoordsDataType(rank); var coords = getChannels('coords', rank); var sourceLocSetup; var sourceRank; if (outSize === 1) { sourceRank = rank + 1; var sourceLocDType = getCoordsDataType(sourceRank); sourceLocSetup = "\n ".concat(sourceLocDType, " sourceLocR = ").concat(sourceLocDType, "(").concat(coords.join(), ", 0);\n ++").concat(coords[rank - 1], ";\n ").concat(sourceLocDType, " sourceLocG = ").concat(sourceLocDType, "(").concat(coords.join(), ", 0);\n ++").concat(coords[rank - 2], ";\n ").concat(sourceLocDType, " sourceLocA = ").concat(sourceLocDType, "(").concat(coords.join(), ", 0);\n --").concat(coords[rank - 1], ";\n ").concat(sourceLocDType, " sourceLocB = ").concat(sourceLocDType, "(").concat(coords.join(), ", 0);\n --").concat(coords[rank - 2], ";"); } else { sourceRank = rank; sourceLocSetup = "\n ".concat(dtype, " sourceLocR = coords;\n ++").concat(coords[rank - 1], ";\n ").concat(dtype, " sourceLocG = coords;\n ++").concat(coords[rank - 2], ";\n ").concat(dtype, " sourceLocA = coords;\n --").concat(coords[rank - 1], ";\n ").concat(dtype, " sourceLocB = coords;\n --").concat(coords[rank - 2], ";"); } var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank); var inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3. var intChannels = channels.map(function (x) { return 'int ' + x; }); var srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r'); var srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g'); var srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b'); var srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a'); var compOp = (op === 'max') ? 'greaterThan' : 'lessThan'; var fetchCandidateIdx = firstPass ? '' : "\n inIdx = round(vec4(getBestIndicesAChannel(".concat(srcRCoords.join(), "),\n getBestIndicesAChannel(").concat(srcGCoords.join(), "),\n getBestIndicesAChannel(").concat(srcBCoords.join(), "),\n getBestIndicesAChannel(").concat(srcACoords.join(), ")));"); var fetchValue = "vec4(\n getAChannel(".concat(srcRCoords.join(), "),\n hasNextCol ? getAChannel(").concat(srcGCoords.join(), ") : 0.,\n hasNextRow ? getAChannel(").concat(srcBCoords.join(), ") : 0.,\n hasNextRow && hasNextCol ? getAChannel(").concat(srcACoords.join(), ") : 0.)"); var getBestIndicesAChannelSnippet = firstPass ? '' : "\n float getBestIndicesAChannel(".concat(intChannels.join(), ") {\n return getChannel(getBestIndicesA(").concat(channels.join(), "),\n vec2(").concat(channels.slice(-2).join(), "));\n }"); this.userCode = "\n float getAChannel(".concat(intChannels.join(), ") {\n return getChannel(getA(").concat(channels.join(), "),\n vec2(").concat(channels.slice(-2).join(), "));\n }\n ").concat(getBestIndicesAChannelSnippet, "\n void main() {\n ").concat(dtype, " coords = getOutputCoords();\n bool hasNextCol = ").concat(coords[rank - 1], " < ").concat(outShape[rank - 1] - 1, ";\n bool hasNextRow = ").concat(coords[rank - 2], " < ").concat(outShape[rank - 2] - 1, ";\n ").concat(sourceLocSetup, "\n ivec4 srcIdx = ivec4(sourceLocR").concat(inChannel, ", sourceLocG").concat(inChannel, ",\n sourceLocB").concat(inChannel, ", sourceLocA").concat(inChannel, ") * ").concat(windowSize, ";\n ivec4 inIdx = srcIdx;\n vec4 bestIndex = vec4(inIdx);\n vec4 bestValue = ").concat(fetchValue, ";\n\n for (int i = 0; i < ").concat(windowSize, "; i++) {\n inIdx = srcIdx;\n ").concat(fetchCandidateIdx, "\n vec4 candidate = ").concat(fetchValue, ";\n bvec4 nan = isnan(candidate);\n bvec4 replace = bvec4(\n vec4(").concat(compOp, "(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n bestValue = vec4(replace.x ? candidate.x : bestValue.x,\n replace.y ? candidate.y : bestValue.y,\n replace.z ? candidate.z : bestValue.z,\n replace.w ? candidate.w : bestValue.w);\n bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n srcIdx++;\n }\n setOutput(bestIndex);\n }\n "); } return ArgMinMaxPackedProgram; }()); function argReduce(backend, x, reduceType, bestIndicesA) { if (bestIndicesA === void 0) { bestIndicesA = null; } var batchSize = x.shape[0]; var inSize = x.shape[1]; if (bestIndicesA != null) { batchSize = bestIndicesA.shape[0]; inSize = bestIndicesA.shape[1]; } var windowSize = tf.backend_util.computeOptimalWindowSize(inSize); var reduceInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize, outSize: Math.ceil(inSize / windowSize) }; var program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null); var inputs = [x]; if (bestIndicesA != null) { inputs.push(bestIndicesA); } var output = backend.runWebGLProgram(program, inputs, 'int32'); // No need to run another GPGPU program. if (output.shape[1] === 1) { return output; } var result = argReduce(backend, x, reduceType, output); backend.disposeIntermediateTensorInfo(output); return result; } function argReducePacked(backend, x, reduceType, bestIndicesA) { if (bestIndicesA === void 0) { bestIndicesA = null; } var inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape; var inSize = inShape[inShape.length - 1]; var windowSize = tf.backend_util.computeOptimalWindowSize(inSize); var program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null); var inputs = bestIndicesA == null ? [x] : [x, bestIndicesA]; var output = backend.runWebGLProgram(program, inputs, 'int32'); if (output.shape.length === x.shape.length) { var result = argReducePacked(backend, x, reduceType, output); backend.disposeIntermediateTensorInfo(output); return result; } return output; } function argMinMaxReduce(backend, x, axis, reduceType) { var axes = [axis]; tf.backend_util.assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.shape.length); if (!tf.env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) { var intermediateTensorInfos = []; // Eagerly unpack x input since it is passed in to all the shaders which // require unpacked inputs. var xtexData = backend.texData.get(x.dataId); var xIsPacked = xtexData !== null && xtexData.isPacked; var xUnPacked = x; if (xIsPacked) { xUnPacked = backend.unpackTensor(x); intermediateTensorInfos.push(xUnPacked); } var _a = __read(tf.backend_util.computeOutAndReduceShapes(xUnPacked.shape, axes), 2), outShape = _a[0], reduceShape = _a[1]; var inSize = tf.util.sizeFromShape(reduceShape); var a2D = reshape({ inputs: { x: xUnPacked }, backend: backend, attrs: { shape: [-1, inSize] } }); intermediateTensorInfos.push(a2D); var reduced = argReduce(backend, a2D, reduceType); intermediateTensorInfos.push(reduced); var reshaped = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: outShape } }); intermediateTensorInfos.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return reshaped; } return argReducePacked(backend, x, reduceType); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function argMax(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var axis = attrs.axis; var axes = tf.util.parseAxisParam(axis, x.shape); var permutedAxes = tf.backend_util.getAxesPermutation(axes, x.shape.length); var $x = x; var intermediateTensorInfos = []; if (permutedAxes != null) { $x = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } }); intermediateTensorInfos.push($x); axes = tf.backend_util.getInnerMostAxes(axes.length, $x.shape.length); } tf.backend_util.assertAxesAreInnerMostDims('argMax', [axes[0]], $x.shape.length); var out = argMinMaxReduce(backend, $x, axes[0], 'max'); intermediateTensorInfos.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return out; } var argMaxConfig = { kernelName: tf.ArgMax, backendName: 'webgl', kernelFunc: argMax }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function argMin(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var axis = attrs.axis; var axes = tf.util.parseAxisParam(axis, x.shape); var permutedAxes = tf.backend_util.getAxesPermutation(axes, x.shape.length); var $x = x; var intermediateTensorInfos = []; if (permutedAxes != null) { $x = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } }); intermediateTensorInfos.push($x); axes = tf.backend_util.getInnerMostAxes(axes.length, $x.shape.length); } tf.backend_util.assertAxesAreInnerMostDims('argMin', [axes[0]], $x.shape.length); var out = argMinMaxReduce(backend, $x, axes[0], 'min'); intermediateTensorInfos.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return out; } var argMinConfig = { kernelName: tf.ArgMin, backendName: 'webgl', kernelFunc: argMin }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ASIN = CHECK_NAN_SNIPPET$1 + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return asin(x);\n"; var asin = unaryKernelFunc({ opSnippet: ASIN }); var asinConfig = { kernelName: tf.Asin, backendName: 'webgl', kernelFunc: asin, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ASINH = CHECK_NAN_SNIPPET$1 + "return log(x + sqrt(x * x + 1.0));"; var asinh = unaryKernelFunc({ opSnippet: ASINH }); var asinhConfig = { kernelName: tf.Asinh, backendName: 'webgl', kernelFunc: asinh, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ATAN = CHECK_NAN_SNIPPET$1 + "\n return atan(x);\n"; var atan = unaryKernelFunc({ opSnippet: ATAN }); var atanConfig = { kernelName: tf.Atan, backendName: 'webgl', kernelFunc: atan, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ATAN2 = CHECK_NAN_SNIPPET + "\n return atan(a, b);\n"; var ATAN2_PACKED = "\n vec4 result = atan(a, b);\n bvec4 isNaNA = isnan(a);\n bvec4 isNaNB = isnan(b);\n bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);\n " + CHECK_NAN_SNIPPET_PACKED + "\n return result;\n"; var atan2 = binaryKernelFunc({ opSnippet: ATAN2, packedOpSnippet: ATAN2_PACKED }); var atan2Config = { kernelName: tf.Atan2, backendName: 'webgl', kernelFunc: atan2, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ATANH = CHECK_NAN_SNIPPET$1 + "\n if ((x < -1.0) || (x > 1.0)) return NAN;\nreturn (log(1.0 + x) - log(1.0 - x)) / 2.0;"; var atanh = unaryKernelFunc({ opSnippet: ATANH }); var atanhConfig = { kernelName: tf.Atanh, backendName: 'webgl', kernelFunc: atanh, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var Pool2DProgram = /** @class */ (function () { function Pool2DProgram(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) { if (flattenPositions === void 0) { flattenPositions = false; } if (includeBatchInIndex === void 0) { includeBatchInIndex = false; } this.variableNames = ['x']; if (poolType === 'avg' && computePositions) { throw new Error('Cannot compute positions for average pool.'); } var filterWidth = convInfo.filterWidth; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var dilationHeight = convInfo.dilationHeight; var dilationWidth = convInfo.dilationWidth; var effectiveFilterHeight = convInfo.effectiveFilterHeight; var effectiveFilterWidth = convInfo.effectiveFilterWidth; var padTop = convInfo.padInfo.top; var padLeft = convInfo.padInfo.left; this.outputShape = convInfo.outShape; var isAvgPool = poolType === 'avg'; var batchFlattenPositionStr = "((batch * ".concat(convInfo.inHeight, " + xR) * ").concat(convInfo.inWidth, " + xC) * ").concat(convInfo.inChannels, " + d"); var flattenPositionStr = "(xR * ".concat(convInfo.inWidth, " + xC) * ").concat(convInfo.inChannels, " + d"); var initializationValue = '0.0'; if (!isAvgPool) { // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. initializationValue = '-1.0 / 1e-20'; } if (computePositions) { var compareOp_1 = '>='; this.userCode = "\n const ivec2 strides = ivec2(".concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec2 pads = ivec2(").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n float avgValue = 0.0;\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, ";\n wC += ").concat(dilationWidth, ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n float value = getX(batch, xR, xC, d);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value ").concat(compareOp_1, " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = ").concat(flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr : flattenPositionStr) : "wR * ".concat(effectiveFilterWidth, " + wC"), ";\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n "); return; } var compareOp = 'max'; var returnValue = "".concat(poolType, "(").concat(poolType, "(").concat(poolType, "(") + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (poolType === 'avg') { returnValue = "avgValue / max(count, 1.0)"; } var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; var filterWidthVec4Remainder = filterWidth % 4; var updateSnippet = "\n if (".concat(isAvgPool, ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = ").concat(compareOp, "(values, minMaxValue);\n }\n "); this.userCode = "\n const ivec2 strides = ivec2(".concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec2 pads = ivec2(").concat(padTop, ", ").concat(padLeft, ");\n const float initializationValue = ").concat(initializationValue, ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xR, int xC, int d) {\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xR, xC, d);\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n vec4 minMaxValue = vec4(").concat(initializationValue, ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(filterWidthNearestVec4, "; wC += 4) {\n int xC = xCCorner + wC * ").concat(dilationWidth, ";\n\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + ").concat(dilationWidth, ", d),\n getValue(batch, xR, xC + 2 * ").concat(dilationWidth, ", d),\n getValue(batch, xR, xC + 3 * ").concat(dilationWidth, ", d)\n );\n\n ").concat(updateSnippet, "\n }\n\n int xC = xCCorner + ").concat(filterWidthNearestVec4, ";\n if (").concat(filterWidthVec4Remainder === 1, ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(filterWidthVec4Remainder === 2, ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + ").concat(dilationWidth, ", d),\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(filterWidthVec4Remainder === 3, ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + ").concat(dilationWidth, ", d),\n getValue(batch, xR, xC + 2 * ").concat(dilationWidth, ", d),\n initializationValue\n );\n\n ").concat(updateSnippet, "\n }\n }\n setOutput(").concat(returnValue, ");\n }\n "); } return Pool2DProgram; }()); var Pool3DProgram = /** @class */ (function () { function Pool3DProgram(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) { if (flattenPositions === void 0) { flattenPositions = false; } if (includeBatchInIndex === void 0) { includeBatchInIndex = false; } this.variableNames = ['x']; if (poolType === 'avg' && computePositions) { throw new Error('Cannot compute positions for average pool.'); } var filterWidth = convInfo.filterWidth; var strideDepth = convInfo.strideDepth; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var dilationDepth = convInfo.dilationDepth; var dilationHeight = convInfo.dilationHeight; var dilationWidth = convInfo.dilationWidth; var effectiveFilterDepth = convInfo.effectiveFilterDepth; var effectiveFilterHeight = convInfo.effectiveFilterHeight; var effectiveFilterWidth = convInfo.effectiveFilterWidth; var padFront = convInfo.padInfo.front; var padTop = convInfo.padInfo.top; var padLeft = convInfo.padInfo.left; this.outputShape = convInfo.outShape; var isAvgPool = poolType === 'avg'; var initializationValue = '0.0'; if (!isAvgPool) { // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. initializationValue = '-1.0 / 1e-20'; } if (computePositions) { var compareOp_2 = '>='; this.userCode = "\n const ivec3 strides =\n ivec3(".concat(strideDepth, ", ").concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec3 pads = ivec3(").concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n\n for (int wD = 0; wD < ").concat(effectiveFilterDepth, ";\n wD += ").concat(dilationDepth, ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= ").concat(convInfo.inDepth, ") {\n continue;\n }\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, ";\n wC += ").concat(dilationWidth, ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n float value = getX(batch, xD, xR, xC, ch);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value ").concat(compareOp_2, " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = ").concat(flattenPositions ? (includeBatchInIndex ? "(((batch * ".concat(convInfo.inDepth, " + xD) * ").concat(convInfo.inHeight, " + xR) * ").concat(convInfo.inWidth, " + xC) * ").concat(convInfo.inChannels, " + ch") : "((xD * ".concat(convInfo.inHeight, " + xR) * ").concat(convInfo.inWidth, " + xC) * ").concat(convInfo.inChannels, " + ch")) : "wD * ".concat(effectiveFilterHeight, " * ").concat(effectiveFilterWidth, " +\n wR * ").concat(effectiveFilterWidth, " + wC"), ";\n }\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n "); return; } var compareOp = 'max'; var returnValue = "".concat(poolType, "(").concat(poolType, "(").concat(poolType, "(") + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (poolType === 'avg') { // Use `max(count, 1.0)` instead of `count` in case count === 0.0. // If count === 0.0, `avgValue` is always 0.0 and we change `count`'s // value to avoid dividing zero. returnValue = "avgValue / max(count, 1.0)"; } var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; var filterWidthVec4Remainder = filterWidth % 4; var updateSnippet = "\n if (".concat(isAvgPool, ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = ").concat(compareOp, "(values, minMaxValue);\n }\n "); this.userCode = "\n const ivec3 strides =\n ivec3(".concat(strideDepth, ", ").concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec3 pads = ivec3(").concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n const float initializationValue = ").concat(initializationValue, ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xD, int xR, int xC, int ch) {\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xD, xR, xC, ch);\n }\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).\n // ? = to be determined\n vec4 minMaxValue = vec4(").concat(initializationValue, ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wD = 0; wD < ").concat(effectiveFilterDepth, ";\n wD += ").concat(dilationDepth, ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= ").concat(convInfo.inDepth, ") {\n continue;\n }\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(filterWidthNearestVec4, "; wC += 4) {\n int xC = xCCorner + wC * ").concat(dilationWidth, ";\n\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + ").concat(dilationWidth, ", ch),\n getValue(batch, xD, xR, xC + 2 * ").concat(dilationWidth, ", ch),\n getValue(batch, xD, xR, xC + 3 * ").concat(dilationWidth, ", ch)\n );\n\n ").concat(updateSnippet, "\n }\n\n int xC = xCCorner + ").concat(filterWidthNearestVec4, ";\n if (").concat(filterWidthVec4Remainder === 1, ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(filterWidthVec4Remainder === 2, ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + ").concat(dilationWidth, ", ch),\n initializationValue,\n initializationValue\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(filterWidthVec4Remainder === 3, ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + ").concat(dilationWidth, ", ch),\n getValue(batch, xD, xR, xC + 2 * ").concat(dilationWidth, ", ch),\n initializationValue\n );\n\n ").concat(updateSnippet, "\n }\n }\n }\n setOutput(").concat(returnValue, ");\n }\n "); } return Pool3DProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function avgPool(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; assertNotComplex(x, 'avgPool'); var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode; var dilations = 1; tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in avgPool: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'"); }); var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && tf.util.arraysEqual(convInfo.inShape, convInfo.outShape)) { return identity({ inputs: { x: x }, backend: backend }); } var avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false); return backend.runWebGLProgram(avgPoolProgram, [x], 'float32'); } var avgPoolConfig = { kernelName: tf.AvgPool, backendName: 'webgl', kernelFunc: avgPool }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function avgPool3D(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode, dataFormat = attrs.dataFormat; var dilations = [1, 1, 1]; var convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat); var avgPoolProgram = new Pool3DProgram(convInfo, 'avg', false); return backend.runWebGLProgram(avgPoolProgram, [x], 'float32'); } var avgPool3DConfig = { kernelName: tf.AvgPool3D, backendName: 'webgl', kernelFunc: avgPool3D }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var AvgPool2DBackpropProgram = /** @class */ (function () { function AvgPool2DBackpropProgram(convInfo) { this.variableNames = ['dy']; this.outputShape = convInfo.inShape; var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var dilationHeight = convInfo.dilationHeight; var dilationWidth = convInfo.dilationWidth; var effectiveFilterHeight = convInfo.effectiveFilterHeight; var effectiveFilterWidth = convInfo.effectiveFilterWidth; var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; var avgMultiplier = 1 / (filterHeight * filterWidth); this.userCode = "\n const ivec2 pads = ivec2(".concat(padTop, ", ").concat(padLeft, ");\n const float avgMultiplier = float(").concat(avgMultiplier, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, ";\n wC+= ").concat(dilationWidth, ") {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n setOutput(dotProd);\n }\n "); } return AvgPool2DBackpropProgram; }()); var AvgPool3DBackpropProgram = /** @class */ (function () { function AvgPool3DBackpropProgram(convInfo) { this.variableNames = ['dy']; this.outputShape = convInfo.inShape; var filterDepth = convInfo.filterDepth; var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var strideDepth = convInfo.strideDepth; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var dilationDepth = convInfo.dilationDepth; var dilationHeight = convInfo.dilationHeight; var dilationWidth = convInfo.dilationWidth; var effectiveFilterDepth = convInfo.effectiveFilterDepth; var effectiveFilterHeight = convInfo.effectiveFilterHeight; var effectiveFilterWidth = convInfo.effectiveFilterWidth; var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth); this.userCode = "\n const ivec3 pads = ivec3(".concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n const float avgMultiplier = float(").concat(avgMultiplier, ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < ").concat(effectiveFilterDepth, ";\n wD += ").concat(dilationDepth, ") {\n float dyD = float(dyDCorner + wD) / ").concat(strideDepth, ".0;\n\n if (dyD < 0.0 || dyD >= ").concat(convInfo.outDepth, ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, ";\n wC += ").concat(dilationWidth, ") {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n }\n setOutput(dotProd);\n }\n "); } return AvgPool3DBackpropProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function avgPool3DGrad(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var dy = inputs.dy, input = inputs.input; var x = input; var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode; var dilations = [1, 1, 1]; var convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); var avgPoolBackpropProgram = new AvgPool3DBackpropProgram(convInfo); return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype); } var avgPool3DGradConfig = { kernelName: tf.AvgPool3DGrad, backendName: 'webgl', kernelFunc: avgPool3DGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function avgPoolGrad(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var dy = inputs.dy, input = inputs.input; var x = input; assertNotComplex([dy, input], 'avgPoolGrad'); var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad; var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad); var avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo); return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype); } var avgPoolGradConfig = { kernelName: tf.AvgPoolGrad, backendName: 'webgl', kernelFunc: avgPoolGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function batchMatMul(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var a = inputs.a, b = inputs.b; var transposeA = attrs.transposeA, transposeB = attrs.transposeB; return batchMatMulImpl({ a: a, b: b, transposeA: transposeA, transposeB: transposeB, backend: backend }); } var batchMatMulConfig = { kernelName: tf.BatchMatMul, backendName: 'webgl', kernelFunc: batchMatMul, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var BatchNormProgram = /** @class */ (function () { function BatchNormProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) { this.outputShape = []; this.variableNames = ['x', 'mean', 'variance']; tf.backend_util.assertAndGetBroadcastShape(xShape, meanShape); tf.backend_util.assertAndGetBroadcastShape(xShape, varianceShape); var offsetSnippet = '0.0'; if (offsetShape != null) { tf.backend_util.assertAndGetBroadcastShape(xShape, offsetShape); this.variableNames.push('offset'); offsetSnippet = 'getOffsetAtOutCoords()'; } var scaleSnippet = '1.0'; if (scaleShape != null) { tf.backend_util.assertAndGetBroadcastShape(xShape, scaleShape); this.variableNames.push('scale'); scaleSnippet = 'getScaleAtOutCoords()'; } this.outputShape = xShape; this.userCode = "\n void main() {\n float x = getXAtOutCoords();\n float mean = getMeanAtOutCoords();\n float variance = getVarianceAtOutCoords();\n float offset = ".concat(offsetSnippet, ";\n float scale = ").concat(scaleSnippet, ";\n float inv = scale * inversesqrt(variance + float(").concat(varianceEpsilon, "));\n setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));\n }\n "); } return BatchNormProgram; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var BatchNormPackedProgram = /** @class */ (function () { function BatchNormPackedProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) { this.packedInputs = true; this.packedOutput = true; this.variableNames = ['x', 'mean', 'variance']; tf.backend_util.assertAndGetBroadcastShape(xShape, meanShape); tf.backend_util.assertAndGetBroadcastShape(xShape, varianceShape); var offsetSnippet = 'vec4(0.0)'; if (offsetShape != null) { tf.backend_util.assertAndGetBroadcastShape(xShape, offsetShape); this.variableNames.push('offset'); offsetSnippet = 'getOffsetAtOutCoords()'; } var scaleSnippet = 'vec4(1.0)'; if (scaleShape != null) { tf.backend_util.assertAndGetBroadcastShape(xShape, scaleShape); this.variableNames.push('scale'); scaleSnippet = 'getScaleAtOutCoords()'; } this.outputShape = xShape; this.userCode = "\n void main() {\n vec4 offset = ".concat(offsetSnippet, ";\n vec4 scale = ").concat(scaleSnippet, ";\n\n vec4 x = getXAtOutCoords();\n vec4 mean = getMeanAtOutCoords();\n vec4 variance = getVarianceAtOutCoords();\n\n vec4 inv = scale * inversesqrt(variance + vec4(").concat(varianceEpsilon, "));\n\n setOutput((x - mean) * inv + offset);\n }\n "); } return BatchNormPackedProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var batchNorm = function (_a) { var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs; var x = inputs.x, mean = inputs.mean, variance = inputs.variance, offset = inputs.offset, scale = inputs.scale; tf.util.assert(mean.shape.length === variance.shape.length, function () { return 'Batch normalization gradient requires mean and variance to have ' + 'equal ranks.'; }); tf.util.assert(offset == null || mean.shape.length === offset.shape.length, function () { return 'Batch normalization gradient requires mean and offset to have ' + 'equal ranks.'; }); tf.util.assert(scale == null || mean.shape.length === scale.shape.length, function () { return 'Batch normalization gradient requires mean and scale to have ' + 'equal ranks.'; }); var varianceEpsilon = attrs.varianceEpsilon; if (varianceEpsilon == null) { varianceEpsilon = 0.001; } var finalInputs = [x, mean, variance]; var offsetShape = null; if (offset != null) { offsetShape = offset.shape; finalInputs.push(offset); } var scaleShape = null; if (scale != null) { scaleShape = scale.shape; finalInputs.push(scale); } var program = tf.env().getBool('WEBGL_PACK_NORMALIZATION') ? new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) : new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon); var output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype); return output; }; var batchNormConfig = { kernelName: tf.FusedBatchNorm, backendName: 'webgl', kernelFunc: batchNorm, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SliceProgram = /** @class */ (function () { function SliceProgram(destSize) { this.variableNames = ['source']; this.outputShape = destSize; this.rank = destSize.length; var dtype = getCoordsDataType(this.rank); this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }]; var sourceCoords = getCoords$1(this.rank); var body; var coordSum = destSize.map(function (_, i) { return "sourceLoc.".concat(coords[i], " = start[").concat(i, "] + coords.").concat(coords[i], ";"); }); body = "\n ".concat(dtype, " sourceLoc;\n ").concat(dtype, " coords = getOutputCoords();\n ").concat(coordSum.join('\n'), "\n "); this.userCode = "\n void main() {\n ".concat(body, "\n setOutput(getSource(").concat(sourceCoords, "));\n }\n "); } return SliceProgram; }()); var coords = ['x', 'y', 'z', 'w', 'u', 'v']; function getCoords$1(rank) { if (rank === 1) { return 'sourceLoc'; } else if (rank <= 6) { return coords.slice(0, rank).map(function (x) { return 'sourceLoc.' + x; }).join(','); } else { throw Error("Slicing for rank ".concat(rank, " is not yet supported")); } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SlicePackedProgram = /** @class */ (function () { function SlicePackedProgram(destSize) { this.variableNames = ['source']; this.packedInputs = true; this.packedOutput = true; this.outputShape = destSize; this.rank = destSize.length; this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }]; var dtype = getCoordsDataType(this.rank); var coords = getChannels('coords', this.rank); var sourceLoc = getChannels('sourceLoc', this.rank); var innerDims = this.rank === 1 ? 'sourceLoc' : "vec2(".concat(sourceLoc.slice(-2).join(), ")"); var getChannel = "getChannel(getSource(".concat(sourceLoc.join(), "), ").concat(innerDims, ")"); var upperRow = "\n result.x = ".concat(getChannel, ";\n if (++").concat(coords[this.rank - 1], " < ").concat(destSize[this.rank - 1], ") {\n ++").concat(sourceLoc[this.rank - 1], ";\n result.y = ").concat(getChannel, ";\n --").concat(sourceLoc[this.rank - 1], ";\n }\n "); var lowerRow = this.rank === 1 ? '' : "\n --".concat(coords[this.rank - 1], ";\n if (++").concat(coords[this.rank - 2], " < ").concat(destSize[this.rank - 2], ") {\n ++").concat(sourceLoc[this.rank - 2], ";\n result.z = ").concat(getChannel, ";\n if (++").concat(coords[this.rank - 1], " < ").concat(destSize[this.rank - 1], ") {\n ++").concat(sourceLoc[this.rank - 1], ";\n result.w = ").concat(getChannel, ";\n }\n }\n "); var sourceLocSetup = this.rank <= 4 ? "sourceLoc = coords +\n ".concat(dtype, "(").concat(destSize.map(function (_, i) { return "start[".concat(i, "]"); }).join(), ");") : destSize.map(function (_, i) { return "".concat(sourceLoc[i], " = ").concat(coords[i], " + start[").concat(i, "];"); }) .join('\n'); this.userCode = "\n void main() {\n ".concat(dtype, " coords = getOutputCoords();\n ").concat(dtype, " sourceLoc;\n ").concat(sourceLocSetup, "\n vec4 result = vec4(0.);\n ").concat(upperRow, "\n ").concat(lowerRow, "\n setOutput(result);\n }\n "); } return SlicePackedProgram; }()); function shallowSlice(x, begin, size, backend) { var xTexData = backend.texData.get(x.dataId); var t = backend.makeTensorInfo(size, x.dtype); var newTexData = backend.texData.get(t.dataId); // Copy texture data from the original tensor. Object.assign(newTexData, xTexData); newTexData.refCount = 1; newTexData.shape = size; newTexData.dtype = x.dtype; var flatOffset = tf.slice_util.computeFlatOffset(begin, tf.util.computeStrides(x.shape)); if (xTexData.slice) { // We are slicing an already sliced tensor, so we have to accumulate // the offset. flatOffset += xTexData.slice.flatOffset; } newTexData.slice = { flatOffset: flatOffset, // Point to the original dataId, which is used to do ref counting. origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId }; // Increase the ref count for that data bucket. var refCount = backend.dataRefCount.get(newTexData.slice.origDataId) || 1; backend.dataRefCount.set(newTexData.slice.origDataId, refCount + 1); return t; } function slice(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var begin = attrs.begin, size = attrs.size; var _a = __read(tf.slice_util.parseSliceParams(x, begin, size), 2), $begin = _a[0], $size = _a[1]; tf.slice_util.assertParamsValid(x, $begin, $size); if (tf.util.sizeFromShape($size) === 0) { return backend.makeTensorInfo($size, x.dtype, []); } // Run on cpu if dtype is string. For string, the backend represents it // as Uint8Array[], where each Uint8Array is a character. Given that the // computation is only on the outer array, uploading the whole data onto // gpu is wasteful. Also, currently webgl doesn't have a design to // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we // just run the kernel on cpu if dtype is string. if (backend.shouldExecuteOnCPU([x]) || x.dtype === 'string') { var xTexData = backend.texData.get(x.dataId); var outValues = sliceImplCPU(xTexData.values, $begin, $size, x.shape, x.dtype); return backend.makeTensorInfo($size, x.dtype, outValues); } var isPacked = backend.texData.get(x.dataId).isPacked; var isContinous = tf.slice_util.isSliceContinous(x.shape, $begin, $size); if (isPacked || !isContinous) { var program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new SlicePackedProgram($size) : new SliceProgram($size); var customValues = [$begin]; return backend.runWebGLProgram(program, [x], x.dtype, customValues); } backend.uploadToGPU(x.dataId); return shallowSlice(x, $begin, $size, backend); } var sliceConfig = { kernelName: tf.Slice, backendName: 'webgl', kernelFunc: slice }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var batchToSpaceND = function (args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var blockShape = attrs.blockShape, crops = attrs.crops; tf.util.assert(x.shape.length <= 4, function () { return 'batchToSpaceND for rank > 4 with a WebGL backend not ' + 'implemented yet'; }); var prod = blockShape.reduce(function (a, b) { return a * b; }); var reshaped = tf.backend_util.getReshaped(x.shape, blockShape, prod); var permuted = tf.backend_util.getPermuted(reshaped.length, blockShape.length); var reshapedPermuted = tf.backend_util.getReshapedPermuted(x.shape, blockShape, prod); var sliceBeginCoords = tf.backend_util.getSliceBeginCoords(crops, blockShape.length); var sliceSize = tf.backend_util.getSliceSize(reshapedPermuted, crops, blockShape.length); var toDispose = []; var reshapedIntermediate = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: reshaped } }); var transposedIntermediate = transpose({ inputs: { x: reshapedIntermediate }, backend: backend, attrs: { perm: permuted } }); var reshapedIntermediate2 = reshape({ inputs: { x: transposedIntermediate }, backend: backend, attrs: { shape: reshapedPermuted } }); var sliced = slice({ inputs: { x: reshapedIntermediate2 }, backend: backend, attrs: { begin: sliceBeginCoords, size: sliceSize } }); toDispose.push(reshapedIntermediate); toDispose.push(transposedIntermediate); toDispose.push(reshapedIntermediate2); toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return sliced; }; var batchToSpaceNDConfig = { kernelName: tf.BatchToSpaceND, backendName: 'webgl', kernelFunc: batchToSpaceND }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function bincount(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, weights = inputs.weights; var size = attrs.size; var xVals = backend.readSync(x.dataId); var weightsVals = backend.readSync(weights.dataId); var outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size); return backend.makeTensorInfo([size], weights.dtype, outVals); } var bincountConfig = { kernelName: tf.Bincount, backendName: 'webgl', kernelFunc: bincount }; var BITWISEAND = "\n int r = int(a.r) & int(b.r);\n int g = int(a.g) & int(b.g);\n int rb = int(a.b) & int(b.b);\n int ra = int(a.a) & int(b.a);\n return vec4(r, g, rb, ra);\n"; var BITWISEAND_UNPACKED = "\n return float(int(a.r) & int(b.r));\n"; function bitwiseAnd(args) { var inputs = args.inputs, backend = args.backend; var a = inputs.a, b = inputs.b; var shouldUsePackedProgram = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); var versionNumber = tf.env().getNumber('WEBGL_VERSION'); // The type of a and b are ensured to be `int32` in core, therefore no need to // consider other type situations. if ((backend.shouldExecuteOnCPU([a, b])) || versionNumber === 1) { var aVals = backend.texData.get(a.dataId).values; var bVals = backend.texData.get(b.dataId).values; var _a = __read(bitwiseAndImplCPU(a.shape, b.shape, aVals, bVals, a.dtype), 2), outValues = _a[0], outShape = _a[1]; var out = backend.makeTensorInfo(outShape, a.dtype); var outData = backend.texData.get(out.dataId); outData.values = outValues; return out; } var program; if (shouldUsePackedProgram) { program = new BinaryOpPackedProgram(BITWISEAND, a.shape, b.shape, false); } else { program = new BinaryOpProgram(BITWISEAND_UNPACKED, a.shape, b.shape); } return backend.runWebGLProgram(program, [a, b], a.dtype); } var bitwiseAndConfig = { kernelName: tf.BitwiseAnd, backendName: 'webgl', kernelFunc: bitwiseAnd }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function broadcastArgs(args) { var inputs = args.inputs, backend = args.backend; var s0 = inputs.s0, s1 = inputs.s1; var s0Vals = backend.readSync(s0.dataId); var s1Vals = backend.readSync(s1.dataId); var broadcastShape = tf.backend_util.assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals)); return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape)); } var broadcastArgsConfig = { kernelName: tf.BroadcastArgs, backendName: 'webgl', kernelFunc: broadcastArgs }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var NOT_EQUAL = "return float(a != b);"; var notEqual = binaryKernelFunc({ opSnippet: NOT_EQUAL, cpuKernelImpl: notEqualImplCPU, dtype: 'bool' }); var notEqualConfig = { kernelName: tf.NotEqual, backendName: 'webgl', kernelFunc: notEqual, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function real(args) { var inputs = args.inputs, backend = args.backend; var input = inputs.input; var inputData = backend.texData.get(input.dataId); return identity({ inputs: { x: inputData.complexTensorInfos.real }, backend: backend }); } var realConfig = { kernelName: tf.Real, backendName: 'webgl', kernelFunc: real }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var TO_INT = "return float(int(x));"; function int(input, backend) { var program = new UnaryOpProgram(input.shape, TO_INT); var output = backend.runWebGLProgram(program, [input], 'int32'); return { dataId: output.dataId, shape: output.shape, dtype: output.dtype }; } function cast(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var dtype = attrs.dtype; // Casting to complex64. if (dtype === 'complex64') { if (x.dtype === 'complex64') { return identity({ inputs: { x: x }, backend: backend }); } // TODO(annxingyuan): Import kernel function once zeros is modularized. var zerosTensor = tf__namespace.zeros(x.shape); var floatX = cast({ inputs: { x: x }, backend: backend, attrs: { dtype: 'float32' } }); var result = complex({ inputs: { real: floatX, imag: zerosTensor }, backend: backend }); zerosTensor.dispose(); backend.disposeIntermediateTensorInfo(floatX); return result; } // Casting from complex64 if (x.dtype === 'complex64') { var realPart = real({ inputs: { input: x }, backend: backend }); var result = cast({ inputs: { x: realPart }, backend: backend, attrs: { dtype: dtype } }); backend.disposeIntermediateTensorInfo(realPart); return result; } if (!tf.util.hasEncodingLoss(x.dtype, dtype)) { // We don't change the underlying data, since we cast to higher // precision. var result = identity({ inputs: { x: x }, backend: backend }); return { dataId: result.dataId, shape: result.shape, dtype: dtype }; } if (backend.shouldExecuteOnCPU([x])) { var values = backend.texData.get(x.dataId).values; var _a = __read(castImplCPU(values, x.shape, x.dtype, dtype), 3), resultShape = _a[0], resultType = _a[1], resultData = _a[2]; return backend.makeTensorInfo(resultShape, resultType, resultData); } if (dtype === 'int32') { return int(x, backend); } if (dtype === 'bool') { var zerosTensorInfo = backend.makeTensorInfo([], 'bool', tf.util.getTypedArrayFromDType('bool', 1)); var binaryInputs = { a: x, b: zerosTensorInfo }; var result = notEqual({ inputs: binaryInputs, backend: backend }); backend.disposeIntermediateTensorInfo(zerosTensorInfo); return result; } throw new Error("Error in Cast: failed to cast ".concat(x.dtype, " to ").concat(dtype)); } var castConfig = { kernelName: tf.Cast, backendName: 'webgl', kernelFunc: cast }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var CEIL = "return ceil(x);"; var ceil = unaryKernelFunc({ opSnippet: CEIL, packedOpSnippet: CEIL, cpuKernelImpl: ceilImplCPU }); var ceilConfig = { kernelName: tf.Ceil, backendName: 'webgl', kernelFunc: ceil }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ClipProgram = /** @class */ (function () { function ClipProgram(aShape) { this.variableNames = ['A']; this.customUniforms = [ { name: 'minVal', type: 'float' }, { name: 'maxVal', type: 'float' } ]; this.outputShape = aShape; this.userCode = "\n\n void main() {\n float value = getAAtOutCoords();\n if (isnan(value)) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, minVal, maxVal));\n }\n "; } return ClipProgram; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ClipPackedProgram = /** @class */ (function () { function ClipPackedProgram(aShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'minVal', type: 'float' }, { name: 'maxVal', type: 'float' } ]; this.outputShape = aShape; this.userCode = "\n void main() {\n vec4 value = getAAtOutCoords();\n\n if (any(isnan(value))) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, vec4(minVal), vec4(maxVal)));\n }\n "; } return ClipPackedProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function clipByValue(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var clipValueMin = attrs.clipValueMin, clipValueMax = attrs.clipValueMax; var program; if (tf.env().getBool('WEBGL_PACK_CLIP')) { program = new ClipPackedProgram(x.shape); } else { program = new ClipProgram(x.shape); } var customValues = [[clipValueMin], [clipValueMax]]; return backend.runWebGLProgram(program, [x], x.dtype, customValues); } var clipByValueConfig = { kernelName: tf.ClipByValue, backendName: 'webgl', kernelFunc: clipByValue }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ComplexAbsProgram = /** @class */ (function () { function ComplexAbsProgram(shape) { this.variableNames = ['real', 'imag']; this.outputShape = shape; this.userCode = "\n void main() {\n float re = abs(getRealAtOutCoords());\n float im = abs(getImagAtOutCoords());\n float mx = max(re, im);\n\n // sadly the length function in glsl is not underflow-safe\n // (at least not on Intel GPUs). So the safe solution is\n // to ensure underflow-safety in all cases.\n setOutput(\n mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))\n );\n }\n "; } return ComplexAbsProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Returns a TensorInfo with the complex shape and the dataId of the // underlying part. We need to do this because a reshaped complex tensor is // not reflected in its parts. function makeComplexComponentTensorInfo(complexTensor, complexPart) { return { dataId: complexPart.dataId, dtype: complexPart.dtype, shape: complexTensor.shape }; } function complexAbs(args) { var inputs = args.inputs, backend = args.backend; var x = inputs.x; var xData = backend.texData.get(x.dataId); var program = new ComplexAbsProgram(x.shape); var programInputs = [ makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real), makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag), ]; return backend.runWebGLProgram(program, programInputs, programInputs[0].dtype); } var complexAbsConfig = { kernelName: tf.ComplexAbs, backendName: 'webgl', kernelFunc: complexAbs }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ConcatProgram = /** @class */ (function () { // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat(). function ConcatProgram(shapes) { this.outputShape = []; this.outputShape = tf.backend_util.computeOutShape(shapes, 1 /* axis */); this.variableNames = shapes.map(function (_, i) { return "T".concat(i); }); var offsets = new Array(shapes.length - 1); offsets[0] = shapes[0][1]; for (var i = 1; i < offsets.length; i++) { offsets[i] = offsets[i - 1] + shapes[i][1]; } var snippets = ["if (yC < ".concat(offsets[0], ") setOutput(getT0(yR, yC));")]; for (var i = 1; i < offsets.length; i++) { var shift = offsets[i - 1]; snippets.push("else if (yC < ".concat(offsets[i], ") ") + "setOutput(getT".concat(i, "(yR, yC-").concat(shift, "));")); } var lastIndex = offsets.length; var lastShift = offsets[offsets.length - 1]; snippets.push("else setOutput(getT".concat(lastIndex, "(yR, yC-").concat(lastShift, "));")); this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int yR = coords.x;\n int yC = coords.y;\n\n ".concat(snippets.join('\n '), "\n }\n "); } return ConcatProgram; }()); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ConcatPackedProgram = /** @class */ (function () { function ConcatPackedProgram(shapes, axis) { this.packedInputs = true; this.packedOutput = true; this.outputShape = []; this.outputShape = tf.backend_util.computeOutShape(shapes, axis); var shape = this.outputShape; var rank = shape.length; var dtype = getCoordsDataType(rank); var coords = getChannels('coords', rank); var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank); this.variableNames = shapes.map(function (_, i) { return "T".concat(i); }); var offsets = new Array(shapes.length - 1); offsets[0] = shapes[0][axis]; for (var i = 1; i < offsets.length; i++) { offsets[i] = offsets[i - 1] + shapes[i][axis]; } var channel = channels[axis]; var lastChannels = channels.slice(-2); var allChannels = channels.join(); var getValueSnippet = "if (".concat(channel, " < ").concat(offsets[0], ") {\n return getChannel(\n getT0(").concat(allChannels, "), vec2(").concat(lastChannels.join(), "));\n }"); for (var i = 1; i < offsets.length; i++) { var shift_1 = offsets[i - 1]; // Note: the >= comparison below may seem unnecessary given the check // above but is needed to workaround branch execution issues on some // devices. It makes all the conditions exclusive without relying on // execution order. getValueSnippet += "\n if (".concat(channel, " < ").concat(offsets[i], " && ").concat(channel, " >= ").concat(offsets[i - 1], ") {\n return getChannel(\n getT").concat(i, "(").concat(shiftedChannels(channels, channel, shift_1), "),\n vec2(").concat(shiftedChannels(lastChannels, channel, shift_1), "));\n }"); } var lastIndex = offsets.length; var shift = offsets[offsets.length - 1]; getValueSnippet += "\n return getChannel(\n getT".concat(lastIndex, "(").concat(shiftedChannels(channels, channel, shift), "),\n vec2(").concat(shiftedChannels(lastChannels, channel, shift), "));"); this.userCode = "\n float getValue(".concat(channels.map(function (x) { return 'int ' + x; }), ") {\n ").concat(getValueSnippet, "\n }\n\n void main() {\n ").concat(dtype, " coords = getOutputCoords();\n vec4 result = vec4(getValue(").concat(coords, "), 0., 0., 0.);\n\n ").concat(coords[rank - 1], " = ").concat(coords[rank - 1], " + 1;\n if (").concat(coords[rank - 1], " < ").concat(shape[rank - 1], ") {\n result.g = getValue(").concat(coords, ");\n }\n\n ").concat(coords[rank - 2], " = ").concat(coords[rank - 2], " + 1;\n if (").concat(coords[rank - 2], " < ").concat(shape[rank - 2], ") {\n result.a = getValue(").concat(coords, ");\n }\n\n ").concat(coords[rank - 1], " = ").concat(coords[rank - 1], " - 1;\n if (").concat(coords[rank - 2], " < ").concat(shape[rank - 2], " &&\n ").concat(coords[rank - 1], " < ").concat(shape[rank - 1], ") {\n result.b = getValue(").concat(coords, ");\n }\n setOutput(result);\n }\n "); } return ConcatPackedProgram; }()); /** * Return an expression for coordinates into a vector where a given channel * will be offset by [shift]. * * @param channels the channels to consider * @param channel the channel we want shifted * @param shift the amount to subtract from the channel. * * @returns a string of the form 'x, y-[shift], z' where any one channel can * have the shift applied. */ function shiftedChannels(channels, channel, shift) { var channelIdx = channels.indexOf(channel); var res = channels.map(function (c, idx) { if (idx === channelIdx) { return "".concat(c, " - ").concat(shift); } else { return c; } }); return res.join(); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function imag(args) { var inputs = args.inputs, backend = args.backend; var input = inputs.input; var inputData = backend.texData.get(input.dataId); return identity({ inputs: { x: inputData.complexTensorInfos.imag }, backend: backend }); } var imagConfig = { kernelName: tf.Imag, backendName: 'webgl', kernelFunc: imag }; function concatImpl(inputs, axis, backend) { var e_1, _a; var dtype = inputs[0].dtype; if (dtype === 'complex64') { var reals = inputs.map(function (t) { return real({ inputs: { input: t }, backend: backend }); }); var imags = inputs.map(function (t) { return imag({ inputs: { input: t }, backend: backend }); }); var realConcated = concatImpl(reals, axis, backend); var imagConcated = concatImpl(imags, axis, backend); var result_1 = complex({ inputs: { real: realConcated, imag: imagConcated }, backend: backend }); reals.forEach(function (r) { return backend.disposeIntermediateTensorInfo(r); }); imags.forEach(function (i) { return backend.disposeIntermediateTensorInfo(i); }); backend.disposeIntermediateTensorInfo(realConcated); backend.disposeIntermediateTensorInfo(imagConcated); return result_1; } var runOnCpu = backend.shouldExecuteOnCPU(inputs); // Run on cpu if dtype is string. For string, the backend represents it // as Uint8Array[], where each Uint8Array is a character. Given that the // computation is only on the outer array, uploading the whole data onto // gpu is wasteful. Also, currently webgl doesn't have a design to // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we // just run the kernel on cpu if dtype is string. if (dtype === 'string') { runOnCpu = true; } if (runOnCpu) { // Any concat of n-dimensional tensors across any axis can be reduced to // a concatenation of two-dimensional tensors across the axis 1 by first // partitioning the axes of the original tensors into those less than the // axis to be concatenated and the rest. Then reshape the tensors // into a two-dimensional tensor by collapsing these two sets of axes and // concatenate the resulting matrices across the axis 1, finally reshaping // the result to have the proper shape. var tensors2D_1 = inputs.map(function (t) { var innerSize = tf.util.sizeFromShape(t.shape.slice(axis)); var shape = [-1, innerSize]; return reshape({ inputs: { x: t }, backend: backend, attrs: { shape: shape } }); }); var inputsValShapes = tensors2D_1.map(function (t) { return { vals: backend.readSync(t.dataId), shape: t.shape }; }); // Concats 2d tensors along axis=1. var outShape_1 = tf.backend_util.computeOutShape(tensors2D_1.map(function (t) { return t.shape; }), 1 /* axis */); var simplyConcat = tensors2D_1[0].shape[0] === 1; var outVals = concatImplCPU(inputsValShapes, outShape_1, dtype, simplyConcat); var finalOutShape = tf.backend_util.computeOutShape(inputs.map(function (t) { return t.shape; }), axis); var outInfo = backend.makeTensorInfo(finalOutShape, dtype, outVals); tensors2D_1.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return outInfo; } // Keep only non-empty tensors (ignore tensors with 0 in their shape). var $inputs = inputs.filter(function (t) { return tf.util.sizeFromShape(t.shape) > 0; }); var shouldPack = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && $inputs[0].shape.length > 1; if ($inputs.length === 1) { // Clone tensor. var program_1 = shouldPack ? new UnaryOpProgram(inputs[0].shape, CLONE) : new UnaryOpPackedProgram(inputs[0].shape, CLONE); return backend.runWebGLProgram(program_1, inputs, dtype); } var maxTexturesInShader = tf.env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER'); if ($inputs.length > maxTexturesInShader) { var reducedInputs = []; for (var i = 0; i < $inputs.length; i += maxTexturesInShader) { var subArray = $inputs.slice(i, i + maxTexturesInShader); reducedInputs.push(concatImpl(subArray, axis, backend)); } var result_2 = concatImpl(reducedInputs, axis, backend); try { for (var reducedInputs_1 = __values(reducedInputs), reducedInputs_1_1 = reducedInputs_1.next(); !reducedInputs_1_1.done; reducedInputs_1_1 = reducedInputs_1.next()) { var i = reducedInputs_1_1.value; backend.disposeIntermediateTensorInfo(i); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (reducedInputs_1_1 && !reducedInputs_1_1.done && (_a = reducedInputs_1.return)) _a.call(reducedInputs_1); } finally { if (e_1) throw e_1.error; } } return result_2; } if (shouldPack) { var program_2 = new ConcatPackedProgram($inputs.map(function (t) { return t.shape; }), axis); return backend.runWebGLProgram(program_2, $inputs, dtype); } var _b = computeTensors2D($inputs, axis, backend), tensors2D = _b.tensors2D, outShape = _b.outShape; var program = new ConcatProgram(tensors2D.map(function (t) { return t.shape; })); var result = backend.runWebGLProgram(program, tensors2D, dtype); tensors2D.forEach(function (r) { return backend.disposeIntermediateTensorInfo(r); }); var reshapedResult = reshape({ inputs: { x: result }, attrs: { shape: outShape }, backend: backend }); backend.disposeIntermediateTensorInfo(result); return reshapedResult; } function computeTensors2D(inputs, axis, backend) { // Any concat of n-dimensional tensors across any axis can be reduced to // a concatenation of two-dimensional tensors across the axis 1 by first // partitioning the axes of the original tensors into those less than the // axis to be concatenated and the rest. Then reshape the tensors // into a two-dimensional tensor by collapsing these two sets of axes and // concatenate the resulting matrices across the axis 1, finally reshaping // the result to have the proper shape. var outShape = tf.backend_util.computeOutShape(inputs.map(function (t) { return t.shape; }), axis); var tensors2D = inputs.map(function (x) { return reshape({ inputs: { x: x }, attrs: { shape: [-1, tf.util.sizeFromShape(x.shape.slice(axis))] }, backend: backend }); }); return { tensors2D: tensors2D, outShape: outShape }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function concat(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var axis = attrs.axis; var $axis = tf.util.parseAxisParam(axis, inputs[0].shape)[0]; var shapes = inputs.map(function (t) { return t.shape; }); tf.backend_util.assertParamsConsistent(shapes, $axis); var outShape = tf.backend_util.computeOutShape(inputs.map(function (t) { return t.shape; }), $axis); if (tf.util.sizeFromShape(outShape) === 0) { return backend.makeTensorInfo(outShape, inputs[0].dtype, []); } // Keep only non-empty tensors (ignore tensors with 0 in their shape). var $inputs = inputs.filter(function (t) { return tf.util.sizeFromShape(t.shape) > 0; }); if ($inputs.length === 1) { return identity({ inputs: { x: $inputs[0] }, backend: backend }); } return concatImpl($inputs, $axis, backend); } var concatConfig = { kernelName: tf.Concat, backendName: 'webgl', kernelFunc: concat }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var Conv2DProgram = /** @class */ (function () { function Conv2DProgram(convInfo, addBias, activation, hasPreluActivationWeights, hasLeakyreluAlpha) { if (addBias === void 0) { addBias = false; } if (activation === void 0) { activation = null; } if (hasPreluActivationWeights === void 0) { hasPreluActivationWeights = false; } if (hasLeakyreluAlpha === void 0) { hasLeakyreluAlpha = false; } this.variableNames = ['x', 'W']; this.outputShape = convInfo.outShape; var padTop = convInfo.padInfo.top; var padLeft = convInfo.padInfo.left; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var dilationHeight = convInfo.dilationHeight; var dilationWidth = convInfo.dilationWidth; var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4; var inputDepthVec4Remainder = convInfo.inChannels % 4; var isChannelsLast = convInfo.dataFormat === 'channelsLast'; var rowDim = isChannelsLast ? 1 : 2; var colDim = isChannelsLast ? 2 : 3; var channelDim = isChannelsLast ? 3 : 1; var activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivationWeights) { activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n ".concat(activation, "\n }"); } else if (hasLeakyreluAlpha) { activationSnippet = "float activation(float a) {\n float b = getLeakyreluAlphaAtOutCoords();\n ".concat(activation, "\n }"); } else { activationSnippet = "\n float activation(float x) {\n ".concat(activation, "\n }\n "); } applyActivationSnippet = "result = activation(result);"; } var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivationWeights) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyreluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = "\n ".concat(activationSnippet, "\n\n const ivec2 strides = ivec2(").concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec2 pads = ivec2(").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d2 = coords[").concat(channelDim, "];\n\n ivec2 xRCCorner =\n ivec2(coords[").concat(rowDim, "], coords[").concat(colDim, "]) * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n int xR = xRCorner + wR * ").concat(dilationHeight, ";\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n int xC = xCCorner + wC * ").concat(dilationWidth, ";\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n for (int d1 = 0; d1 < ").concat(inputDepthNearestVec4, "; d1 += 4) {\n vec4 wValues = vec4(\n getW(wR, wC, d1, d2),\n getW(wR, wC, d1 + 1, d2),\n getW(wR, wC, d1 + 2, d2),\n getW(wR, wC, d1 + 3, d2)\n );\n\n if (").concat(isChannelsLast, ") {\n vec4 xValues = vec4(\n getX(batch, xR, xC, d1),\n getX(batch, xR, xC, d1 + 1),\n getX(batch, xR, xC, d1 + 2),\n getX(batch, xR, xC, d1 + 3)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec4 xValues = vec4(\n getX(batch, d1, xR, xC),\n getX(batch, d1 + 1, xR, xC),\n getX(batch, d1 + 2, xR, xC),\n getX(batch, d1 + 3, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n\n if (").concat(inputDepthVec4Remainder === 1, ") {\n\n if (").concat(isChannelsLast, ") {\n dotProd +=\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, ") *\n getW(wR, wC, ").concat(inputDepthNearestVec4, ", d2);\n } else {\n dotProd +=\n getX(batch, ").concat(inputDepthNearestVec4, ", xR, xC) *\n getW(wR, wC, ").concat(inputDepthNearestVec4, ", d2);\n }\n\n } else if (").concat(inputDepthVec4Remainder === 2, ") {\n vec2 wValues = vec2(\n getW(wR, wC, ").concat(inputDepthNearestVec4, ", d2),\n getW(wR, wC, ").concat(inputDepthNearestVec4, " + 1, d2)\n );\n\n if (").concat(isChannelsLast, ") {\n vec2 xValues = vec2(\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, "),\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, " + 1)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec2 xValues = vec2(\n getX(batch, ").concat(inputDepthNearestVec4, ", xR, xC),\n getX(batch, ").concat(inputDepthNearestVec4, " + 1, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n } else if (").concat(inputDepthVec4Remainder === 3, ") {\n vec3 wValues = vec3(\n getW(wR, wC, ").concat(inputDepthNearestVec4, ", d2),\n getW(wR, wC, ").concat(inputDepthNearestVec4, " + 1, d2),\n getW(wR, wC, ").concat(inputDepthNearestVec4, " + 2, d2)\n );\n\n if (").concat(isChannelsLast, ") {\n vec3 xValues = vec3(\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, "),\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, " + 1),\n getX(batch, xR, xC, ").concat(inputDepthNearestVec4, " + 2)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec3 xValues = vec3(\n getX(batch, ").concat(inputDepthNearestVec4, ", xR, xC),\n getX(batch, ").concat(inputDepthNearestVec4, " + 1, xR, xC),\n getX(batch, ").concat(inputDepthNearestVec4, " + 2, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n }\n }\n }\n\n float result = dotProd;\n ").concat(addBiasSnippet, "\n ").concat(applyActivationSnippet, "\n setOutput(result);\n }\n "); } return Conv2DProgram; }()); var Conv3DProgram = /** @class */ (function () { function Conv3DProgram(convInfo) { this.variableNames = ['x', 'W']; this.outputShape = convInfo.outShape; var padFront = convInfo.padInfo.front; var padTop = convInfo.padInfo.top; var padLeft = convInfo.padInfo.left; var strideDepth = convInfo.strideDepth; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var dilationDepth = convInfo.dilationDepth; var dilationHeight = convInfo.dilationHeight; var dilationWidth = convInfo.dilationWidth; var filterDepth = convInfo.filterDepth; var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4; var inputDepthVec4Remainder = convInfo.inChannels % 4; this.userCode = "\n const ivec3 strides = ivec3(".concat(strideDepth, ", ").concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec3 pads = ivec3(").concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d2 = coords.u;\n\n ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xFCorner = xFRCCorner.x;\n int xRCorner = xFRCCorner.y;\n int xCCorner = xFRCCorner.z;\n\n // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get\n // y(yF, yR, yC, d2). ? = to be determined. : = across all\n // values in that axis.\n float dotProd = 0.0;\n for (int wF = 0; wF < ").concat(filterDepth, "; wF++) {\n int xF = xFCorner + wF * ").concat(dilationDepth, ";\n\n if (xF < 0 || xF >= ").concat(convInfo.inDepth, ") {\n continue;\n }\n\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n int xR = xRCorner + wR * ").concat(dilationHeight, ";\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n int xC = xCCorner + wC * ").concat(dilationWidth, ";\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n for (int d1 = 0; d1 < ").concat(inputDepthNearestVec4, "; d1 += 4) {\n vec4 xValues = vec4(\n getX(batch, xF, xR, xC, d1),\n getX(batch, xF, xR, xC, d1 + 1),\n getX(batch, xF, xR, xC, d1 + 2),\n getX(batch, xF, xR, xC, d1 + 3)\n );\n vec4 wValues = vec4(\n getW(wF, wR, wC, d1, d2),\n getW(wF, wR, wC, d1 + 1, d2),\n getW(wF, wR, wC, d1 + 2, d2),\n getW(wF, wR, wC, d1 + 3, d2)\n );\n\n dotProd += dot(xValues, wValues);\n }\n\n if (").concat(inputDepthVec4Remainder === 1, ") {\n dotProd +=\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, ") *\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, ", d2);\n } else if (").concat(inputDepthVec4Remainder === 2, ") {\n vec2 xValues = vec2(\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, "),\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, " + 1)\n );\n vec2 wValues = vec2(\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, ", d2),\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, " + 1, d2)\n );\n dotProd += dot(xValues, wValues);\n } else if (").concat(inputDepthVec4Remainder === 3, ") {\n vec3 xValues = vec3(\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, "),\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, " + 1),\n getX(batch, xF, xR, xC, ").concat(inputDepthNearestVec4, " + 2)\n );\n vec3 wValues = vec3(\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, ", d2),\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, " + 1, d2),\n getW(wF, wR, wC, ").concat(inputDepthNearestVec4, " + 2, d2)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n "); } return Conv3DProgram; }()); /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var Conv2DPackedProgram = /** @class */ (function () { function Conv2DPackedProgram(convInfo, addBias, activation, hasPreluActivation, hasLeakyReluAlpha) { if (addBias === void 0) { addBias = false; } if (activation === void 0) { activation = null; } if (hasPreluActivation === void 0) { hasPreluActivation = false; } if (hasLeakyReluAlpha === void 0) { hasLeakyReluAlpha = false; } this.variableNames = ['x', 'W']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'pads', type: 'ivec2' }, { name: 'strides', type: 'ivec2' }, { name: 'dilations', type: 'ivec2' }, { name: 'inDims', type: 'ivec2' }, ]; this.outputShape = convInfo.outShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); var padLeft = convInfo.padInfo.left; var strideWidth = convInfo.strideWidth; var dilationWidth = convInfo.dilationWidth; var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var texelsAcross = filterWidth; var mainLoop = "\n int xR; int xC; int xCOffset;\n vec4 wTexel; vec4 previous; vec4 final;"; for (var c = 0; c < filterWidth; c++) { mainLoop += "\n vec4 xTexelC".concat(c * 2, ";\n int xTexelC").concat(c * 2, "Ready;\n vec4 xTexelC").concat(c * 2 + 1, ";\n int xTexelC").concat(c * 2 + 1, "Ready;\n vec4 xC").concat(c, ";"); } /** * This vectorized implementation works by gathering the values needed for * each output channel's dot product into vec4's and then multiplying them * all together (this happens in the final double for-loop below). Most of * the main loop consists of constructing these vec4's with the minimum * number of texture2D calls, which means making use of all four returned * values from a texture2D call at once. */ mainLoop += "\n for (int r = 0; r < ".concat(filterHeight, "; r++) {\n for (int d1 = 0; d1 < ").concat(convInfo.inChannels, "; d1 += 2) {\n "); for (var c = 0; c < filterWidth; c++) { mainLoop += "\n xTexelC".concat(c * 2, " = vec4(0.0);\n xTexelC").concat(c * 2, "Ready = 0;\n xTexelC").concat(c * 2 + 1, " = vec4(0.0);\n xTexelC").concat(c * 2 + 1, "Ready = 0;\n xC").concat(c, " = vec4(0.0);"); } mainLoop += "\n xR = xRCorner + r * dilations[0];\n if (xR >=0 && xR < inDims[0]) {\n "; for (var texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) { var colIndex = texelC * 2; mainLoop += "\n xC = xCCorner + ".concat(colIndex * dilationWidth, ";\n "); if (strideWidth === 1) { if (colIndex < filterWidth) { // If padding is odd, the outer texels have to be composed. if (padLeft % 2 === 1) { // TODO: Ensure vec4 previous does not result in redundant sample, // and avoid setting xTexelRC's that exceed the boundary in the // first place rather than resetting them to vec4(0)). // To compute xCOffset: // - If padding is odd, we must add 1 to ensure we ask for an // even-numbered row. // - We subtract 2 to access the previous texel. mainLoop += "\n xCOffset = xC + 1;\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n "); // This texel has been read in previous iteration if the dilation // is 1. if (dilationWidth === 1 && colIndex > 0) { mainLoop += "\n xC".concat(colIndex, " = vec4(xTexelC").concat(colIndex - 2, ".zw, xTexelC").concat(colIndex, ".xy);\n "); } else { mainLoop += "\n xCOffset = xC + 1 - 2;\n\n if (xCOffset >= 0 && xCOffset < inDims[1]) {\n previous = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n previous.zw = vec2(0.0);\n }\n\n xC".concat(colIndex, " = vec4(previous.zw, xTexelC").concat(colIndex, ".xy);\n } else {\n xC").concat(colIndex, " = vec4(0.0, 0.0, xTexelC").concat(colIndex, ".xy);\n }\n "); } } else { // Padding is even, so xRC corresponds to a single texel. mainLoop += "\n if (xC >= 0 && xC < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xC, d1);\n if (xC + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n xC").concat(colIndex, " = xTexelC").concat(colIndex, ";\n "); } if (colIndex + 1 < filterWidth) { // If dilation is even, the second entry should match the first // (either both are composed or both are single samples). But if // dilation is odd, then the second entry should be the opposite // of the first (if the first is composed, the second is a single // sample, and vice versa.) var nextTexelOffset = padLeft % 2 === 0 ? tf.util.nearestLargerEven(dilationWidth) : dilationWidth; if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) || (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) { mainLoop += "\n xCOffset = xC + imod(pads[1], 2) + ".concat(nextTexelOffset, ";\n\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n "); // If dilation > 1 then the xRC's will not be able to share any // values, so each xRC will require two unique calls to getX. if (dilationWidth > 1) { mainLoop += "\n xCOffset -= 2;\n if (xCOffset >= 0 && xCOffset < inDims[1]) {\n previous = getX(batch, xR, xCOffset, d1);\n xC".concat(colIndex + 1, " = vec4(previous.zw, xTexelC").concat(colIndex + 1, ".xy);\n } else {\n xC").concat(colIndex + 1, " = vec4(0.0, 0.0, xTexelC").concat(colIndex + 1, ".xy);\n }\n "); } else { mainLoop += "\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".xy);\n "); } } else { // If dilation is 1 and padding is odd, we have already read the // texel when constructing the previous x value. Here we can // simply skip the texture read. if (nextTexelOffset === 1) { mainLoop += "\n xC".concat(colIndex + 1, " = xTexelC").concat(colIndex, ";\n "); } else { mainLoop += "\n xCOffset = xC + ".concat(nextTexelOffset, ";\n\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex + 1, " = xTexelC").concat(colIndex + 1, ";\n "); } } } } } else { // stride === 2 if (colIndex < filterWidth) { // Depending on whether padLeft is even or odd, we want either the // xy or zw channels from X texels for xC${colIndex}. If padLeft is // even, xC${colIndex +1} is simply the zw channels of texels we've // already sampled. But if padLeft is odd, xC{$c + 1}.zw will // need to come from the xy channels of a new texel, hence the ` // vec4 // final` initialized below. if (padLeft % 2 === 1) { mainLoop += "\n xCOffset = xC + 1 - strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xCOffset, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xC + 1, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xC + 2 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".zw);\n "); if (colIndex + 1 < filterWidth) { mainLoop += "\n final = vec4(0.0);\n xCOffset = xC + 1 + strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1]) {\n final = getX(batch, xR, xCOffset, d1);\n }\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex + 1, ".xy, final.xy);\n "); } } else { mainLoop += "\n if(xC >= 0 && xC < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xC, d1);\n if (xC + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n xCOffset = xC + strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex, " = vec4(\n xTexelC").concat(colIndex, ".xy, xTexelC").concat(colIndex + 1, ".xy);\n "); if (colIndex + 1 < filterWidth) { mainLoop += "\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".zw);\n "); } } } } // localize the dotProd accumulation within the loop, the theory is for // GPU with limited cache, accumulate sum across large amount of // veriables will cause lots of cache misses. (i.e. 5x5 filter will have // 50 variables) if (colIndex < filterWidth) { mainLoop += "\n wTexel = getW(r, ".concat(colIndex, ", d1, d2);\n dotProd += xC").concat(colIndex, ".xxzz * vec4(wTexel.xy, wTexel.xy);\n if(d1 + 1 < ").concat(convInfo.inChannels, ") {\n dotProd += xC").concat(colIndex, ".yyww * vec4(wTexel.zw, wTexel.zw);\n }\n "); if (colIndex + 1 < filterWidth) { mainLoop += "\n wTexel = getW(r, ".concat(colIndex + 1, ", d1, d2);\n dotProd += xC").concat(colIndex + 1, ".xxzz * vec4(wTexel.xy, wTexel.xy);\n if(d1 + 1 < ").concat(convInfo.inChannels, ") {\n dotProd += xC").concat(colIndex + 1, ".yyww * vec4(wTexel.zw, wTexel.zw);\n }\n "); } } } mainLoop += "\n }\n "; mainLoop += "\n }\n "; mainLoop += "\n }\n "; var activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n ".concat(activation, "\n }"); } else if (hasLeakyReluAlpha) { activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getLeakyreluAlphaAtOutCoords();\n ".concat(activation, "\n }"); } else { activationSnippet = "vec4 activation(vec4 x) {\n ".concat(activation, "\n }"); } applyActivationSnippet = "result = activation(result);"; } var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyReluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = "\n ".concat(activationSnippet, "\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n //intialize dotProd with a small epsilon seems to reduce GPU accuracy loss.\n vec4 dotProd = vec4(0.000000000000001);\n\n ").concat(mainLoop, "\n\n vec4 result = dotProd - vec4(0.000000000000001);\n ").concat(addBiasSnippet, "\n ").concat(applyActivationSnippet, "\n setOutput(result);\n }\n "); } return Conv2DPackedProgram; }()); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var Im2ColPackedProgram = /** @class */ (function () { function Im2ColPackedProgram(outputShape, convInfo) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'inputShape', type: 'ivec4' }, { name: 'pad', type: 'ivec2' }, { name: 'stride', type: 'ivec2' }, { name: 'dilation', type: 'ivec2' }, { name: 'inChannels', type: 'int' }, { name: 'itemsPerBlockRow', type: 'int' }, { name: 'outWidth', type: 'int' }, ]; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); var dataFormat = convInfo.dataFormat; var glsl = getGlslDifferences(); var isChannelsLast = dataFormat === 'channelsLast'; var rowDim = isChannelsLast ? 1 : 2; var colDim = isChannelsLast ? 2 : 3; var boundsCheckingSnippet = this.enableShapeUniforms ? 'if(blockIndex < outShape[2] && pos < outShape[1]) {' : "if(blockIndex < ".concat(outputShape[2], " && pos < ").concat(outputShape[1], ") {"); var unrolled = ""; for (var row = 0; row <= 1; row++) { for (var col = 0; col <= 1; col++) { unrolled += "\n blockIndex = rc.z + ".concat(col, ";\n pos = rc.y + ").concat(row, ";\n\n ").concat(boundsCheckingSnippet, "\n offsetY = int(blockIndex / outWidth) * stride[0] - pad[0];\n d0 = offsetY + dilation[0] * (pos / itemsPerBlockRow);\n\n if(d0 < inputShape[").concat(rowDim, "] && d0 >= 0) {\n // Use custom imod instead mod. On Intel GPU, mod may generate\n // unexpected value.\n // https://github.com/tensorflow/tfjs/issues/5447\n offsetX = imod(blockIndex, outWidth) * stride[1] - pad[1];\n d1 = offsetX + dilation[1] * (imod(pos, itemsPerBlockRow) /\n inChannels);\n\n if(d1 < inputShape[").concat(colDim, "] && d1 >= 0) {\n\n ch = imod(pos, inChannels);\n\n if (").concat(isChannelsLast, ") {\n innerDims = vec2(d1, ch);\n result[").concat(row * 2 + col, "] = getChannel(\n getA(rc.x, d0, int(innerDims.x),\n int(innerDims.y)), innerDims);\n } else {\n innerDims = vec2(d0, d1);\n result[").concat(row * 2 + col, "] = getChannel(\n getA(rc.x, ch, int(innerDims.x),\n int(innerDims.y)), innerDims);\n }\n }\n }\n }\n "); } } this.userCode = "\n void main() {\n ivec3 rc = getOutputCoords();\n\n vec4 result = vec4(0);\n\n int blockIndex, pos, offsetY, d0, offsetX, d1, ch;\n vec2 innerDims;\n\n ".concat(unrolled, "\n\n ").concat(glsl.output, " = result;\n }\n "); } return Im2ColPackedProgram; }()); // Both conv2dByMatMul and conv2dWithIm2Row fuse height and width into one // dimension to compute batchMatMul, so bias and activation weights are also // supposed to fuse the two dimensions into one. // // This function computes the target shape for fusing height and width // dimensions. Returning null means the shape is already compatible. // // Even though the bias is not supposed to be a 3-D or a 4-D (including // batch) tensor and PReLU activiation weights is not supposed to be a 4-D // tensor, we still need to support them, because we haven't disabled // them for NHWC format. // https://github.com/tensorflow/tfjs/blob/b53bd47e880367ae57493f0ea628abaf08db2d5d/tfjs-core/src/ops/fused/conv2d.ts#L181-L196 function getShapeForBatchMatMul(shape, isChannelsLast) { var length = shape.length; if (length >= 3) { return isChannelsLast ? __spreadArray(__spreadArray([], __read(shape.slice(0, -3) /* batch */), false), [ shape[length - 3] * shape[length - 2] /* height * width */, shape[length - 1] /* channel */ ], false) : __spreadArray(__spreadArray([], __read(shape.slice(0, -3) /* batch */), false), [ shape[length - 3] /* channel */, shape[length - 2] * shape[length - 1] /* height * width */ ], false); } else if (!isChannelsLast && length === 1 && shape[0] > 1) { return [shape[0], 1]; } else { return null; } } // For 1x1 kernels that iterate through every point in the input, convolution // can be expressed as matrix multiplication (without need for memory // remapping). function conv2dByMatMul(_a) { var e_1, _b; var x = _a.x, filter = _a.filter, convInfo = _a.convInfo, backend = _a.backend, _c = _a.bias, bias = _c === void 0 ? null : _c, _d = _a.preluActivationWeights, preluActivationWeights = _d === void 0 ? null : _d, _e = _a.leakyreluAlpha, leakyreluAlpha = _e === void 0 ? 0 : _e, _f = _a.activation, activation = _f === void 0 ? null : _f; // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the // result from 2D to 4D. var xShape = x.shape; var xTexData = backend.texData.get(x.dataId); var sharedMatMulDim = convInfo.inChannels; var outerShapeX = xShape[0] * xShape[1] * xShape[2]; var outerShapeFilter = convInfo.outChannels; var isChannelsLast = convInfo.dataFormat === 'channelsLast'; var transposeA = false; var transposeB = false; var out; var intermediates = []; if (preluActivationWeights != null) { var targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast); if (targetShape != null) { preluActivationWeights = reshape({ inputs: { x: preluActivationWeights }, backend: backend, attrs: { shape: targetShape } }); intermediates.push(preluActivationWeights); } } if (bias != null) { var targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast); if (targetShape != null) { bias = reshape({ inputs: { x: bias }, backend: backend, attrs: { shape: targetShape } }); intermediates.push(bias); } } // TODO: Once reduction ops are packed, batchMatMul will always be packed // and we can remove this condition. var batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) && sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD; // The algorithm in the if condition assumes (1) the output will be packed, // (2) x is packed, (3) x isChannelsLast, (4) x's packed texture is already // on GPU, (5) col is odd, (6) the width, height and inChannels are the same // for xTexData.shape and xShape. var canOptimize = !batchMatMulWillBeUnpacked && xTexData.isPacked && isChannelsLast && xTexData.texture != null && xShape[2] % 2 !== 0 && tf.util.arraysEqual(xTexData.shape.slice(-3), xShape.slice(-3)); if (canOptimize) { // We avoid expensive packed 2x2 reshape by padding col count to next, // even number. When col is odd, the result of packed batchMatMul is // the same (has the same texture layout and and values in the texture) as // it is for next even col. We make the odd-cols tensor to look like // even-cols tensor before the operation and, after the batchMatMul, // fix the even-cols result to have odd number of cols. var targetShape = xShape[0] * xShape[1] * (xShape[2] + 1); var xReshaped_1 = { dataId: x.dataId, shape: [1, targetShape, convInfo.inChannels], dtype: x.dtype }; // xTexData.shape gets referenced from GPGPUBinary.inShapeInfos. // Decrementing col count, after batchMatMul->...->compileProgram leads to // invalid col count within the reference in GPGPUBinary.inShapeInfos. // Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos // in compileProgram method, but that would affect compilation of all // programs - instead, provide a copy here, with even col count, before // calling batchMatMul->...->compileProgram and after that, the original // xTexData.shape is restored. var originalXTexDataShape = xTexData.shape; xTexData.shape = xTexData.shape.slice(); xTexData.shape[xTexData.shape.length - 2]++; tf.util.assert(isReshapeFree(xTexData.shape, xReshaped_1.shape), function () { return "packed reshape ".concat(xTexData.shape, " to ").concat(xReshaped_1.shape, " isn't free"); }); var filterReshaped = reshape({ inputs: { x: filter }, backend: backend, attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] } }); intermediates.push(filterReshaped); var pointwiseConv = batchMatMulImpl({ a: xReshaped_1, b: filterReshaped, backend: backend, transposeA: transposeA, transposeB: transposeB, bias: bias, activation: activation, preluActivationWeights: preluActivationWeights, leakyreluAlpha: leakyreluAlpha }); var pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId); tf.util.assert(pointwiseConvTexData.isPacked, function () { return 'batchMatMul result is expected to be packed'; }); // Restore the input shape to original. xTexData.shape = originalXTexDataShape; // Set the output shape - there is no need for expensive reshape as data // layout is already correct. pointwiseConvTexData.shape = convInfo.outShape; out = identity({ inputs: { x: pointwiseConv }, backend: backend }); out.shape = convInfo.outShape; intermediates.push(pointwiseConv); } else { var numCols = convInfo.outHeight * convInfo.outWidth; var xReshaped = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: isChannelsLast ? [convInfo.batchSize, numCols, convInfo.inChannels] : [convInfo.batchSize, convInfo.inChannels, numCols] } }); var filterReshaped = reshape({ inputs: { x: filter }, backend: backend, attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] } }); var result = batchMatMulImpl({ a: isChannelsLast ? xReshaped : filterReshaped, b: isChannelsLast ? filterReshaped : xReshaped, transposeA: !isChannelsLast, transposeB: transposeB, backend: backend, bias: bias, activation: activation, preluActivationWeights: preluActivationWeights, leakyreluAlpha: leakyreluAlpha }); out = reshape({ inputs: { x: result }, backend: backend, attrs: { shape: convInfo.outShape } }); intermediates.push(xReshaped); intermediates.push(filterReshaped); intermediates.push(result); } try { for (var intermediates_1 = __values(intermediates), intermediates_1_1 = intermediates_1.next(); !intermediates_1_1.done; intermediates_1_1 = intermediates_1.next()) { var i = intermediates_1_1.value; backend.disposeIntermediateTensorInfo(i); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (intermediates_1_1 && !intermediates_1_1.done && (_b = intermediates_1.return)) _b.call(intermediates_1); } finally { if (e_1) throw e_1.error; } } return out; } // Implements the im2row algorithm as outlined in "High Performance // Convolutional Neural Networks for Document Processing" (Suvisoft, 2006) function conv2dWithIm2Row(_a) { var e_2, _b; var x = _a.x, filter = _a.filter, convInfo = _a.convInfo, backend = _a.backend, _c = _a.bias, bias = _c === void 0 ? null : _c, _d = _a.preluActivationWeights, preluActivationWeights = _d === void 0 ? null : _d, _e = _a.leakyreluAlpha, leakyreluAlpha = _e === void 0 ? 0 : _e, _f = _a.activation, activation = _f === void 0 ? null : _f; // Rearranges conv2d input so each block to be convolved over forms the // column of a new matrix with shape [filterWidth * filterHeight * // inChannels, outHeight * outWidth]. The filter is also rearranged so each // output channel forms a row of a new matrix with shape [outChannels, // filterWidth * filterHeight * inChannels]. The convolution is then // computed by multiplying these matrices and reshaping the result. var filterWidth = convInfo.filterWidth, filterHeight = convInfo.filterHeight, inChannels = convInfo.inChannels, outWidth = convInfo.outWidth, outHeight = convInfo.outHeight, dataFormat = convInfo.dataFormat; var isChannelsLast = dataFormat === 'channelsLast'; var sharedDim = filterWidth * filterHeight * inChannels; var numCols = outHeight * outWidth; var x2ColShape = [convInfo.batchSize, sharedDim, numCols]; var transposeA = true; var transposeB = false; var intermediates = []; if (preluActivationWeights != null) { var targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast); if (targetShape != null) { preluActivationWeights = reshape({ inputs: { x: preluActivationWeights }, backend: backend, attrs: { shape: targetShape } }); intermediates.push(preluActivationWeights); } } if (bias != null) { var targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast); if (targetShape != null) { bias = reshape({ inputs: { x: bias }, backend: backend, attrs: { shape: targetShape } }); intermediates.push(bias); } } var w2Row = reshape({ inputs: { x: filter }, backend: backend, attrs: { shape: [1, sharedDim, tf.util.sizeFromShape(filter.shape) / sharedDim] } }); intermediates.push(w2Row); var im2ColProgram = new Im2ColPackedProgram(x2ColShape, convInfo); var customValues = [ x.shape, [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inChannels], [convInfo.filterWidth * convInfo.inChannels], [convInfo.outWidth] ]; var im2Col = backend.runWebGLProgram(im2ColProgram, [x], 'float32', customValues); var im2ColReshaped = reshape({ inputs: { x: im2Col }, backend: backend, attrs: { shape: x2ColShape } }); intermediates.push(im2Col); intermediates.push(im2ColReshaped); var hasBias = bias != null; var hasPreluActivationWeights = preluActivationWeights != null; var hasLeakyreluAlpha = activation === 'leakyrelu'; var fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null; var matmulProgram = new MatMulPackedProgram(isChannelsLast ? im2ColReshaped.shape : w2Row.shape, isChannelsLast ? w2Row.shape : im2ColReshaped.shape, isChannelsLast ? [convInfo.batchSize, numCols, convInfo.outChannels] : [convInfo.batchSize, convInfo.outChannels, numCols], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); var inputs = isChannelsLast ? [im2ColReshaped, w2Row] : [w2Row, im2ColReshaped]; if (bias) { inputs.push(bias); } if (hasPreluActivationWeights) { inputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } var product = backend.runWebGLProgram(matmulProgram, inputs, 'float32'); var out = reshape({ inputs: { x: product }, backend: backend, attrs: { shape: convInfo.outShape } }); intermediates.push(product); try { for (var intermediates_2 = __values(intermediates), intermediates_2_1 = intermediates_2.next(); !intermediates_2_1.done; intermediates_2_1 = intermediates_2.next()) { var i = intermediates_2_1.value; backend.disposeIntermediateTensorInfo(i); } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (intermediates_2_1 && !intermediates_2_1.done && (_b = intermediates_2.return)) _b.call(intermediates_2); } finally { if (e_2) throw e_2.error; } } return out; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv2d(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, filter = inputs.filter; var strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode; var $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat); var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat); var out; if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) { out = conv2dByMatMul({ x: x, filter: filter, convInfo: convInfo, backend: backend }); } else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast' && tf.env().getBool('WEBGL_EXP_CONV')) { var program = new Conv2DPackedProgram(convInfo); var customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; out = backend.runWebGLProgram(program, [x, filter], 'float32', customValues); } else if (tf.env().getBool('WEBGL_CONV_IM2COL')) { out = conv2dWithIm2Row({ x: x, filter: filter, convInfo: convInfo, backend: backend }); } else { var program = new Conv2DProgram(convInfo); out = backend.runWebGLProgram(program, [x, filter], 'float32'); } var outReshaped = reshape({ inputs: { x: out }, backend: backend, attrs: { shape: convInfo.outShape } }); backend.disposeIntermediateTensorInfo(out); return outReshaped; } var conv2DConfig = { kernelName: tf.Conv2D, backendName: 'webgl', kernelFunc: conv2d, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var Conv2DDerFilterProgram = /** @class */ (function () { function Conv2DDerFilterProgram(convInfo) { this.variableNames = ['x', 'dy']; this.outputShape = convInfo.filterShape; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var padTop = convInfo.padInfo.top; var padLeft = convInfo.padInfo.left; var isChannelsLast = convInfo.dataFormat === 'channelsLast'; this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int d2 = coords.w;\n\n // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int b = 0; b < ".concat(convInfo.batchSize, "; b++) {\n for (int yR = 0; yR < ").concat(convInfo.outHeight, "; yR++) {\n int xR = wR + yR * ").concat(strideHeight, " - ").concat(padTop, ";\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int yC = 0; yC < ").concat(convInfo.outWidth, "; yC++) {\n int xC = wC + yC * ").concat(strideWidth, " - ").concat(padLeft, ";\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n ").concat(isChannelsLast ? "float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);" : "float dyValue = getDy(b, d2, yR, yC);\n float xValue = getX(b, d1, xR, xC);\n dotProd += (xValue * dyValue);", "\n }\n }\n }\n setOutput(dotProd);\n }\n "); } return Conv2DDerFilterProgram; }()); var Conv2DDerInputProgram = /** @class */ (function () { function Conv2DDerInputProgram(convInfo) { this.variableNames = ['dy', 'W']; this.outputShape = convInfo.inShape; var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var isChannelsLast = convInfo.dataFormat === 'channelsLast'; var padTop = filterHeight - 1 - convInfo.padInfo.top; var padLeft = filterWidth - 1 - convInfo.padInfo.left; var rowDim = isChannelsLast ? 1 : 2; var colDim = isChannelsLast ? 2 : 3; var channelDim = isChannelsLast ? 3 : 1; this.userCode = "\n const ivec2 pads = ivec2(".concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[").concat(channelDim, "];\n\n ivec2 dyCorner = ivec2(coords[").concat(rowDim, "], coords[").concat(colDim, "]) - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = ").concat(filterHeight, " - 1 - wR;\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = ").concat(filterWidth, " - 1 - wC;\n\n for (int d2 = 0; d2 < ").concat(convInfo.outChannels, "; d2++) {\n\n if (").concat(isChannelsLast, ") {\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n } else {\n float xValue = getDy(batch, d2, idyR, idyC);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n "); } return Conv2DDerInputProgram; }()); var Conv3DDerFilterProgram = /** @class */ (function () { function Conv3DDerFilterProgram(convInfo) { this.variableNames = ['x', 'dy']; this.outputShape = convInfo.filterShape; var strideDepth = convInfo.strideDepth; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var padFront = convInfo.padInfo.front; var padTop = convInfo.padInfo.top; var padLeft = convInfo.padInfo.left; this.userCode = "\n void main() {\n ivec5 coords = getOutputCoords();\n int wF = coords.x;\n int wR = coords.y;\n int wC = coords.z;\n int d1 = coords.w;\n int d2 = coords.u;\n\n float dotProd = 0.0;\n\n for (int b = 0; b < ".concat(convInfo.batchSize, "; b++) {\n for (int yF = 0; yF < ").concat(convInfo.outDepth, "; yF++) {\n int xF = wF + yF * ").concat(strideDepth, " - ").concat(padFront, ";\n\n if (xF < 0 || xF >= ").concat(convInfo.inDepth, ") {\n continue;\n }\n\n for (int yR = 0; yR < ").concat(convInfo.outHeight, "; yR++) {\n int xR = wR + yR * ").concat(strideHeight, " - ").concat(padTop, ";\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int yC = 0; yC < ").concat(convInfo.outWidth, "; yC++) {\n int xC = wC + yC * ").concat(strideWidth, " - ").concat(padLeft, ";\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n float dyValue = getDy(b, yF, yR, yC, d2);\n float xValue = getX(b, xF, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n "); } return Conv3DDerFilterProgram; }()); var Conv3DDerInputProgram = /** @class */ (function () { function Conv3DDerInputProgram(convInfo) { this.variableNames = ['dy', 'W']; this.outputShape = convInfo.inShape; var filterDepth = convInfo.filterDepth; var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var strideDepth = convInfo.strideDepth; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var padFront = filterDepth - 1 - convInfo.padInfo.front; var padTop = filterHeight - 1 - convInfo.padInfo.top; var padLeft = filterWidth - 1 - convInfo.padInfo.left; this.userCode = "\n const ivec3 pads = ivec3(".concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.u;\n\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyFCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n float dotProd = 0.0;\n for (int wF = 0; wF < ").concat(filterDepth, "; wF++) {\n float dyF = float(dyFCorner + wF) / ").concat(strideDepth, ".0;\n\n if (dyF < 0.0 || dyF >= ").concat(convInfo.outDepth, ".0 || fract(dyF) > 0.0) {\n continue;\n }\n int idyF = int(dyF);\n\n int wFPerm = ").concat(filterDepth, " - 1 - wF;\n\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = ").concat(filterHeight, " - 1 - wR;\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = ").concat(filterWidth, " - 1 - wC;\n\n for (int d2 = 0; d2 < ").concat(convInfo.outChannels, "; d2++) {\n float xValue = getDy(batch, idyF, idyR, idyC, d2);\n float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n }\n }\n }\n setOutput(dotProd);\n }\n "); } return Conv3DDerInputProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv2DBackpropFilter(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, dy = inputs.dy; var strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dimRoundingMode = attrs.dimRoundingMode, filterShape = attrs.filterShape; var $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat); var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad, dimRoundingMode, false /* depthwise */, $dataFormat); var program = new Conv2DDerFilterProgram(convInfo); return backend.runWebGLProgram(program, [x, dy], 'float32'); } var conv2DBackpropFilterConfig = { kernelName: tf.Conv2DBackpropFilter, backendName: 'webgl', kernelFunc: conv2DBackpropFilter, }; /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var Conv2DDerInputPackedProgram = /** @class */ (function () { function Conv2DDerInputPackedProgram(convInfo) { this.variableNames = ['dy', 'W']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'strides', type: 'vec2' }, ]; this.outputShape = convInfo.inShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var padTop = filterHeight - 1 - convInfo.padInfo.top; var padLeft = filterWidth - 1 - convInfo.padInfo.left; this.userCode = "\n const ivec2 pads = ivec2(".concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[3];\n\n ivec2 dyCorner = ivec2(coords[1], coords[2]) - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n vec4 result = vec4(0.);\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n float dyR = float(dyRCorner + wR) / strides[0];\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n int wRPerm = ").concat(filterHeight, " - 1 - wR;\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n int wCPerm = ").concat(filterWidth, " - 1 - wC;\n\n float dyC = float(dyCCorner + wC) / strides[1];\n bool idyCVal = (dyC >= 0.0) && (dyC < ").concat(convInfo.outWidth, ".0)\n && (fract(dyC) == 0.0);\n int idyC = int(dyC);\n\n float dyC2 = float(dyCCorner + wC + 1) / strides[1];\n bool idyCVal2 = (dyC2 >= 0.0) && (dyC2 < ").concat(convInfo.outWidth, ".0)\n && (fract(dyC2) == 0.0);\n int idyC2 = int(dyC2);\n\n if (idyCVal && idyCVal2) {\n for (int d2 = 0; d2 < ").concat(convInfo.outChannels, "; d2 += 2) {\n vec4 wValue = getW(wRPerm, wCPerm, d1, d2);\n vec4 dySample = getDy(batch, idyR, idyC, d2);\n vec4 dySample2 = (idyC / 2 == idyC2 / 2) ?\n dySample : getDy(batch, idyR, idyC2, d2);\n\n vec2 dyValue = mod(float(idyC), 2.) == 0. ?\n dySample.xy : dySample.zw;\n result.xy += vec2(dot(dyValue, wValue.xy),\n dot(dyValue, wValue.zw));\n\n dyValue = mod(float(idyC2), 2.) == 0. ?\n dySample2.xy : dySample2.zw;\n result.zw += vec2(dot(dyValue, wValue.xy),\n dot(dyValue, wValue.zw));\n }\n } else if (idyCVal) {\n for (int d2 = 0; d2 < ").concat(convInfo.outChannels, "; d2 += 2) {\n vec4 wValue = getW(wRPerm, wCPerm, d1, d2);\n vec4 dySample = getDy(batch, idyR, idyC, d2);\n vec2 dyValue = mod(float(idyC), 2.) == 0. ?\n dySample.xy : dySample.zw;\n result.xy += vec2(dot(dyValue, wValue.xy),\n dot(dyValue, wValue.zw));\n }\n } else if (idyCVal2) {\n for (int d2 = 0; d2 < ").concat(convInfo.outChannels, "; d2 += 2) {\n vec4 wValue = getW(wRPerm, wCPerm, d1, d2);\n vec4 dySample = getDy(batch, idyR, idyC2, d2);\n vec2 dyValue = mod(float(idyC2), 2.) == 0. ?\n dySample.xy : dySample.zw;\n result.zw += vec2(dot(dyValue, wValue.xy),\n dot(dyValue, wValue.zw));\n }\n }\n }\n }\n setOutput(result);\n }\n "); } return Conv2DDerInputPackedProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv2DBackpropInput(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var dy = inputs.dy, filter = inputs.filter; var inputShape = attrs.inputShape, strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dimRoundingMode = attrs.dimRoundingMode; var $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat); var convInfo = tf.backend_util.computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat); if (tf.env().getBool('WEBGL_PACK_CONV2DTRANSPOSE') && $dataFormat === 'channelsLast') { var customValues = [ [convInfo.strideHeight, convInfo.strideWidth], ]; var program = new Conv2DDerInputPackedProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32', customValues); } else { var program = new Conv2DDerInputProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32'); } } var conv2DBackpropInputConfig = { kernelName: tf.Conv2DBackpropInput, backendName: 'webgl', kernelFunc: conv2DBackpropInput, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv3D(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, filter = inputs.filter; var strides = attrs.strides, pad = attrs.pad, dilations = attrs.dilations; var convInfo = tf.backend_util.computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad); var program = new Conv3DProgram(convInfo); return backend.runWebGLProgram(program, [x, filter], 'float32'); } var conv3DConfig = { kernelName: tf.Conv3D, backendName: 'webgl', kernelFunc: conv3D, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv3DBackpropFilterV2(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, dy = inputs.dy; var strides = attrs.strides, pad = attrs.pad, filterShape = attrs.filterShape; var convInfo = tf.backend_util.computeConv3DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad); var program = new Conv3DDerFilterProgram(convInfo); return backend.runWebGLProgram(program, [x, dy], 'float32'); } var conv3DBackpropFilterV2Config = { kernelName: tf.Conv3DBackpropFilterV2, backendName: 'webgl', kernelFunc: conv3DBackpropFilterV2 }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv3DBackpropInput(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var dy = inputs.dy, filter = inputs.filter; var pad = attrs.pad, strides = attrs.strides, inputShape = attrs.inputShape; var convInfo = tf.backend_util.computeConv3DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad); var program = new Conv3DDerInputProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32'); } var conv3DBackpropInputConfig = { kernelName: tf.Conv3DBackpropInputV2, backendName: 'webgl', kernelFunc: conv3DBackpropInput, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var COS = CHECK_NAN_SNIPPET_UNARY + "\n return cos(x);\n"; var COS_PACKED = "\n vec4 result = cos(x);\n bvec4 isNaN = isnan(x);\n ".concat(CHECK_NAN_SNIPPET_PACKED, "\n return result;\n"); var cos = unaryKernelFunc({ opSnippet: COS, packedOpSnippet: COS_PACKED }); var cosConfig = { kernelName: tf.Cos, backendName: 'webgl', kernelFunc: cos, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var COSH = "\n float e2x = exp(-x);\n return (e2x + 1.0 / e2x) / 2.0;\n"; var cosh = unaryKernelFunc({ opSnippet: COSH }); var coshConfig = { kernelName: tf.Cosh, backendName: 'webgl', kernelFunc: cosh, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var CropAndResizeProgram = /** @class */ (function () { function CropAndResizeProgram(imageShape, boxShape, cropSize, method, extrapolationValue) { this.variableNames = ['Image', 'Boxes', 'BoxInd']; this.outputShape = []; var _a = __read(imageShape, 4), batch = _a[0], imageHeight = _a[1], imageWidth = _a[2], depth = _a[3]; var _b = __read(boxShape, 1), numBoxes = _b[0]; var _c = __read(cropSize, 2), cropHeight = _c[0], cropWidth = _c[1]; this.outputShape = [numBoxes, cropHeight, cropWidth, depth]; var methodId = method === 'bilinear' ? 1 : 0; var _d = __read(["".concat(imageHeight - 1, ".0"), "".concat(imageWidth - 1, ".0")], 2), inputHeightFloat = _d[0], inputWidthFloat = _d[1]; var _e = __read(cropHeight > 1 ? [ "".concat((imageHeight - 1) / (cropHeight - 1)), '(y2-y1) * height_ratio', "y1*".concat(inputHeightFloat, " + float(y)*(height_scale)"), ] : [ '0.0', '0.0', "0.5 * (y1+y2) * ".concat(inputHeightFloat), ], 3), heightRatio = _e[0], heightScale = _e[1], inY = _e[2]; var _f = __read(cropWidth > 1 ? [ "".concat((imageWidth - 1) / (cropWidth - 1)), '(x2-x1) * width_ratio', "x1*".concat(inputWidthFloat, " + float(x)*(width_scale)"), ] : [ '0.0', '0.0', "0.5 * (x1+x2) * ".concat(inputWidthFloat), ], 3), widthRatio = _f[0], widthScale = _f[1], inX = _f[2]; // Reference implementation // tslint:disable-next-line:max-line-length // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc this.userCode = "\n const float height_ratio = float(".concat(heightRatio, ");\n const float width_ratio = float(").concat(widthRatio, ");\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int y = coords[1];\n int x = coords[2];\n int d = coords[3];\n\n // get box vals\n float y1 = getBoxes(b,0);\n float x1 = getBoxes(b,1);\n float y2 = getBoxes(b,2);\n float x2 = getBoxes(b,3);\n\n // get image in batch index\n int bInd = round(getBoxInd(b));\n if(bInd < 0 || bInd >= ").concat(batch, ") {\n return;\n }\n\n float height_scale = ").concat(heightScale, ";\n float width_scale = ").concat(widthScale, ";\n\n float in_y = ").concat(inY, ";\n if( in_y < 0.0 || in_y > ").concat(inputHeightFloat, " ) {\n setOutput(float(").concat(extrapolationValue, "));\n return;\n }\n float in_x = ").concat(inX, ";\n if( in_x < 0.0 || in_x > ").concat(inputWidthFloat, " ) {\n setOutput(float(").concat(extrapolationValue, "));\n return;\n }\n\n vec2 sourceFracIndexCR = vec2(in_x,in_y);\n if(").concat(methodId, " == 1) {\n // Compute the four integer indices.\n ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);\n ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));\n\n float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);\n float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);\n float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);\n float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);\n\n vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);\n\n float top = topLeft + (topRight - topLeft) * fracCR.x;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;\n float newValue = top + (bottom - top) * fracCR.y;\n setOutput(newValue);\n } else {\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestCR = ivec2(floor(\n sourceFracIndexCR + vec2(0.5,0.5)));\n float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);\n setOutput(newValue);\n }\n }\n "); } return CropAndResizeProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var cropAndResize = function (args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var image = inputs.image, boxes = inputs.boxes, boxInd = inputs.boxInd; var cropSize = attrs.cropSize, method = attrs.method, extrapolationValue = attrs.extrapolationValue; var program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue); return backend.runWebGLProgram(program, [image, boxes, boxInd], 'float32'); }; var cropAndResizeConfig = { kernelName: tf.CropAndResize, backendName: 'webgl', kernelFunc: cropAndResize }; var CumOpType; (function (CumOpType) { CumOpType["Prod"] = "*"; CumOpType["Sum"] = "+"; })(CumOpType || (CumOpType = {})); var CumProgram = /** @class */ (function () { function CumProgram(op, outputShape, exclusive, reverse) { this.op = op; this.outputShape = outputShape; this.variableNames = ['x']; this.customUniforms = [{ name: 'index', type: 'float' }]; var rank = this.outputShape.length; var initVal = this.op === CumOpType.Prod ? '1.0' : '0.0'; var val = exclusive ? initVal : "getX(".concat(getCoords(rank, 'coords', this.op), ")"); var length = this.outputShape[this.outputShape.length - 1]; var condition = ''; var idxString = ''; // When exclusive is set, the cum op becomes roll op that copies the // value from the previous index based on the direction specified by the // reverse flag. if (exclusive) { condition = reverse ? "end != ".concat(length - 1) : 'end != 0'; idxString = reverse ? 'end + 1' : 'end - 1'; } else { condition = reverse ? "end + pow2 < ".concat(length) : 'end >= pow2'; idxString = (reverse ? 'end + pow2' : 'end - pow2'); } this.userCode = "\n void main() {\n ".concat(getCoordsDataType(rank), " coords = getOutputCoords();\n int end = ").concat(getFinalCoord(rank, 'coords', this.op), ";\n float val = ").concat(val, ";\n int pow2 = int(pow(2.0, index));\n if (").concat(condition, ") {\n int idx = ").concat(idxString, ";\n ").concat(getFinalCoord(rank, 'coords', this.op), " = idx;\n val ").concat(this.op, "= getX(").concat(getCoords(rank, 'coords', this.op), ");\n }\n setOutput(val);\n }\n "); } return CumProgram; }()); function getCoords(rank, name, op) { if (rank === 1) { return "".concat(name); } else if (rank === 2) { return "".concat(name, ".x, ").concat(name, ".y"); } else if (rank === 3) { return "".concat(name, ".x, ").concat(name, ".y, ").concat(name, ".z"); } else if (rank === 4) { return "".concat(name, ".x, ").concat(name, ".y, ").concat(name, ".z, ").concat(name, ".w"); } else { throw new Error("Cumulative ".concat(op, " for rank ").concat(rank, " is not yet supported")); } } function getFinalCoord(rank, name, op) { if (rank === 1) { return "".concat(name); } else if (rank === 2) { return "".concat(name, ".y"); } else if (rank === 3) { return "".concat(name, ".z"); } else if (rank === 4) { return "".concat(name, ".w"); } else { throw new Error("Cumulative ".concat(op, " for rank ").concat(rank, " is not yet supported")); } } /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function cumImpl(op, x, backend, axis, exclusive, reverse) { var xRank = x.shape.length; var permutation = tf.backend_util.getAxesPermutation([axis], xRank); var permutedX = x; if (permutation != null) { permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutation } }); } var permutedAxis = tf.backend_util.getInnerMostAxes(1, xRank)[0]; if (permutedAxis !== xRank - 1) { throw new Error("WebGL cumprod shader expects an inner-most axis=".concat(x.shape.length - 1, " ") + "but got axis=".concat(axis)); } var size = permutedX.shape[permutedAxis]; var result = identity({ inputs: { x: permutedX }, backend: backend }); // Use cum parallel algorithm, inspired by: // https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda // Note: although the algorithm is called sum, it works for any associtative // operator with an identity. for (var i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) { var program = new CumProgram(op, permutedX.shape, false, reverse); var customValues = [[i]]; var prevResult = result; result = backend.runWebGLProgram(program, [result], result.dtype, customValues); backend.disposeIntermediateTensorInfo(prevResult); } // For exclusive cum, shift the end result in the direction of product or sum // and add 1 for product or 0 for sum to the front index. if (exclusive) { var program = new CumProgram(op, permutedX.shape, exclusive, reverse); var prevResult = result; result = backend.runWebGLProgram(program, [result], result.dtype); backend.disposeIntermediateTensorInfo(prevResult); } if (permutation != null) { var reversePermutation = tf.backend_util.getUndoAxesPermutation(permutation); var reverseTransposedResult = transpose({ inputs: { x: result }, backend: backend, attrs: { perm: reversePermutation } }); backend.disposeIntermediateTensorInfo(result); backend.disposeIntermediateTensorInfo(permutedX); return reverseTransposedResult; } return result; } /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function cumprod(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var axis = attrs.axis, exclusive = attrs.exclusive, reverse = attrs.reverse; return cumImpl(CumOpType.Prod, x, backend, axis, exclusive, reverse); } var cumprodConfig = { kernelName: tf.Cumprod, backendName: 'webgl', kernelFunc: cumprod }; /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function cumsum(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var axis = attrs.axis, exclusive = attrs.exclusive, reverse = attrs.reverse; return cumImpl(CumOpType.Sum, x, backend, axis, exclusive, reverse); } var cumsumConfig = { kernelName: tf.Cumsum, backendName: 'webgl', kernelFunc: cumsum }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function denseBincount(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, weights = inputs.weights; var size = attrs.size, binaryOutput = attrs.binaryOutput; if (x.shape.length === 1) { var xVals = backend.readSync(x.dataId); var weightsVals = backend.readSync(weights.dataId); var outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size); return backend.makeTensorInfo([size], weights.dtype, outVals); } else if (x.shape.length === 2) { var xBuf = backend.bufferSync(x); var weightsBuf = backend.bufferSync(weights); var outBuf = bincountReduceImplCPU(xBuf, weightsBuf, size, binaryOutput); return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values); } throw new Error("Error in denseBincount: input must be at most rank 2, but got rank" + "".concat(x.shape.length, ".")); } var denseBincountConfig = { kernelName: tf.DenseBincount, backendName: 'webgl', kernelFunc: denseBincount }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var DepthToSpaceProgram = /** @class */ (function () { function DepthToSpaceProgram(outputShape, blockSize, dataFormat) { this.variableNames = ['x']; this.outputShape = []; this.outputShape = outputShape; this.blockSize = blockSize; this.dataFormat = dataFormat; this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int h = ".concat(this.getHeightCoordString(), ";\n int w = ").concat(this.getWidthCoordString(), ";\n int d = ").concat(this.getDepthCoordString(), ";\n\n int in_h = h / ").concat(blockSize, ";\n int offset_h = imod(h, ").concat(blockSize, ");\n int in_w = w / ").concat(blockSize, ";\n int offset_w = imod(w, ").concat(blockSize, ");\n int offset_d = (offset_h * ").concat(blockSize, " + offset_w) *\n ").concat(this.getOutputDepthSize(), ";\n int in_d = d + offset_d;\n\n float result = ").concat(this.getInputSamplingString(), ";\n setOutput(result);\n }\n "); } DepthToSpaceProgram.prototype.getHeightCoordString = function () { if (this.dataFormat === 'NHWC') { return "coords[1]"; } else { return "coords[2]"; } }; DepthToSpaceProgram.prototype.getWidthCoordString = function () { if (this.dataFormat === 'NHWC') { return "coords[2]"; } else { return "coords[3]"; } }; DepthToSpaceProgram.prototype.getDepthCoordString = function () { if (this.dataFormat === 'NHWC') { return "coords[3]"; } else { return "coords[1]"; } }; DepthToSpaceProgram.prototype.getOutputDepthSize = function () { if (this.dataFormat === 'NHWC') { return this.outputShape[3]; } else { return this.outputShape[1]; } }; DepthToSpaceProgram.prototype.getInputSamplingString = function () { if (this.dataFormat === 'NHWC') { return "getX(b, in_h, in_w, in_d)"; } else { return "getX(b, in_d, in_h, in_w)"; } }; return DepthToSpaceProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function depthToSpace(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var blockSize = attrs.blockSize, dataFormat = attrs.dataFormat; var batchSize = x.shape[0]; var inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2]; var inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3]; var inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1]; var outputHeight = inputHeight * blockSize; var outputWidth = inputWidth * blockSize; var outputDepth = inputDepth / (blockSize * blockSize); var outputShape = (dataFormat === 'NHWC') ? [batchSize, outputHeight, outputWidth, outputDepth] : [batchSize, outputDepth, outputHeight, outputWidth]; var program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat); return backend.runWebGLProgram(program, [x], x.dtype); } var depthToSpaceConfig = { kernelName: tf.DepthToSpace, backendName: 'webgl', kernelFunc: depthToSpace }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var DepthwiseConv2DProgram = /** @class */ (function () { function DepthwiseConv2DProgram(convInfo, addBias, activation, hasPreluActivation, hasLeakyReluAlpha) { if (addBias === void 0) { addBias = false; } if (activation === void 0) { activation = null; } if (hasPreluActivation === void 0) { hasPreluActivation = false; } if (hasLeakyReluAlpha === void 0) { hasLeakyReluAlpha = false; } this.variableNames = ['x', 'W']; this.customUniforms = [ { name: 'pads', type: 'ivec2' }, { name: 'strides', type: 'ivec2' }, { name: 'dilations', type: 'ivec2' }, { name: 'inDims', type: 'ivec2' }, ]; this.outputShape = convInfo.outShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var channelMul = convInfo.outChannels / convInfo.inChannels; var activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n ".concat(activation, "\n }"); } else if (hasLeakyReluAlpha) { activationSnippet = "float activation(float a) {\n float b = getLeakyreluAlphaAtOutCoords();\n ".concat(activation, "\n }"); } else { activationSnippet = "\n float activation(float x) {\n ".concat(activation, "\n }\n "); } applyActivationSnippet = "result = activation(result);"; } var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyReluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = "\n ".concat(activationSnippet, "\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / ").concat(channelMul, ";\n int q = d2 - d1 * ").concat(channelMul, ";\n\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n int xR = xRCorner + wR * dilations[0];\n\n if (xR < 0 || xR >= inDims[0]) {\n continue;\n }\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n int xC = xCCorner + wC * dilations[1];\n\n if (xC < 0 || xC >= inDims[1]) {\n continue;\n }\n\n float xVal = getX(batch, xR, xC, d1);\n float wVal = getW(wR, wC, d1, q);\n dotProd += xVal * wVal;\n }\n }\n\n float result = dotProd;\n ").concat(addBiasSnippet, "\n ").concat(applyActivationSnippet, "\n setOutput(result);\n }\n "); } return DepthwiseConv2DProgram; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var DepthwiseConvPacked2DProgram = /** @class */ (function () { function DepthwiseConvPacked2DProgram(convInfo, addBias, activation, hasPreluActivation, hasLeakyReluAlpha) { if (addBias === void 0) { addBias = false; } if (activation === void 0) { activation = null; } if (hasPreluActivation === void 0) { hasPreluActivation = false; } if (hasLeakyReluAlpha === void 0) { hasLeakyReluAlpha = false; } this.variableNames = ['x', 'W']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'pads', type: 'ivec2' }, { name: 'strides', type: 'ivec2' }, { name: 'dilations', type: 'ivec2' }, { name: 'inDims', type: 'ivec2' }, ]; this.outputShape = convInfo.outShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); var channelMul = convInfo.outChannels / convInfo.inChannels; var padLeft = convInfo.padInfo.left; var strideWidth = convInfo.strideWidth; var dilationWidth = convInfo.dilationWidth; var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var texelsAcross = filterWidth; var mainLoop = "\n int xR; int xC; int xCOffset;\n vec4 wTexel; vec4 previous; vec4 final;"; for (var c = 0; c < filterWidth; c++) { mainLoop += "\n vec4 xTexelC".concat(c * 2, ";\n int xTexelC").concat(c * 2, "Ready;\n vec4 xTexelC").concat(c * 2 + 1, ";\n int xTexelC").concat(c * 2 + 1, "Ready;\n vec4 xC").concat(c, ";"); } /** * This vectorized implementation works by gathering the values needed for * each output channel's dot product into vec4's and then multiplying them * all together (this happens in the final double for-loop below). Most of * the main loop consists of constructing these vec4's with the minimum * number of texture2D calls, which means making use of all four returned * values from a texture2D call at once. */ mainLoop += "\n for (int r = 0; r < ".concat(filterHeight, "; r++) {\n "); for (var c = 0; c < filterWidth; c++) { mainLoop += "\n xTexelC".concat(c * 2, " = vec4(0.0);\n xTexelC").concat(c * 2, "Ready = 0;\n xTexelC").concat(c * 2 + 1, " = vec4(0.0);\n xTexelC").concat(c * 2 + 1, "Ready = 0;\n xC").concat(c, " = vec4(0.0);"); } mainLoop += "\n xR = xRCorner + r * dilations[0];\n if (xR >=0 && xR < inDims[0]) {\n "; for (var texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) { var colIndex = texelC * 2; mainLoop += "\n xC = xCCorner + ".concat(colIndex * dilationWidth, ";\n "); if (strideWidth === 1) { if (colIndex < filterWidth) { // If padding is odd, the outer texels have to be composed. if (padLeft % 2 === 1) { // TODO: Ensure vec4 previous does not result in redundant sample, // and avoid setting xTexelRC's that exceed the boundary in the // first place rather than resetting them to vec4(0)). // To compute xCOffset: // - If padding is odd, we must add 1 to ensure we ask for an // even-numbered row. // - We subtract 2 to access the previous texel. mainLoop += "\n xCOffset = xC + 1;\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n "); // This texel has been read in previous iteration if the dilation // is 1. if (dilationWidth === 1 && colIndex > 0) { mainLoop += "\n xC".concat(colIndex, " = vec4(xTexelC").concat(colIndex - 2, ".zw, xTexelC").concat(colIndex, ".xy);\n "); } else { mainLoop += "\n xCOffset = xC + 1 - 2;\n\n if (xCOffset >= 0 && xCOffset < inDims[1]) {\n previous = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n previous.zw = vec2(0.0);\n }\n\n xC".concat(colIndex, " = vec4(previous.zw, xTexelC").concat(colIndex, ".xy);\n } else {\n xC").concat(colIndex, " = vec4(0.0, 0.0, xTexelC").concat(colIndex, ".xy);\n }\n "); } } else { // Padding is even, so xRC corresponds to a single texel. mainLoop += "\n if (xC >= 0 && xC < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xC, d1);\n if (xC + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n xC").concat(colIndex, " = xTexelC").concat(colIndex, ";\n "); } if (colIndex + 1 < filterWidth) { // If dilation is even, the second entry should match the first // (either both are composed or both are single samples). But if // dilation is odd, then the second entry should be the opposite // of the first (if the first is composed, the second is a single // sample, and vice versa.) var nextTexelOffset = padLeft % 2 === 0 ? tf.util.nearestLargerEven(dilationWidth) : dilationWidth; if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) || (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) { mainLoop += "\n xCOffset = xC + imod(pads[1], 2) + ".concat(nextTexelOffset, ";\n\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n "); // If dilation > 1 then the xRC's will not be able to share any // values, so each xRC will require two unique calls to getX. if (dilationWidth > 1) { mainLoop += "\n xCOffset -= 2;\n if (xCOffset >= 0 && xCOffset < inDims[1]) {\n previous = getX(batch, xR, xCOffset, d1);\n xC".concat(colIndex + 1, " = vec4(previous.zw, xTexelC").concat(colIndex + 1, ".xy);\n } else {\n xC").concat(colIndex + 1, " = vec4(0.0, 0.0, xTexelC").concat(colIndex + 1, ".xy);\n }\n "); } else { mainLoop += "\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".xy);\n "); } } else { // If dilation is 1 and padding is odd, we have already read the // texel when constructing the previous x value. Here we can // simply skip the texture read. if (nextTexelOffset === 1) { mainLoop += "\n xC".concat(colIndex + 1, " = xTexelC").concat(colIndex, ";\n "); } else { mainLoop += "\n xCOffset = xC + ".concat(nextTexelOffset, ";\n\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex + 1, " = xTexelC").concat(colIndex + 1, ";\n "); } } } } } else { // stride === 2 if (colIndex < filterWidth) { // Depending on whether padLeft is even or odd, we want either the // xy or zw channels from X texels for xC${colIndex}. If padLeft is // even, xC${colIndex +1} is simply the zw channels of texels we've // already sampled. But if padLeft is odd, xC{$c + 1}.zw will // need to come from the xy channels of a new texel, hence the ` // vec4 // final` initialized below. if (padLeft % 2 === 1) { mainLoop += "\n xCOffset = xC + 1 - strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xCOffset, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xC + 1, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xC + 2 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".zw);\n "); if (colIndex + 1 < filterWidth) { mainLoop += "\n final = vec4(0.0);\n xCOffset = xC + 1 + strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1]) {\n final = getX(batch, xR, xCOffset, d1);\n }\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex + 1, ".xy, final.xy);\n "); } } else { mainLoop += "\n if(xC >= 0 && xC < inDims[1] && xTexelC".concat(colIndex, "Ready == 0) {\n xTexelC").concat(colIndex, " = getX(batch, xR, xC, d1);\n if (xC + 1 >= inDims[1]) {\n xTexelC").concat(colIndex, ".zw = vec2(0.0);\n }\n xTexelC").concat(colIndex, "Ready = 1;\n }\n\n xCOffset = xC + strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC").concat(colIndex + 1, "Ready == 0) {\n xTexelC").concat(colIndex + 1, " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC").concat(colIndex + 1, ".zw = vec2(0.);\n }\n xTexelC").concat(colIndex + 1, "Ready = 1;\n }\n\n xC").concat(colIndex, " = vec4(\n xTexelC").concat(colIndex, ".xy, xTexelC").concat(colIndex + 1, ".xy);\n "); if (colIndex + 1 < filterWidth) { mainLoop += "\n xC".concat(colIndex + 1, " = vec4(xTexelC").concat(colIndex, ".zw, xTexelC").concat(colIndex + 1, ".zw);\n "); } } } } // localize the dotProd accumulation within the loop, the theory is for // GPU with limited cache, accumulate sum across large amount of // veriables will cause lots of cache misses. (i.e. 5x5 filter will have // 50 variables) if (colIndex < filterWidth) { mainLoop += "\n wTexel = getW(r, ".concat(colIndex, ", d1, q);\n dotProd += xC").concat(colIndex, " * vec4(wTexel.xz, wTexel.xz);\n "); if (colIndex + 1 < filterWidth) { mainLoop += "\n wTexel = getW(r, ".concat(colIndex + 1, ", d1, q);\n dotProd += xC").concat(colIndex + 1, " * vec4(wTexel.xz, wTexel.xz);\n "); } } } mainLoop += "\n }\n "; mainLoop += "\n }\n "; var activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n ".concat(activation, "\n }"); } else if (hasLeakyReluAlpha) { activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getLeakyreluAlphaAtOutCoords();\n ".concat(activation, "\n }"); } else { activationSnippet = "vec4 activation(vec4 x) {\n ".concat(activation, "\n }"); } applyActivationSnippet = "result = activation(result);"; } var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyReluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = "\n ".concat(activationSnippet, "\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / ").concat(channelMul, ";\n int q = d2 - d1 * ").concat(channelMul, ";\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n //intialize dotProd with a small epsilon seems to reduce GPU accuracy loss.\n vec4 dotProd = vec4(0.000000000000001);\n\n ").concat(mainLoop, "\n\n vec4 result = dotProd - vec4(0.000000000000001);\n ").concat(addBiasSnippet, "\n ").concat(applyActivationSnippet, "\n setOutput(result);\n }\n "); } return DepthwiseConvPacked2DProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function depthwiseConv2dNative(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, filter = inputs.filter; var strides = attrs.strides, pad = attrs.pad, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode; var $dilations = dilations; if ($dilations == null) { $dilations = [1, 1]; } tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), function () { return 'Error in depthwiseConv2d: Either strides or dilations must be ' + "1. Got strides ".concat(strides, " and dilations '").concat($dilations, "'"); }); var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */); var program; if (tf.env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1) { program = new DepthwiseConvPacked2DProgram(convInfo); } else { program = new DepthwiseConv2DProgram(convInfo); } var customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; return backend.runWebGLProgram(program, [x, filter], 'float32', customValues); } var depthwiseConv2dNativeConfig = { kernelName: tf.DepthwiseConv2dNative, backendName: 'webgl', kernelFunc: depthwiseConv2dNative, }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var DepthwiseConv2DDerFilterProgram = /** @class */ (function () { function DepthwiseConv2DDerFilterProgram(convInfo) { this.variableNames = ['x', 'dy']; this.outputShape = convInfo.filterShape; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var padTop = convInfo.padInfo.top; var padLeft = convInfo.padInfo.left; var channelMul = convInfo.outChannels / convInfo.inChannels; this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int dm = coords.w;\n int d2 = d1 * ".concat(channelMul, " + dm;\n\n float dotProd = 0.0;\n\n // TO DO: Vec4 over the batch size\n for (int b = 0; b < ").concat(convInfo.batchSize, "; b++) {\n for (int yR = 0; yR < ").concat(convInfo.outHeight, "; yR++) {\n int xR = wR + yR * ").concat(strideHeight, " - ").concat(padTop, ";\n\n if (xR < 0 || xR >= ").concat(convInfo.inHeight, ") {\n continue;\n }\n\n for (int yC = 0; yC < ").concat(convInfo.outWidth, "; yC++) {\n int xC = wC + yC * ").concat(strideWidth, " - ").concat(padLeft, ";\n\n if (xC < 0 || xC >= ").concat(convInfo.inWidth, ") {\n continue;\n }\n\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n setOutput(dotProd);\n }\n "); } return DepthwiseConv2DDerFilterProgram; }()); var DepthwiseConv2DDerInputProgram = /** @class */ (function () { function DepthwiseConv2DDerInputProgram(convInfo) { this.variableNames = ['dy', 'W']; this.outputShape = convInfo.inShape; var filterHeight = convInfo.filterHeight; var filterWidth = convInfo.filterWidth; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var padTop = filterHeight - 1 - convInfo.padInfo.top; var padLeft = filterWidth - 1 - convInfo.padInfo.left; var channelMul = convInfo.outChannels / convInfo.inChannels; this.userCode = "\n const ivec2 pads = ivec2(".concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[3];\n ivec2 dyCorner = coords.yz - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n float dotProd = 0.0;\n\n for (int wR = 0; wR < ").concat(filterHeight, "; wR++) {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = ").concat(filterHeight, " - 1 - wR;\n\n for (int wC = 0; wC < ").concat(filterWidth, "; wC++) {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = ").concat(filterWidth, " - 1 - wC;\n\n // TO DO: Vec4 over the channelMul\n for (int dm = 0; dm < ").concat(channelMul, "; dm++) {\n int d2 = d1 * ").concat(channelMul, " + dm;\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, dm);\n dotProd += xValue * wValue;\n }\n }\n }\n setOutput(dotProd);\n }\n "); } return DepthwiseConv2DDerInputProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function depthwiseConv2dNativeBackpropFilter(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, dy = inputs.dy; var strides = attrs.strides, dilations = attrs.dilations, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode, filterShape = attrs.filterShape; var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true /* depthwise */); var program = new DepthwiseConv2DDerFilterProgram(convInfo); return backend.runWebGLProgram(program, [x, dy], 'float32'); } var depthwiseConv2dNativeBackpropFilterConfig = { kernelName: tf.DepthwiseConv2dNativeBackpropFilter, backendName: 'webgl', kernelFunc: depthwiseConv2dNativeBackpropFilter }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function depthwiseConv2dNativeBackpropInput(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var dy = inputs.dy, filter = inputs.filter; var strides = attrs.strides, dilations = attrs.dilations, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode, inputShape = attrs.inputShape; var convInfo = tf.backend_util.computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */); var program = new DepthwiseConv2DDerInputProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32'); } var depthwiseConv2dNativeBackpropInputConfig = { kernelName: tf.DepthwiseConv2dNativeBackpropInput, backendName: 'webgl', kernelFunc: depthwiseConv2dNativeBackpropInput }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var DiagProgram = /** @class */ (function () { function DiagProgram(size) { this.variableNames = ['X']; this.outputShape = [size, size]; this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;\n setOutput(val);\n }\n "; } return DiagProgram; }()); function diag(args) { var inputs = args.inputs, backend = args.backend; var x = inputs.x; var outShape = __spreadArray(__spreadArray([], __read(x.shape), false), __read(x.shape), false); var xSize = tf.util.sizeFromShape(x.shape); var flat = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: [xSize] } }); var program = new DiagProgram(xSize); var res = backend.runWebGLProgram(program, [flat], flat.dtype); var out = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: outShape } }); backend.disposeIntermediateTensorInfo(flat); backend.disposeIntermediateTensorInfo(res); return out; } var diagConfig = { kernelName: tf.Diag, backendName: 'webgl', kernelFunc: diag }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var Dilation2DProgram = /** @class */ (function () { function Dilation2DProgram(convInfo) { this.variableNames = ['x', 'W']; this.outputShape = convInfo.outShape; var inHeight = convInfo.inHeight, inWidth = convInfo.inWidth, padInfo = convInfo.padInfo, strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, dilationHeight = convInfo.dilationHeight, dilationWidth = convInfo.dilationWidth; var padTop = padInfo.top, padLeft = padInfo.left; this.userCode = "\n const ivec2 strides = ivec2(".concat(strideHeight, ", ").concat(strideWidth, ");\n const ivec2 pads = ivec2(").concat(padTop, ", ").concat(padLeft, ");\n const float neg_infinity = -3.4e38;\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.w;\n ivec2 outTopLeftCorner =\n coords.yz * strides - pads;\n int hBeg = outTopLeftCorner.x;\n int wBeg = outTopLeftCorner.y;\n\n float curVal = neg_infinity;\n for (int h = 0; h < ").concat(filterHeight, "; h++) {\n int hIn = hBeg + h * ").concat(dilationHeight, ";\n\n if (hIn >= 0 && hIn < ").concat(inHeight, ") {\n for (int w = 0; w < ").concat(filterWidth, "; w++) {\n int wIn = wBeg + w * ").concat(dilationWidth, ";\n\n if (wIn >= 0 && wIn < ").concat(inWidth, ") {\n float xVal = getX(batch, hIn, wIn, d1);\n float wVal = getW(h, w, d1);\n\n float val = xVal + wVal;\n if (val > curVal) {\n curVal = val;\n }\n }\n }\n }\n }\n\n float result = curVal;\n setOutput(result);\n }\n "); } return Dilation2DProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function dilation2D(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, filter = inputs.filter; var strides = attrs.strides, pad = attrs.pad, dilations = attrs.dilations; var convInfo = tf.backend_util.computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations); var out; var program = new Dilation2DProgram(convInfo); out = backend.runWebGLProgram(program, [x, filter], 'float32'); var outReshaped = reshape({ inputs: { x: out }, backend: backend, attrs: { shape: convInfo.outShape } }); backend.disposeIntermediateTensorInfo(out); return outReshaped; } var dilation2DConfig = { kernelName: tf.Dilation2D, backendName: 'webgl', kernelFunc: dilation2D, }; function einsum(args) { var e_1, _a, e_2, _b; var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var equation = attrs.equation; var tensors = inputs; var _c = tf.backend_util.decodeEinsumEquation(equation, tensors.length), allDims = _c.allDims, summedDims = _c.summedDims, idDims = _c.idDims; tf.backend_util.checkEinsumDimSizes(allDims.length, idDims, tensors); var _d = tf.backend_util.getEinsumComputePath(summedDims, idDims), path = _d.path, steps = _d.steps; var nSteps = steps.length; var out = null; var numDimsRemaining = allDims.length; var tensorsToDispose = []; for (var i = 0; i < nSteps; ++i) { try { for (var _e = (e_1 = void 0, __values(steps[i])), _f = _e.next(); !_f.done; _f = _e.next()) { var idTerm = _f.value; var _g = tf.backend_util.getEinsumPermutation(numDimsRemaining, idDims[idTerm]), perm = _g.permutationIndices, dimsToExpand = _g.expandDims; var x = void 0; if (tf.backend_util.isIdentityPermutation(perm)) { x = tensors[idTerm]; } else { x = transpose({ inputs: { x: tensors[idTerm] }, backend: backend, attrs: { perm: perm } }); tensorsToDispose.push(x); } var targetShape = x.shape.slice(); for (var k = 0; k < dimsToExpand.length; ++k) { targetShape.splice(dimsToExpand[k], 0, 1); } if (!tf.util.arraysEqual(x.shape, targetShape)) { x = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: targetShape } }); tensorsToDispose.push(x); } if (out === null) { out = x; } else { // tslint:disable-next-line: no-unnecessary-type-assertion out = multiply({ inputs: { a: x, b: out }, backend: backend }); tensorsToDispose.push(out); } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_f && !_f.done && (_a = _e.return)) _a.call(_e); } finally { if (e_1) throw e_1.error; } } if (i < nSteps - 1) { if (path[i] >= 0) { out = sum({ inputs: { x: out }, backend: backend, attrs: { axis: path[i] - (allDims.length - numDimsRemaining), keepDims: false } }); tensorsToDispose.push(out); } numDimsRemaining--; } } try { // Clean up intermediate tensors. for (var tensorsToDispose_1 = __values(tensorsToDispose), tensorsToDispose_1_1 = tensorsToDispose_1.next(); !tensorsToDispose_1_1.done; tensorsToDispose_1_1 = tensorsToDispose_1.next()) { var tensorInfo = tensorsToDispose_1_1.value; if (tensorInfo === out) { continue; } backend.disposeIntermediateTensorInfo(tensorInfo); } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (tensorsToDispose_1_1 && !tensorsToDispose_1_1.done && (_b = tensorsToDispose_1.return)) _b.call(tensorsToDispose_1); } finally { if (e_2) throw e_2.error; } } return out; } var einsumConfig = { kernelName: tf.Einsum, backendName: 'webgl', kernelFunc: einsum }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ELU = "return (x >= 0.0) ? x : (exp(x) - 1.0);"; var ELU_PACKED = "\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n"; var elu = unaryKernelFunc({ opSnippet: ELU, packedOpSnippet: ELU_PACKED }); var eluConfig = { kernelName: tf.Elu, backendName: 'webgl', kernelFunc: elu }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ELU_DER = "return (b >= 0.0) ? a : a * (b + 1.0);"; var ELU_DER_PACKED = "\n vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n"; var eluGrad = function (args) { var inputs = args.inputs, backend = args.backend; var dy = inputs.dy, y = inputs.y; var program = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(ELU_DER_PACKED, dy.shape, y.shape) : new BinaryOpProgram(ELU_DER, dy.shape, y.shape); return backend.runWebGLProgram(program, [dy, y], dy.dtype); }; var eluGradConfig = { kernelName: tf.EluGrad, backendName: 'webgl', kernelFunc: eluGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var PACKED_EQUAL = "\n return vec4(equal(a, b));\n"; var EQUAL = "return float(a == b);"; var equal = binaryKernelFunc({ opSnippet: EQUAL, packedOpSnippet: PACKED_EQUAL, dtype: 'bool', cpuKernelImpl: equalImplCPU, }); var equalConfig = { kernelName: tf.Equal, backendName: 'webgl', kernelFunc: equal }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ERF = "\n // Error function is calculated approximately with elementary function.\n // See \"Handbook of Mathematical Functions with Formulas,\n // Graphs, and Mathematical Tables\", Abramowitz and Stegun.\n float p = ".concat(tf.backend_util.ERF_P, ";\n float a1 = ").concat(tf.backend_util.ERF_A1, ";\n float a2 = ").concat(tf.backend_util.ERF_A2, ";\n float a3 = ").concat(tf.backend_util.ERF_A3, ";\n float a4 = ").concat(tf.backend_util.ERF_A4, ";\n float a5 = ").concat(tf.backend_util.ERF_A5, ";\n\n float sign = sign(x);\n x = abs(x);\n float t = 1.0 / (1.0 + p * x);\n return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));\n"); var erf = unaryKernelFunc({ opSnippet: ERF }); var erfConfig = { kernelName: tf.Erf, backendName: 'webgl', kernelFunc: erf, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var EXP = CHECK_NAN_SNIPPET_UNARY + "\n return exp(x);\n"; var EXP_PACKED = "\n vec4 result = exp(x);\n bvec4 isNaN = isnan(x);\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n"; var exp = unaryKernelFunc({ opSnippet: EXP, packedOpSnippet: EXP_PACKED, cpuKernelImpl: expImplCPU, dtype: 'float32', }); var expConfig = { kernelName: tf.Exp, backendName: 'webgl', kernelFunc: exp }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function expandDims(args) { var inputs = args.inputs, attrs = args.attrs, backend = args.backend; var dim = attrs.dim; var input = inputs.input; var inputRank = input.shape.length; var newShape = input.shape.slice(); var $dim = dim; if (dim < 0) { // Negative value is counted from the tail of rank. tf.util.assert(-(inputRank + 1) <= dim, function () { return "Axis must be in the interval [".concat(-(inputRank + 1), ", ").concat(inputRank, "]"); }); $dim = inputRank + dim + 1; } newShape.splice($dim, 0, 1); return reshape({ inputs: { x: input }, backend: backend, attrs: { shape: newShape } }); } var expandDimsConfig = { kernelName: tf.ExpandDims, backendName: 'webgl', kernelFunc: expandDims, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var EXPM1 = "return exp(x) - 1.0;"; var expm1 = unaryKernelFunc({ opSnippet: EXPM1, packedOpSnippet: EXPM1, cpuKernelImpl: expm1ImplCPU }); var expm1Config = { kernelName: tf.Expm1, backendName: 'webgl', kernelFunc: expm1 }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var FFTProgram = /** @class */ (function () { function FFTProgram(component, inputShape, inverse) { this.variableNames = ['real', 'imag']; var innerDim = inputShape[1]; this.outputShape = inputShape; var exponentMultiplierSnippet = inverse ? "2.0 * ".concat(Math.PI) : "-2.0 * ".concat(Math.PI); var resultDenominator = inverse ? "".concat(innerDim, ".0") : '1.0'; var opString; if (component === 'real') { opString = 'return real * expR - imag * expI;'; } else if (component === 'imag') { opString = 'return real * expI + imag * expR;'; } else { throw new Error("FFT component must be either \"real\" or \"imag\", got ".concat(component, ".")); } this.userCode = "\n const float exponentMultiplier = ".concat(exponentMultiplierSnippet, ";\n\n float unaryOpComplex(float real, float expR, float imag, float expI) {\n ").concat(opString, "\n }\n\n float mulMatDFT(int batch, int index) {\n float indexRatio = float(index) / float(").concat(innerDim, ");\n float exponentMultiplierTimesIndexRatio =\n exponentMultiplier * indexRatio;\n\n float result = 0.0;\n\n for (int i = 0; i < ").concat(innerDim, "; i++) {\n // x = (-2|2 * PI / N) * index * i;\n float x = exponentMultiplierTimesIndexRatio * float(i);\n float expR = cos(x);\n float expI = sin(x);\n float real = getReal(batch, i);\n float imag = getImag(batch, i);\n\n result +=\n unaryOpComplex(real, expR, imag, expI) / ").concat(resultDenominator, ";\n }\n\n return result;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n setOutput(mulMatDFT(coords[0], coords[1]));\n }\n "); } return FFTProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function fftImpl(x, inverse, backend) { var xData = backend.texData.get(x.dataId); var inputSize = tf.util.sizeFromShape(x.shape); // Collapse all outer dimensions to a single batch dimension. var innerDimensionSize = x.shape[x.shape.length - 1]; var batch = inputSize / innerDimensionSize; var input2D = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: [batch, innerDimensionSize] } }); var xShape = input2D.shape; var realProgram = new FFTProgram('real', xShape, inverse); var imagProgram = new FFTProgram('imag', xShape, inverse); var inputs = [ { dataId: xData.complexTensorInfos.real.dataId, dtype: xData.complexTensorInfos.real.dtype, shape: xShape }, { dataId: xData.complexTensorInfos.imag.dataId, dtype: xData.complexTensorInfos.imag.dtype, shape: xShape } ]; var realPart = backend.runWebGLProgram(realProgram, inputs, 'float32'); var imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32'); var complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend: backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(imagPart); var complexOutputReshaped = reshape({ inputs: { x: complexOutput }, backend: backend, attrs: { shape: x.shape } }); backend.disposeIntermediateTensorInfo(input2D); backend.disposeIntermediateTensorInfo(complexOutput); return complexOutputReshaped; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function fft(args) { var inputs = args.inputs, backend = args.backend; var input = inputs.input; return fftImpl(input, false /* inverse */, backend); } var fftConfig = { kernelName: tf.FFT, backendName: 'webgl', kernelFunc: fft }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var FillProgram = /** @class */ (function () { function FillProgram(shape, value) { this.outputShape = []; this.customUniforms = [{ name: 'value', type: 'float' }]; this.variableNames = ['x']; this.outputShape = shape; this.userCode = "\n void main() {\n // Input can be obtained from uniform value.\n setOutput(value);\n }\n "; } return FillProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function fill(args) { var backend = args.backend, attrs = args.attrs; var shape = attrs.shape, value = attrs.value; var dtype = attrs.dtype; dtype = dtype || tf.util.inferDtype(value); if (dtype === 'string') { // String type should be handled in CPU memory. var values = tf.util.getArrayFromDType(dtype, tf.util.sizeFromShape(shape)); values.fill(value); return backend.makeTensorInfo(shape, dtype, values); } else { var program = new FillProgram(shape, value); var customValues = [[value]]; return backend.runWebGLProgram(program, [], dtype, customValues); } } var fillConfig = { kernelName: tf.Fill, backendName: 'webgl', kernelFunc: fill }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var FlipLeftRightProgram = /** @class */ (function () { function FlipLeftRightProgram(imageShape) { this.variableNames = ['Image']; this.outputShape = []; var imageWidth = imageShape[2]; this.outputShape = imageShape; this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int x = coords[2];\n\n int coordX = ".concat(imageWidth, " - x - 1;\n float outputValue;\n if(coordX >= 0 && coordX < ").concat(imageWidth, ") {\n outputValue = getImage(coords[0], coords[1], coordX, coords[3]);\n } else {\n outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);\n }\n setOutput(outputValue);\n }\n "); } return FlipLeftRightProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var flipLeftRightConfig = { kernelName: tf.FlipLeftRight, backendName: 'webgl', kernelFunc: function (_a) { var inputs = _a.inputs, backend = _a.backend; var image = inputs.image; var webglBackend = backend; var program = new FlipLeftRightProgram(image.shape); var output = webglBackend.runWebGLProgram(program, [image], image.dtype); return output; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var FLOOR = "return floor(x);"; var floor = unaryKernelFunc({ opSnippet: FLOOR, packedOpSnippet: FLOOR, cpuKernelImpl: floorImplCPU }); var floorConfig = { kernelName: tf.Floor, backendName: 'webgl', kernelFunc: floor, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // We use native integer division to deal with floating point imprecision. Since // we implement floor division and glsl implements truncated division, we // correct for this by subtracting 1 from result when the result is negative and // there is a remainder. var INT_DIV = "\n float s = sign(a) * sign(b);\n int ia = round(a);\n int ib = round(b);\n if (ib != 0) {\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n return float(idiv(ia, ib, s));\n } else {\n return NAN;\n }\n"; var INT_DIV_PACKED = "\n ivec4 ia = round(a);\n ivec4 ib = round(b);\n bvec4 cond = notEqual(ib, ivec4(0));\n ivec4 result = ivec4(0);\n vec4 s = sign(a) * sign(b);\n\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n if (cond[0]) {\n result[0] = idiv(ia[0], ib[0], s[0]);\n }\n if (cond[1]) {\n result[1] = idiv(ia[1], ib[1], s[1]);\n }\n if (cond[2]) {\n result[2] = idiv(ia[2], ib[2], s[2]);\n }\n if (cond[3]) {\n result[3] = idiv(ia[3], ib[3], s[3]);\n }\n return vec4(result);\n"; var floorDiv = binaryKernelFunc({ opSnippet: INT_DIV, packedOpSnippet: INT_DIV_PACKED, dtype: 'int32' }); var floorDivConfig = { kernelName: tf.FloorDiv, backendName: 'webgl', kernelFunc: floorDiv }; var FromPixelsProgram = /** @class */ (function () { function FromPixelsProgram(outputShape) { this.variableNames = ['A']; var glsl = getGlslDifferences(); var _a = __read(outputShape, 2), height = _a[0], width = _a[1]; this.outputShape = outputShape; this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(".concat(width, ".0, ").concat(height, ".0);\n\n vec4 values = ").concat(glsl.texture2D, "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n setOutput(floor(value * 255.0 + 0.5));\n }\n "); } return FromPixelsProgram; }()); var FromPixelsPackedProgram = /** @class */ (function () { function FromPixelsPackedProgram(outputShape) { this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; var glsl = getGlslDifferences(); var _a = __read(outputShape, 2), height = _a[0], width = _a[1]; this.outputShape = outputShape; this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n\n vec4 result = vec4(0.);\n\n for(int row=0; row<=1; row++) {\n for(int col=0; col<=1; col++) {\n texC = coords[1] + row;\n depth = coords[2] + col;\n\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(".concat(width, ".0, ").concat(height, ".0);\n vec4 values = ").concat(glsl.texture2D, "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n result[row * 2 + col] = floor(value * 255.0 + 0.5);\n }\n }\n\n ").concat(glsl.output, " = result;\n }\n "); } return FromPixelsPackedProgram; }()); var fromPixelsConfig = { kernelName: tf.FromPixels, backendName: 'webgl', kernelFunc: fromPixels, }; var fromPixels2DContext; var willReadFrequently = tf.env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU'); function fromPixels(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var pixels = inputs.pixels; var numChannels = attrs.numChannels; var isVideo = typeof (HTMLVideoElement) !== 'undefined' && pixels instanceof HTMLVideoElement; var isImage = typeof (HTMLImageElement) !== 'undefined' && pixels instanceof HTMLImageElement; var _a = __read(isVideo ? [ pixels.videoWidth, pixels.videoHeight ] : [pixels.width, pixels.height], 2), width = _a[0], height = _a[1]; var texShape = [height, width]; var outShape = [height, width, numChannels]; if (isImage || isVideo) { var newWillReadFrequently = tf.env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU'); if (fromPixels2DContext == null || newWillReadFrequently !== willReadFrequently) { willReadFrequently = newWillReadFrequently; fromPixels2DContext = document.createElement('canvas').getContext('2d', { willReadFrequently: willReadFrequently }); } fromPixels2DContext.canvas.width = width; fromPixels2DContext.canvas.height = height; fromPixels2DContext.drawImage(pixels, 0, 0, width, height); pixels = fromPixels2DContext.canvas; } var tempPixelHandle = backend.makeTensorInfo(texShape, 'int32'); // This is a byte texture with pixels. backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS; backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels); var program = tf.env().getBool('WEBGL_PACK') ? new FromPixelsPackedProgram(outShape) : new FromPixelsProgram(outShape); var res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32'); backend.disposeData(tempPixelHandle.dataId); return res; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function fusedConv2d(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, filter = inputs.filter, bias = inputs.bias, preluActivationWeights = inputs.preluActivationWeights; var strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode, activation = attrs.activation, leakyreluAlpha = attrs.leakyreluAlpha; var $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat); var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat); var out; var intermediates = []; var hasBias = bias != null; var hasPreluActivationWeights = preluActivationWeights != null; var hasLeakyreluAlpha = activation === 'leakyrelu'; var prepareInputs = function () { var inputs = [x, filter]; // If the input is a 1-D tensor, align it with the channels. // // For fusedConv2d, the inputs (x, W, bias, preluActivationWeights) are // supposed to be aligned with the dataFormat. The 4-D tensor inputs or // scalar inputs are originally aligned, but the 1-D tensor inputs are // supposed to be aligned with the channels (only bias and PReLU activation // weights could be a 1-D tensor). var alignInputWithDataFormat = function (input, dataFormat) { if (dataFormat === 'NCHW' && input.shape.length === 1 && input.shape[0] !== 1) { var alignedInput = reshape({ inputs: { x: input }, backend: backend, attrs: { shape: [input.shape[0], 1, 1] } }); intermediates.push(alignedInput); return alignedInput; } return input; }; if (hasBias) { inputs.push(alignInputWithDataFormat(bias, dataFormat)); } if (hasPreluActivationWeights) { inputs.push(alignInputWithDataFormat(preluActivationWeights, dataFormat)); } if (hasLeakyreluAlpha) { var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } return inputs; }; if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) { out = conv2dByMatMul({ x: x, filter: filter, convInfo: convInfo, backend: backend, bias: bias, activation: activation, preluActivationWeights: preluActivationWeights, leakyreluAlpha: leakyreluAlpha }); } else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast' && tf.env().getBool('WEBGL_EXP_CONV')) { var fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null; var program = new Conv2DPackedProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); var customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; var inputs_1 = prepareInputs(); out = backend.runWebGLProgram(program, inputs_1, 'float32', customValues); } else if (tf.env().getBool('WEBGL_CONV_IM2COL')) { out = conv2dWithIm2Row({ x: x, filter: filter, convInfo: convInfo, backend: backend, bias: bias, activation: activation, preluActivationWeights: preluActivationWeights, leakyreluAlpha: leakyreluAlpha }); } else { var fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null; var program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); var inputs_2 = prepareInputs(); out = backend.runWebGLProgram(program, inputs_2, 'float32'); } var outReshaped = reshape({ inputs: { x: out }, backend: backend, attrs: { shape: convInfo.outShape } }); intermediates.push(out); intermediates.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return outReshaped; } var fusedConv2DConfig = { kernelName: tf.FusedConv2D, backendName: 'webgl', kernelFunc: fusedConv2d, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function fusedDepthwiseConv2D(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, filter = inputs.filter, bias = inputs.bias, preluActivationWeights = inputs.preluActivationWeights; var strides = attrs.strides, pad = attrs.pad, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode, activation = attrs.activation, leakyreluAlpha = attrs.leakyreluAlpha; var intermediates = []; var $dilations = dilations; if ($dilations == null) { $dilations = [1, 1]; } tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), function () { return 'Error in depthwiseConv2d: Either strides or dilations must be ' + "1. Got strides ".concat(strides, " and dilations '").concat($dilations, "'"); }); var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */); var shouldPackDepthwiseConv = tf.env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1; var fusedActivation = activation ? mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) : null; var programInputs = [x, filter]; var hasBias = bias != null; var hasPreluActivationWeights = preluActivationWeights != null; var hasLeakyreluAlpha = activation === 'leakyrelu'; if (hasBias) { programInputs.push(bias); } if (hasPreluActivationWeights) { programInputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32')); programInputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } var program; if (shouldPackDepthwiseConv) { program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); } else { program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); } var customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; var result = backend.runWebGLProgram(program, programInputs, 'float32', customValues); intermediates.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return result; } var fusedDepthwiseConv2DConfig = { kernelName: tf.FusedDepthwiseConv2D, backendName: 'webgl', kernelFunc: fusedDepthwiseConv2D, }; var GatherNDProgram = /** @class */ (function () { function GatherNDProgram(sliceDim, strides, shape, paramsShape) { this.sliceDim = sliceDim; this.strides = strides; this.paramsShape = paramsShape; this.variableNames = ['x', 'indices']; this.outputShape = shape; var dtype = getCoordsDataType(shape.length); var mainLoop = "\n int index;"; for (var j = 0; j < this.sliceDim; j++) { mainLoop += "\n index = round(getIndices(coords[0], ".concat(j, "));\n out_of_bounds = out_of_bounds || index < 0;\n out_of_bounds = out_of_bounds || index >= ").concat(this.paramsShape[j], ";\n flattenIndex += index * ").concat(this.strides[j], ";"); } this.userCode = "\n void main() {\n ".concat(dtype, " coords = getOutputCoords();\n int flattenIndex = 0;\n bool out_of_bounds = false;\n\n ").concat(mainLoop, "\n\n setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1]));\n }\n "); } return GatherNDProgram; }()); function gatherNd(args) { var inputs = args.inputs, backend = args.backend; var params = inputs.params, indices = inputs.indices; var indicesShape = indices.shape; var sliceRank = indicesShape[indicesShape.length - 1]; var paramsSize = tf.util.sizeFromShape(params.shape); var _a = __read(tf.backend_util.prepareAndValidate(params, indices), 4), resultShape = _a[0], numSlices = _a[1], sliceSize = _a[2], strides = _a[3]; var flattenIndices = reshape({ inputs: { x: indices }, backend: backend, attrs: { shape: [numSlices, sliceRank] } }); var flattenX = reshape({ inputs: { x: params }, backend: backend, attrs: { shape: [(tf.util.sizeFromShape(params.shape) / sliceSize), sliceSize] } }); if (backend.shouldExecuteOnCPU([params, indices]) || params.dtype === 'string') { var indicesData = backend.readSync(indices.dataId); var paramsBuf = backend.bufferSync(params); var outValue = gatherNdImplCPU(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize); return backend.makeTensorInfo(resultShape, params.dtype, outValue.values); } var program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize], params.shape); var res = backend.runWebGLProgram(program, [flattenX, flattenIndices], flattenX.dtype); var reshaped = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: resultShape } }); backend.disposeIntermediateTensorInfo(flattenIndices); backend.disposeIntermediateTensorInfo(flattenX); backend.disposeIntermediateTensorInfo(res); return reshaped; } var gatherNdConfig = { kernelName: tf.GatherNd, backendName: 'webgl', kernelFunc: gatherNd }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var GatherProgram = /** @class */ (function () { function GatherProgram(aShape, outputShape) { this.variableNames = ['A', 'indices']; this.outputShape = outputShape; this.rank = outputShape.length; var dtype = getCoordsDataType(this.rank); var sourceCoords = getSourceCoords$1(aShape); this.userCode = "\n void main() {\n ".concat(dtype, " resRC = getOutputCoords();\n int index = int(getIndices(resRC.x, resRC.z));\n float inBounds = (index >= 0) && (index < ").concat(aShape[2], ") ? 1.0 : 0.0;\n setOutput(inBounds * getA(").concat(sourceCoords, "));\n }\n "); } return GatherProgram; }()); // The input and output are always flattened into rank 4 tensors. function getSourceCoords$1(aShape, axis) { var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; var sourceCoords = []; for (var i = 0; i < aShape.length; i++) { if (i === 2) { sourceCoords.push('index'); } else { sourceCoords.push("".concat(currentCoords[i])); } } return sourceCoords.join(); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function gatherV2(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, indices = inputs.indices; var axis = attrs.axis, batchDims = attrs.batchDims; var parsedAxis = tf.util.parseAxisParam(axis, x.shape)[0]; if (tf.env().get('DEBUG')) { // In debug mode, throw error when any index is out of bound. // Otherwise, just fill out of bounds with zeroes. var indicesVals = backend.readSync(indices.dataId); var axisDim_1 = x.shape[parsedAxis]; var _loop_1 = function (i) { var index = indicesVals[i]; tf.util.assert(index <= axisDim_1 - 1 && index >= 0, function () { return "GatherV2: the index value ".concat(index, " is not in [0, ").concat(axisDim_1 - 1, "]"); }); }; for (var i = 0; i < indicesVals.length; ++i) { _loop_1(i); } } var shapeInfo = tf.backend_util.segment_util.collectGatherOpShapeInfo(x, indices, parsedAxis, batchDims); var indicesSize = tf.util.sizeFromShape(indices.shape); var toDispose = []; var flattenX = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: [ shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize, shapeInfo.sliceSize ] } }); var flattenIndex = reshape({ inputs: { x: indices }, backend: backend, attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] } }); toDispose.push(flattenX); toDispose.push(flattenIndex); var flattenOutputShape = [ shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize, shapeInfo.sliceSize ]; if (backend.shouldExecuteOnCPU([x, indices]) || x.dtype === 'string') { var indicesBuf = backend.bufferSync(flattenIndex); var xBuf = backend.bufferSync(flattenX); var outBuf = gatherV2ImplCPU(xBuf, indicesBuf, flattenOutputShape); toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values); } var program = new GatherProgram(flattenX.shape, flattenOutputShape); var res = backend.runWebGLProgram(program, [flattenX, flattenIndex], flattenX.dtype); toDispose.push(res); var reshaped = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: shapeInfo.outputShape } }); toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return reshaped; } var gatherV2Config = { kernelName: tf.GatherV2, backendName: 'webgl', kernelFunc: gatherV2 }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var GREATER = "return float(a > b);"; var GREATER_PACKED = "\n return vec4(greaterThan(a, b));\n"; var greater = binaryKernelFunc({ opSnippet: GREATER, packedOpSnippet: GREATER_PACKED, cpuKernelImpl: greaterImplCPU, dtype: 'bool' }); var greaterConfig = { kernelName: tf.Greater, backendName: 'webgl', kernelFunc: greater }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var GREATER_EQUAL = "return float(a >= b);"; var GREATER_EQUAL_PACKED = "\n return vec4(greaterThanEqual(a, b));\n"; var greaterEqual = binaryKernelFunc({ opSnippet: GREATER_EQUAL, packedOpSnippet: GREATER_EQUAL_PACKED, dtype: 'bool', cpuKernelImpl: greaterEqualImplCPU }); var greaterEqualConfig = { kernelName: tf.GreaterEqual, backendName: 'webgl', kernelFunc: greaterEqual }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function ifft(args) { var inputs = args.inputs, backend = args.backend; var input = inputs.input; return fftImpl(input, true /* inverse */, backend); } var ifftConfig = { kernelName: tf.IFFT, backendName: 'webgl', kernelFunc: ifft }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var IS_FINITE = "return float(!isnan(x) && !isinf(x));"; var isFinite = unaryKernelFunc({ opSnippet: IS_FINITE, dtype: 'bool' }); var isFiniteConfig = { kernelName: tf.IsFinite, backendName: 'webgl', kernelFunc: isFinite, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var IS_INF = "return float(isinf(x));"; var isInf = unaryKernelFunc({ opSnippet: IS_INF, dtype: 'bool' }); var isInfConfig = { kernelName: tf.IsInf, backendName: 'webgl', kernelFunc: isInf, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var IS_NAN = "return float(isnan(x));"; var isNaN = unaryKernelFunc({ opSnippet: IS_NAN, dtype: 'bool' }); var isNaNConfig = { kernelName: tf.IsNan, backendName: 'webgl', kernelFunc: isNaN, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var LESS = "return float(a < b);"; var LESS_PACKED = "\n return vec4(lessThan(a, b));\n"; var less = binaryKernelFunc({ opSnippet: LESS, packedOpSnippet: LESS_PACKED, cpuKernelImpl: lessImplCPU, dtype: 'bool' }); var lessConfig = { kernelName: tf.Less, backendName: 'webgl', kernelFunc: less }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var LESS_EQUAL = "return float(a <= b);"; var LESS_EQUAL_PACKED = "\n return vec4(lessThanEqual(a, b));\n"; var lessEqual = binaryKernelFunc({ opSnippet: LESS_EQUAL, packedOpSnippet: LESS_EQUAL_PACKED, cpuKernelImpl: lessEqualImplCPU, dtype: 'bool' }); var lessEqualConfig = { kernelName: tf.LessEqual, backendName: 'webgl', kernelFunc: lessEqual }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function linSpace(args) { var backend = args.backend, attrs = args.attrs; var start = attrs.start, stop = attrs.stop, num = attrs.num; // TODO: Use CPU implementation due to the precision problem in Safari. var outVals = linSpaceImplCPU(start, stop, num); return backend.makeTensorInfo([outVals.length], 'float32', outVals); } var linSpaceConfig = { kernelName: tf.LinSpace, backendName: 'webgl', kernelFunc: linSpace }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Windows chrome return 0 if the input is negative value. We will specifically // return NaN if the input is 0 to solve compatiblity issue. var LOG = CHECK_NAN_SNIPPET_UNARY + "\n return x < 0.0 ? 0./0. : log(x);\n"; var LOG_PACKED = "\n vec4 result = log(x);\n bvec4 isNaN = isnan(x);\n result.r = isNaN.r ? x.r : (x.r < 0.0 ? 0./0. : result.r);\n result.g = isNaN.g ? x.g : (x.g < 0.0 ? 0./0. : result.g);\n result.b = isNaN.b ? x.b : (x.b < 0.0 ? 0./0. : result.b);\n result.a = isNaN.a ? x.a : (x.a < 0.0 ? 0./0. : result.a);\n return result;\n"; var log = unaryKernelFunc({ opSnippet: LOG, packedOpSnippet: LOG_PACKED, cpuKernelImpl: logImplCPU }); var logConfig = { kernelName: tf.Log, backendName: 'webgl', kernelFunc: log }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var LOG1P = CHECK_NAN_SNIPPET_UNARY + "\n return log(1.0 + x);\n"; var log1p = unaryKernelFunc({ opSnippet: LOG1P }); var log1pConfig = { kernelName: tf.Log1p, backendName: 'webgl', kernelFunc: log1p, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var LOGICAL_AND = "return float(a >= 1.0 && b >= 1.0);"; var LOGICAL_AND_PACKED = "\n return vec4(\n vec4(greaterThanEqual(a, vec4(1.0))) *\n vec4(greaterThanEqual(b, vec4(1.0))));\n"; var logicalAnd = binaryKernelFunc({ opSnippet: LOGICAL_AND, packedOpSnippet: LOGICAL_AND_PACKED, dtype: 'bool' }); var logicalAndConfig = { kernelName: tf.LogicalAnd, backendName: 'webgl', kernelFunc: logicalAnd }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var LOGICAL_NOT = "return float(!(x >= 1.0));"; var logicalNot = unaryKernelFunc({ opSnippet: LOGICAL_NOT }); var logicalNotConfig = { kernelName: tf.LogicalNot, backendName: 'webgl', kernelFunc: logicalNot, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var LOGICAL_OR = "return float(a >= 1.0 || b >= 1.0);"; var LOGICAL_OR_PACKED = "\n return min(\n vec4(greaterThanEqual(a, vec4(1.0))) +\n vec4(greaterThanEqual(b, vec4(1.0))),\n vec4(1.0));\n"; var logicalOr = binaryKernelFunc({ opSnippet: LOGICAL_OR, packedOpSnippet: LOGICAL_OR_PACKED, dtype: 'bool' }); var logicalOrConfig = { kernelName: tf.LogicalOr, backendName: 'webgl', kernelFunc: logicalOr }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var LRNProgram = /** @class */ (function () { function LRNProgram(xShape, radius, bias, alpha, beta) { this.variableNames = ['x']; this.outputShape = []; var rad = radius; var maxD = xShape[3] - 1; this.outputShape = xShape; // optimize pow(bias + alpha * sum, -beta) // src: https://github.com/tensorflow/tensorflow/.. // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/.. // tensorflow/core/kernels/mkl_lrn_op.cc#L320 var powOperator; var basis = "float(".concat(bias, ") + float(").concat(alpha, ") * sum"); if (beta === 0.5) { powOperator = "inversesqrt(".concat(basis, ")"); } else if (beta === 1.0) { powOperator = "1.0/(".concat(basis, ")"); } else { powOperator = "exp(log(".concat(basis, ") * float(-").concat(beta, "));"); } this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n int d = coords[3];\n float x = getX(b, r, c, d);\n float sum = 0.0;\n for (int j = -".concat(rad, "; j <= ").concat(rad, "; j++) {\n int idx = d + j;\n if (idx >= 0 && idx <= ").concat(maxD, ") {\n float z = getX(b, r, c, idx);\n sum += z * z;\n }\n }\n float val = x * ").concat(powOperator, ";\n setOutput(val);\n }\n "); } return LRNProgram; }()); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var LRNPackedProgram = /** @class */ (function () { function LRNPackedProgram(xShape, radius, bias, alpha, beta) { this.variableNames = ['x']; this.outputShape = []; this.packedInputs = true; this.packedOutput = true; var rad = radius; var maxD = xShape[3] - 1; this.outputShape = xShape; // optimize pow(bias + alpha * sum, -beta) // src: https://github.com/tensorflow/tensorflow/.. // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/.. // tensorflow/core/kernels/mkl_lrn_op.cc#L320 var powOperator; var basis = "float(".concat(bias, ") + float(").concat(alpha, ") * sum"); if (beta === 0.5) { powOperator = "inversesqrt(".concat(basis, ")"); } else if (beta === 1.0) { powOperator = "1.0/(".concat(basis, ")"); } else { powOperator = "exp(log(".concat(basis, ") * float(-").concat(beta, "));"); } this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords.x;\n int r = coords.y;\n int c = coords.z;\n int d = coords.w;\n\n bool hasNextCol = d < ".concat(this.outputShape[3], ";\n bool hasNextRow = c < ").concat(this.outputShape[2], ";\n\n vec4 sum = vec4(0.);\n vec4 xFragAtOutputCoords = getX(b, r, c, d);\n\n vec4 xAtOutputCoords = vec4(\n getChannel(xFragAtOutputCoords, vec2(c, d)),\n hasNextCol ?\n getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,\n hasNextRow ?\n getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,\n (hasNextRow && hasNextCol) ?\n getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0\n );\n\n int firstChannel = d - ").concat(rad, ";\n vec2 cache = vec2(0.);\n if(firstChannel >= 0){\n vec4 firstChannelFrag = getX(b, r, c, firstChannel);\n cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));\n if(hasNextRow){\n cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));\n }\n }\n\n ivec2 depth = ivec2(d, d + 1);\n for (int j = - ").concat(rad, "; j <= ").concat(rad, "; j++) {\n ivec2 idx = depth + j;\n bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));\n bvec2 belowUpperBound = lessThanEqual(idx, ivec2(").concat(maxD, "));\n\n bool depthInRange = aboveLowerBound.x && belowUpperBound.x;\n bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;\n\n if(depthInRange || depthPlusOneInRange){\n vec4 z = vec4(0.);\n vec4 xFragAtCurrentDepth;\n z.xz = cache.xy;\n if(depthPlusOneInRange && hasNextCol){\n xFragAtCurrentDepth = idx.y != d ?\n getX(b, r, c, idx.y) : xFragAtOutputCoords;\n z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));\n if(hasNextRow){\n z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));\n }\n }\n cache.xy = z.yw;\n sum += z * z;\n }\n }\n vec4 result = xAtOutputCoords * ").concat(powOperator, ";\n setOutput(result);\n }\n "); } return LRNPackedProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var lrn = function (args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var depthRadius = attrs.depthRadius, bias = attrs.bias, alpha = attrs.alpha, beta = attrs.beta; var program = tf.env().getBool('WEBGL_PACK_NORMALIZATION') ? new LRNPackedProgram(x.shape, depthRadius, bias, alpha, beta) : new LRNProgram(x.shape, depthRadius, bias, alpha, beta); return backend.runWebGLProgram(program, [x], x.dtype); }; // tslint:disable-next-line: variable-name var LRNConfig = { kernelName: tf.LRN, backendName: 'webgl', kernelFunc: lrn }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var LRNGradProgram = /** @class */ (function () { function LRNGradProgram(inputShape, depthRadius, bias, alpha, beta) { this.variableNames = ['inputImage', 'outputImage', 'dy']; this.outputShape = []; this.outputShape = inputShape; this.depth = inputShape[3]; this.depthRadius = depthRadius; this.bias = bias; this.alpha = alpha; this.beta = beta; this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n\n float result = 0.0;\n for (int d = 0; d < ".concat(this.depth, "; ++d) {\n int depthBegin = int(max(0.0, float(d - ").concat(depthRadius, ")));\n int depthEnd = int(min(float(").concat(this.depth, "),\n float(d + ").concat(depthRadius, " + 1)));\n\n const int MIN_DEPTH_BEGIN = 0;\n const int MAX_DEPTH_END = ").concat(this.depth, ";\n\n float norm = 0.0;\n for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd) {\n norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);\n }\n else {\n break;\n }\n }\n\n norm = float(").concat(alpha, ") * norm + float(").concat(bias, ");\n\n for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd){\n float dyi = -2.0 * float(").concat(alpha, ")\n * float(").concat(beta, ")\n * getInputImage(b, r, c, k) * getOutputImage(b, r, c, d)\n / norm;\n if (k == d) {\n dyi += pow(norm, -1.0 * ").concat(beta, ");\n }\n if (k == coords[3]) {\n dyi *= getDy(b, r, c, d);\n result += dyi;\n }\n }\n else {\n break;\n }\n }\n }\n setOutput(result);\n }\n "); } return LRNGradProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var lrnGrad = function (args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, y = inputs.y, dy = inputs.dy; var depthRadius = attrs.depthRadius, bias = attrs.bias, alpha = attrs.alpha, beta = attrs.beta; var program = new LRNGradProgram(x.shape, depthRadius, bias, alpha, beta); return backend.runWebGLProgram(program, [x, y, dy], x.dtype); }; // tslint:disable-next-line: variable-name var LRNGradConfig = { kernelName: tf.LRNGrad, backendName: 'webgl', kernelFunc: lrnGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxImpl(x, reduceShape, outShape, backend) { var inSize = tf.util.sizeFromShape(reduceShape); var xSize = tf.util.sizeFromShape(x.shape); var batchSize = xSize / inSize; var reshapedInput = reshape({ inputs: { x: x }, attrs: { shape: [batchSize, inSize] }, backend: backend }); var reduced = reduce(reshapedInput, x.dtype, 'max', backend); var reshapedOutput = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend: backend }); backend.disposeIntermediateTensorInfo(reshapedInput); backend.disposeIntermediateTensorInfo(reduced); return reshapedOutput; } function max(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var reductionIndices = attrs.reductionIndices, keepDims = attrs.keepDims; var xRank = x.shape.length; var origAxes = tf.util.parseAxisParam(reductionIndices, x.shape); var axes = origAxes; var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); var maxInputIsTransposed = permutedAxes != null; var shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]); var maxInput = x; if (maxInputIsTransposed) { if (shouldExecuteOnCPU) { var xTexData = backend.texData.get(maxInput.dataId); var values = xTexData.values; var newShape = new Array(xRank); for (var i = 0; i < newShape.length; i++) { newShape[i] = x.shape[permutedAxes[i]]; } var maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape); maxInput = backend.makeTensorInfo(newShape, x.dtype); var maxInputData = backend.texData.get(maxInput.dataId); maxInputData.values = maxInputValues; } else { maxInput = transposeImpl(x, permutedAxes, backend); } axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); } tf.backend_util.assertAxesAreInnerMostDims('max', axes, xRank); var _a = __read(tf.backend_util.computeOutAndReduceShapes(maxInput.shape, axes), 2), maxOutShape = _a[0], reduceShape = _a[1]; var outShape = maxOutShape; if (keepDims) { // rather than reshape at the end, set the target shape here. outShape = tf.backend_util.expandShapeToKeepDim(maxOutShape, origAxes); } var out; if (shouldExecuteOnCPU) { var xTexData = backend.texData.get(maxInput.dataId); var values = xTexData.values; var outValues = maxImplCPU(values, tf.util.sizeFromShape(reduceShape), outShape, x.dtype); out = backend.makeTensorInfo(outShape, x.dtype); var outData = backend.texData.get(out.dataId); outData.values = outValues; } else { out = maxImpl(maxInput, reduceShape, outShape, backend); } if (maxInputIsTransposed) { backend.disposeIntermediateTensorInfo(maxInput); } return out; } var maxConfig = { kernelName: tf.Max, backendName: 'webgl', kernelFunc: max }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var MAXIMUM = CHECK_NAN_SNIPPET + "\n return max(a, b);\n"; var MAXIMUM_PACKED = "\n vec4 result = vec4(max(a, b));\n bvec4 isNaNA = isnan(a);\n bvec4 isNaNB = isnan(b);\n bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);\n " + CHECK_NAN_SNIPPET_PACKED + "\n return result;\n"; var maximum = binaryKernelFunc({ opSnippet: MAXIMUM, packedOpSnippet: MAXIMUM_PACKED, cpuKernelImpl: maximumImplCPU }); var maximumConfig = { kernelName: tf.Maximum, backendName: 'webgl', kernelFunc: maximum }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxPool(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; assertNotComplex(x, 'maxPool'); var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode; var dilations = 1; tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in maxPool: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'"); }); var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && tf.util.arraysEqual(convInfo.inShape, convInfo.outShape)) { return identity({ inputs: { x: x }, backend: backend }); } var maxPoolProgram = new Pool2DProgram(convInfo, 'max', false); return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype); } var maxPoolConfig = { kernelName: tf.MaxPool, backendName: 'webgl', kernelFunc: maxPool }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxPool3d(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dimRoundingMode = attrs.dimRoundingMode; var dilations = [1, 1, 1]; var convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat); var maxPoolProgram = new Pool3DProgram(convInfo, 'max', false); return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype); } var maxPool3DConfig = { kernelName: tf.MaxPool3D, backendName: 'webgl', kernelFunc: maxPool3d }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var MaxPool2DBackpropProgram = /** @class */ (function () { function MaxPool2DBackpropProgram(convInfo) { this.variableNames = ['dy', 'maxPos']; this.outputShape = convInfo.inShape; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var dilationHeight = convInfo.dilationHeight; var effectiveFilterHeight = convInfo.effectiveFilterHeight; var effectiveFilterWidth = convInfo.effectiveFilterWidth; var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; var lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1; this.userCode = "\n const ivec2 pads = ivec2(".concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, "; wC++) {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n int maxPosValue = ").concat(lastIndex, " - int(getMaxPos(b, idyR, idyC, d));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue = wR * ").concat(effectiveFilterWidth, " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n setOutput(dotProd);\n }\n "); } return MaxPool2DBackpropProgram; }()); var MaxPool3DBackpropProgram = /** @class */ (function () { function MaxPool3DBackpropProgram(convInfo) { this.variableNames = ['dy', 'maxPos']; this.outputShape = convInfo.inShape; var strideDepth = convInfo.strideDepth; var strideHeight = convInfo.strideHeight; var strideWidth = convInfo.strideWidth; var dilationDepth = convInfo.dilationDepth; var dilationHeight = convInfo.dilationHeight; var dilationWidth = convInfo.dilationWidth; var effectiveFilterDepth = convInfo.effectiveFilterDepth; var effectiveFilterHeight = convInfo.effectiveFilterHeight; var effectiveFilterWidth = convInfo.effectiveFilterWidth; var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; var lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1; this.userCode = "\n const ivec3 pads = ivec3(".concat(padFront, ", ").concat(padTop, ", ").concat(padLeft, ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < ").concat(effectiveFilterDepth, ";\n wD += ").concat(dilationDepth, ") {\n float dyD = float(dyDCorner + wD) / ").concat(strideDepth, ".0;\n\n if (dyD < 0.0 || dyD >= ").concat(convInfo.outDepth, ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < ").concat(effectiveFilterHeight, ";\n wR += ").concat(dilationHeight, ") {\n float dyR = float(dyRCorner + wR) / ").concat(strideHeight, ".0;\n\n if (dyR < 0.0 || dyR >= ").concat(convInfo.outHeight, ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < ").concat(effectiveFilterWidth, ";\n wC += ").concat(dilationWidth, ") {\n float dyC = float(dyCCorner + wC) / ").concat(strideWidth, ".0;\n\n if (dyC < 0.0 || dyC >= ").concat(convInfo.outWidth, ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n int maxPosValue = ").concat(lastIndex, " -\n int(getMaxPos(batch, idyD, idyR, idyC, ch));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue =\n wD * ").concat(effectiveFilterHeight, " * ").concat(effectiveFilterWidth, " +\n wR * ").concat(effectiveFilterWidth, " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n }\n setOutput(dotProd);\n }\n "); } return MaxPool3DBackpropProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxPool3DGrad(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var dy = inputs.dy, input = inputs.input; var x = input; var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode; var dilations = [1, 1, 1]; var convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); var maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', true /* get positions */); var maxPool3dPositions = backend.runWebGLProgram(maxPool3dPositionsProgram, [x], x.dtype); var maxPoolBackpropProgram = new MaxPool3DBackpropProgram(convInfo); var result = backend.runWebGLProgram(maxPoolBackpropProgram, [dy, maxPool3dPositions], x.dtype); backend.disposeIntermediateTensorInfo(maxPool3dPositions); return result; } var maxPool3DGradConfig = { kernelName: tf.MaxPool3DGrad, backendName: 'webgl', kernelFunc: maxPool3DGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxPoolGrad(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var dy = inputs.dy, input = inputs.input, output = inputs.output; var x = input; assertNotComplex([input, output], 'maxPoolGrad'); var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode; var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode); var getPositions = true; var maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions); var maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype); var maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo); var result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype); backend.disposeIntermediateTensorInfo(maxPoolPositions); return result; } var maxPoolGradConfig = { kernelName: tf.MaxPoolGrad, backendName: 'webgl', kernelFunc: maxPoolGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, backend) { var program = new Pool2DProgram(convInfo, 'max', false); var poolOutput = backend.runWebGLProgram(program, [x], 'float32'); program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex); var indexOutput = backend.runWebGLProgram(program, [x], 'float32'); return [poolOutput, indexOutput]; } var maxPoolWithArgmaxConfig = { kernelName: tf.MaxPoolWithArgmax, backendName: 'webgl', kernelFunc: function (_a) { var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend; var x = inputs.x; var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, includeBatchInIndex = attrs.includeBatchInIndex; var webglBackend = backend; tf.util.assert(x.shape.length === 4, function () { return "Error in maxPool: input must be rank 4 but got rank ".concat(x.shape.length, "."); }); var dilations = [1, 1]; tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in maxPool: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'"); }); var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad); var _b = __read(maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, webglBackend), 2), result = _b[0], indexes = _b[1]; return [result, indexes]; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function meanImpl(x, reduceShape, outShape, backend) { var inSize = tf.util.sizeFromShape(reduceShape); var xSize = tf.util.sizeFromShape(x.shape); var batchSize = xSize / inSize; var reshapedInput = reshape({ inputs: { x: x }, attrs: { shape: [batchSize, inSize] }, backend: backend }); var reduced = reduce(reshapedInput, 'float32', 'mean', backend); var reshapedOutput = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend: backend }); backend.disposeIntermediateTensorInfo(reshapedInput); backend.disposeIntermediateTensorInfo(reduced); return reshapedOutput; } var meanConfig = { kernelName: tf.Mean, backendName: 'webgl', kernelFunc: function (_a) { var e_1, _b; var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend; var x = inputs.x; var keepDims = attrs.keepDims, axis = attrs.axis; var webglBackend = backend; var xRank = x.shape.length; var origAxes = tf.util.parseAxisParam(axis, x.shape); var axes = origAxes; var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); var meanInputIsTransposed = permutedAxes != null; var shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]); var intermediates = []; var meanInput = x; if (meanInputIsTransposed) { if (shouldExecuteOnCPU) { var xTexData = webglBackend.texData.get(meanInput.dataId); var values = xTexData.values; var newShape = new Array(xRank); for (var i = 0; i < newShape.length; i++) { newShape[i] = x.shape[permutedAxes[i]]; } var meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape); meanInput = webglBackend.makeTensorInfo(newShape, x.dtype); var meanInputData = webglBackend.texData.get(meanInput.dataId); meanInputData.values = meanInputValues; } else { meanInput = transposeImpl(x, permutedAxes, webglBackend); } intermediates.push(meanInput); axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); } tf.backend_util.assertAxesAreInnerMostDims('sum', axes, xRank); var _c = __read(tf.backend_util.computeOutAndReduceShapes(meanInput.shape, axes), 2), meanOutShape = _c[0], reduceShape = _c[1]; var outShape = meanOutShape; if (keepDims) { // rather than reshape at the end, set the target shape here. outShape = tf.backend_util.expandShapeToKeepDim(meanOutShape, origAxes); } var out = meanImpl(meanInput, reduceShape, outShape, webglBackend); try { for (var intermediates_1 = __values(intermediates), intermediates_1_1 = intermediates_1.next(); !intermediates_1_1.done; intermediates_1_1 = intermediates_1.next()) { var i = intermediates_1_1.value; webglBackend.disposeIntermediateTensorInfo(i); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (intermediates_1_1 && !intermediates_1_1.done && (_b = intermediates_1.return)) _b.call(intermediates_1); } finally { if (e_1) throw e_1.error; } } return out; } }; function min(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var axis = attrs.axis, keepDims = attrs.keepDims; var xRank = x.shape.length; var origAxes = tf.util.parseAxisParam(axis, x.shape); var axes = origAxes; var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); var permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } }); axes = tf.backend_util.getInnerMostAxes(axes.length, x.shape.length); } tf.backend_util.assertAxesAreInnerMostDims('min', axes, xRank); var _a = __read(tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes), 2), outShape = _a[0], reduceShape = _a[1]; var inSize = tf.util.sizeFromShape(reduceShape); var a2D = reshape({ inputs: { x: permutedX }, backend: backend, attrs: { shape: [-1, inSize] } }); var reduced = reduce(a2D, a2D.dtype, 'min', backend); var res; if (keepDims) { var newShape = tf.backend_util.expandShapeToKeepDim(outShape, origAxes); res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: newShape } }); } else { res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: outShape } }); } backend.disposeIntermediateTensorInfo(a2D); backend.disposeIntermediateTensorInfo(reduced); if (permutedAxes != null) { backend.disposeIntermediateTensorInfo(permutedX); } return res; } var minConfig = { kernelName: tf.Min, backendName: 'webgl', kernelFunc: min }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var MINIMUM = CHECK_NAN_SNIPPET + "\n return min(a, b);\n"; var MINIMUM_PACKED = "\n vec4 result = vec4(min(a, b));\n bvec4 isNaNA = isnan(a);\n bvec4 isNaNB = isnan(b);\n bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);\n " + CHECK_NAN_SNIPPET_PACKED + "\n return result;\n"; var minimum = binaryKernelFunc({ opSnippet: MINIMUM, packedOpSnippet: MINIMUM_PACKED, cpuKernelImpl: minimumImplCPU }); var minimumConfig = { kernelName: tf.Minimum, backendName: 'webgl', kernelFunc: minimum }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var MirrorPadProgram = /** @class */ (function () { function MirrorPadProgram(xShape, paddings, mode) { this.variableNames = ['x']; this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */); var rank = xShape.length; var dtype = getCoordsDataType(rank); var start = paddings.map(function (p) { return p[0]; }).join(','); var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(','); var unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank); var offset = mode === 'reflect' ? 0 : 1; if (rank === 1) { this.userCode = "\n int start = ".concat(start, ";\n int end = ").concat(end, ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start) {\n outC = start * 2 - outC - ").concat(offset, ";\n } else if(outC >= end) {\n outC = (end - 1) * 2 - outC + ").concat(offset, ";\n }\n setOutput(getX(outC - start));\n }\n "); return; } this.userCode = "\n ".concat(dtype, " start = ").concat(dtype, "(").concat(start, ");\n ").concat(dtype, " end = ").concat(dtype, "(").concat(end, ");\n\n void main() {\n ").concat(dtype, " outC = getOutputCoords();\n for (int i = 0; i < ").concat(rank, "; i++) {\n if (outC[i] < start[i]) {\n outC[i] = start[i] * 2 - outC[i] - ").concat(offset, ";\n } else if(outC[i] >= end[i]) {\n outC[i] = (end[i] - 1) * 2 - outC[i] + ").concat(offset, ";\n }\n }\n ").concat(dtype, " coords = outC - start;\n setOutput(getX(").concat(unpackedCoords, "));\n }\n "); } return MirrorPadProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Example shader code for * `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')` * ``` * const int start = int(2); * const int end = int(5); * * void main() { * int outputLoc = getOutputCoords(); * vec4 result = vec4(0.); * * int rc = outputLoc; * * int source = rc; * if (source < start) { * source = start * 2 - source - 0; * } else if (source >= end) { * source = (end - 1) * 2 - source + 0; * } * source -= start; * * result[0] = getChannel(getX(source), source); * rc += 1; * if(rc < 6) { * int source = rc; * if (source < start) { * source = start * 2 - source - 0; * } else if (source >= end) { * source = (end - 1) * 2 - source + 0; * } * source -= start; * * result[1] = getChannel(getX(source), source); * } * * setOutput(result); * } * ``` */ var MirrorPadPackedProgram = /** @class */ (function () { function MirrorPadPackedProgram(xShape, paddings, mode) { this.variableNames = ['x']; this.packedInputs = true; this.packedOutput = true; this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */); var rank = xShape.length; var dtype = getCoordsDataType(rank); var start = paddings.map(function (p) { return p[0]; }).join(','); var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(','); var coords = getChannels('rc', rank); var source = getChannels('source', rank); var cLimit = "".concat(coords[rank - 1], " < ").concat(this.outputShape[rank - 1]); var innerDims = rank === 1 ? 'source' : "vec2(".concat(source.slice(-2).join(), ")"); var offset = mode === 'reflect' ? 0 : 1; var mainLoop = ''; if (rank === 1) { var padSetup = "\n ".concat(dtype, " source = rc;\n if (source < start) {\n source = start * 2 - source - ").concat(offset, ";\n } else if (source >= end) {\n source = (end - 1) * 2 - source + ").concat(offset, ";\n }\n source -= start;\n "); mainLoop = "\n ".concat(dtype, " rc = outputLoc;\n ").concat(padSetup, "\n result[0] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n ").concat(coords[rank - 1], " += 1;\n if(").concat(cLimit, ") {\n ").concat(padSetup, "\n result[1] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n }\n "); } else { var padSetup = "\n ".concat(dtype, " source = rc;\n ").concat(dtype, " lt = ").concat(dtype, "(lessThan(source, start));\n ").concat(dtype, " gte = ").concat(dtype, "(greaterThanEqual(source, end));\n ").concat(dtype, " orig = 1 - (lt + gte);\n source = orig * source +\n lt * (start * 2 - source - ").concat(offset, ") +\n gte * ((end - 1) * 2 - source + ").concat(offset, ");\n source -= start;\n "); mainLoop = "\n ".concat(dtype, " rc = outputLoc;\n ").concat(padSetup, "\n result[0] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n ").concat(coords[rank - 1], " += 1;\n if(").concat(cLimit, ") {\n ").concat(padSetup, "\n result[1] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n }\n rc = outputLoc;\n ").concat(coords[rank - 2], " += 1;\n if(").concat(coords[rank - 2], " < ").concat(this.outputShape[rank - 2], ") {\n ").concat(padSetup, "\n result[2] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n ").concat(coords[rank - 1], " += 1;\n if(").concat(cLimit, ") {\n ").concat(padSetup, "\n result[3] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n }\n }\n "); } this.userCode = "\n const ".concat(dtype, " start = ").concat(dtype, "(").concat(start, ");\n const ").concat(dtype, " end = ").concat(dtype, "(").concat(end, ");\n\n void main() {\n ").concat(dtype, " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n ").concat(mainLoop, "\n setOutput(result);\n }\n "); } return MirrorPadPackedProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var mirrorPadKernelFunc = function (_a) { var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs; var x = inputs.x; var paddings = attrs.paddings, mode = attrs.mode; var program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new MirrorPadPackedProgram(x.shape, paddings, mode) : new MirrorPadProgram(x.shape, paddings, mode); var output = backend.runWebGLProgram(program, [x], x.dtype); return output; }; var mirrorPadConfig = { kernelName: tf.MirrorPad, backendName: 'webgl', kernelFunc: mirrorPadKernelFunc, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var MOD = "if (b == 0.0) return NAN;\n return mod(a, b);"; var MOD_PACKED = "\n vec4 result = mod(a, b);\n bvec4 isNaN = equal(b, vec4(0.0));\n " + CHECK_NAN_SNIPPET_PACKED + "\n return result;\n"; var mod = binaryKernelFunc({ opSnippet: MOD, packedOpSnippet: MOD_PACKED, }); var modConfig = { kernelName: tf.Mod, backendName: 'webgl', kernelFunc: mod }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var MultinomialProgram = /** @class */ (function () { function MultinomialProgram(batchSize, numOutcomes, numSamples) { this.variableNames = ['probs']; this.customUniforms = [{ name: 'seed', type: 'float' }]; this.outputShape = [batchSize, numSamples]; this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n\n float r = random(seed);\n float cdf = 0.0;\n\n for (int i = 0; i < ".concat(numOutcomes - 1, "; i++) {\n cdf += getProbs(batch, i);\n\n if (r < cdf) {\n setOutput(float(i));\n return;\n }\n }\n\n // If no other event happened, last event happened.\n setOutput(float(").concat(numOutcomes - 1, "));\n }\n "); } return MultinomialProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Without the equality check div produces 0.9999 for a = b, which when // floored can cause errors. var DIV = "\nif (a == b) {\n return 1.0;\n};\nreturn a / b;"; // We do the same as in ./binaryop_gpu, with vec4 and ivec4. // On Linux, the vectorized implementation produces NaNs when a and b are 0. var DIV_PACKED = "\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n"; var realDiv = binaryKernelFunc({ opSnippet: DIV, packedOpSnippet: DIV_PACKED, checkOutOfBounds: true }); var realDivConfig = { kernelName: tf.RealDiv, backendName: 'webgl', kernelFunc: realDiv, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SUB = 'return a - b;'; var sub = binaryKernelFunc({ opSnippet: SUB, packedOpSnippet: SUB, supportsComplex: true, cpuKernelImpl: subImplCPU }); var subConfig = { kernelName: tf.Sub, backendName: 'webgl', kernelFunc: sub }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function softmax(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var logits = inputs.logits; var dim = attrs.dim; var axes = tf.util.parseAxisParam([dim], logits.shape); var maxLogit = max({ inputs: { x: logits }, backend: backend, attrs: { reductionIndices: axes, keepDims: false } }); var expandedShape = tf.backend_util.expandShapeToKeepDim(maxLogit.shape, axes); var maxLogitsReshaped = reshape({ inputs: { x: maxLogit }, backend: backend, attrs: { shape: expandedShape } }); var a = sub({ inputs: { a: logits, b: maxLogitsReshaped }, backend: backend }); var b = exp({ inputs: { x: a }, backend: backend }); var sumExp = sum({ inputs: { x: b }, backend: backend, attrs: { axis: axes, keepDims: false } }); var sumExpReshaped = reshape({ inputs: { x: sumExp }, backend: backend, attrs: { shape: expandedShape } }); var res = realDiv({ inputs: { a: b, b: sumExpReshaped }, backend: backend }); backend.disposeIntermediateTensorInfo(maxLogit); backend.disposeIntermediateTensorInfo(maxLogitsReshaped); backend.disposeIntermediateTensorInfo(a); backend.disposeIntermediateTensorInfo(b); backend.disposeIntermediateTensorInfo(sumExp); backend.disposeIntermediateTensorInfo(sumExpReshaped); return res; } var softmaxConfig = { kernelName: tf.Softmax, backendName: 'webgl', kernelFunc: softmax }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function multinomial(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var logits = inputs.logits; var numSamples = attrs.numSamples, seed = attrs.seed, normalized = attrs.normalized; var probs = normalized ? logits : softmax({ inputs: { logits: logits }, backend: backend, attrs: { dim: logits.shape.length - 1 } }); var batchSize = probs.shape[0]; var numOutcomes = probs.shape[1]; var program = new MultinomialProgram(batchSize, numOutcomes, numSamples); var customValues = [[seed]]; var res = backend.runWebGLProgram(program, [probs], 'int32', customValues); if (!normalized) { backend.disposeIntermediateTensorInfo(probs); } return res; } var multinomialConfig = { kernelName: tf.Multinomial, backendName: 'webgl', kernelFunc: multinomial }; var NEG = CHECK_NAN_SNIPPET$1 + "\n return -x;\n"; var NEG_PACKED = "\n vec4 result = -x;\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n"; // This doesn't use unaryKernelFunc because negImplCPU is not of type // SimpleUnaryKernelImplCPU. function neg(args) { var inputs = args.inputs, backend = args.backend; var x = inputs.x; if (backend.shouldExecuteOnCPU([x])) { var xData = backend.texData.get(x.dataId); var _a = __read(negImplCPU(xData.values, x.shape, x.dtype), 2), outValues = _a[0], newShape = _a[1]; return backend.makeTensorInfo(newShape, x.dtype, outValues); } var program; if (tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { program = new UnaryOpPackedProgram(x.shape, NEG_PACKED); } else { program = new UnaryOpProgram(x.shape, NEG); } return backend.runWebGLProgram(program, [x], x.dtype); } var negConfig = { kernelName: tf.Neg, backendName: 'webgl', kernelFunc: neg }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var nonMaxSuppressionV3Impl = tf.kernel_impls.nonMaxSuppressionV3Impl; function nonMaxSuppressionV3(args) { tf.backend_util.warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var boxes = inputs.boxes, scores = inputs.scores; var maxOutputSize = attrs.maxOutputSize, iouThreshold = attrs.iouThreshold, scoreThreshold = attrs.scoreThreshold; var boxesVals = backend.readSync(boxes.dataId); var scoresVals = backend.readSync(scores.dataId); var selectedIndices = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold).selectedIndices; return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)); } var nonMaxSuppressionV3Config = { kernelName: tf.NonMaxSuppressionV3, backendName: 'webgl', kernelFunc: nonMaxSuppressionV3 }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var nonMaxSuppressionV4Impl = tf.kernel_impls.nonMaxSuppressionV4Impl; function nonMaxSuppressionV4(args) { tf.backend_util.warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var boxes = inputs.boxes, scores = inputs.scores; var maxOutputSize = attrs.maxOutputSize, iouThreshold = attrs.iouThreshold, scoreThreshold = attrs.scoreThreshold, padToMaxOutputSize = attrs.padToMaxOutputSize; var boxesVals = backend.readSync(boxes.dataId); var scoresVals = backend.readSync(scores.dataId); var _a = nonMaxSuppressionV4Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize), selectedIndices = _a.selectedIndices, validOutputs = _a.validOutputs; return [ backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs])) ]; } var nonMaxSuppressionV4Config = { kernelName: tf.NonMaxSuppressionV4, backendName: 'webgl', kernelFunc: nonMaxSuppressionV4 }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var nonMaxSuppressionV5Impl = tf.kernel_impls.nonMaxSuppressionV5Impl; function nonMaxSuppressionV5(args) { tf.backend_util.warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var boxes = inputs.boxes, scores = inputs.scores; var maxOutputSize = attrs.maxOutputSize, iouThreshold = attrs.iouThreshold, scoreThreshold = attrs.scoreThreshold, softNmsSigma = attrs.softNmsSigma; var boxesVals = backend.readSync(boxes.dataId); var scoresVals = backend.readSync(scores.dataId); var maxOutputSizeVal = maxOutputSize; var iouThresholdVal = iouThreshold; var scoreThresholdVal = scoreThreshold; var softNmsSigmaVal = softNmsSigma; var _a = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal), selectedIndices = _a.selectedIndices, selectedScores = _a.selectedScores; return [ backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores)) ]; } var nonMaxSuppressionV5Config = { kernelName: tf.NonMaxSuppressionV5, backendName: 'webgl', kernelFunc: nonMaxSuppressionV5 }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var OneHotProgram = /** @class */ (function () { function OneHotProgram(numIndices, depth, onValue, offValue) { this.variableNames = ['indices']; this.outputShape = [numIndices, depth]; this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int index = round(getIndices(coords.x));\n setOutput(mix(float(".concat(offValue, "), float(").concat(onValue, "),\n float(index == coords.y)));\n }\n "); } return OneHotProgram; }()); var oneHot = function (args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var indices = inputs.indices; var dtype = attrs.dtype, depth = attrs.depth, onValue = attrs.onValue, offValue = attrs.offValue; var indicesSize = tf.util.sizeFromShape(indices.shape); var program = new OneHotProgram(indicesSize, depth, onValue, offValue); var reshaped = reshape({ inputs: { x: indices }, backend: backend, attrs: { shape: [indicesSize] } }); var result = backend.runWebGLProgram(program, [reshaped], dtype); backend.disposeIntermediateTensorInfo(reshaped); var outShape = __spreadArray(__spreadArray([], __read(indices.shape), false), [depth], false); var out = reshape({ inputs: { x: result }, backend: backend, attrs: { shape: outShape } }); backend.disposeIntermediateTensorInfo(result); return out; }; var oneHotConfig = { kernelName: tf.OneHot, backendName: 'webgl', kernelFunc: oneHot }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function zerosLike(args) { var inputs = args.inputs, backend = args.backend; var x = inputs.x; if (x.dtype === 'complex64') { var realPart = real({ inputs: { input: x }, backend: backend }); var r = zerosLike({ inputs: { x: realPart }, backend: backend }); var imagPart = imag({ inputs: { input: x }, backend: backend }); var i = zerosLike({ inputs: { x: imagPart }, backend: backend }); var result = complex({ inputs: { real: r, imag: i }, backend: backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(r); backend.disposeIntermediateTensorInfo(imagPart); backend.disposeIntermediateTensorInfo(i); return result; } else { return fill({ attrs: { shape: x.shape, dtype: x.dtype, value: x.dtype === 'string' ? '' : 0 }, backend: backend }); } } var zerosLikeConfig = { kernelName: tf.ZerosLike, backendName: 'webgl', kernelFunc: zerosLike }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function onesLike(args) { var inputs = args.inputs, backend = args.backend; var x = inputs.x; if (x.dtype === 'string') { throw new Error('onesLike is not supported under string dtype'); } else if (x.dtype === 'complex64') { var realPart = real({ inputs: { input: x }, backend: backend }); var r = onesLike({ inputs: { x: realPart }, backend: backend }); var imagPart = imag({ inputs: { input: x }, backend: backend }); var i = zerosLike({ inputs: { x: imagPart }, backend: backend }); var result = complex({ inputs: { real: r, imag: i }, backend: backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(r); backend.disposeIntermediateTensorInfo(imagPart); backend.disposeIntermediateTensorInfo(i); return result; } else { // TODO(cais, smilkov): Add WebGL shader for onesLike: // https://github.com/tensorflow/tfjs/issues/1293 return fill({ attrs: { shape: x.shape, dtype: x.dtype, value: 1 }, backend: backend }); } } var onesLikeConfig = { kernelName: tf.OnesLike, backendName: 'webgl', kernelFunc: onesLike }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function pack(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var axis = attrs.axis; if (inputs.length === 1) { return expandDims({ inputs: { input: inputs[0] }, backend: backend, attrs: { dim: axis } }); } var shape = inputs[0].shape; var dtype = inputs[0].dtype; inputs.forEach(function (t) { tf.util.assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes'); tf.util.assert(dtype === t.dtype, function () { return 'All tensors passed to stack must have matching dtypes'; }); }); var intermediateTensorInfos = []; var expandedTensors = inputs.map(function (t) { var expandedT = expandDims({ inputs: { input: t }, backend: backend, attrs: { dim: axis } }); intermediateTensorInfos.push(expandedT); return expandedT; }); var result = concat({ inputs: expandedTensors, backend: backend, attrs: { axis: axis } }); intermediateTensorInfos.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return result; } var packConfig = { kernelName: tf.Pack, backendName: 'webgl', kernelFunc: pack }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var PadProgram = /** @class */ (function () { function PadProgram(xShape, paddings, constantValue) { this.variableNames = ['x']; this.customUniforms = [{ name: 'value', type: 'float' }]; this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */); var rank = xShape.length; var type = getCoordsDataType(rank); var start = paddings.map(function (p) { return p[0]; }).join(','); var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(','); var unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank); if (rank === 1) { this.userCode = "\n int start = ".concat(start, ";\n int end = ").concat(end, ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start || outC >= end) {\n setOutput(value);\n } else {\n setOutput(getX(outC - start));\n }\n }\n "); return; } this.userCode = "\n ".concat(type, " start = ").concat(type, "(").concat(start, ");\n ").concat(type, " end = ").concat(type, "(").concat(end, ");\n\n void main() {\n ").concat(type, " outC = getOutputCoords();\n if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {\n setOutput(value);\n } else {\n ").concat(type, " coords = outC - start;\n setOutput(getX(").concat(unpackedCoords, "));\n }\n }\n "); } return PadProgram; }()); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var PadPackedProgram = /** @class */ (function () { function PadPackedProgram(xShape, paddings, constantValue) { this.variableNames = ['x']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [{ name: 'value', type: 'float' }]; this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */); var rank = xShape.length; var dtype = getCoordsDataType(rank); var start = paddings.map(function (p) { return p[0]; }).join(','); var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(','); var coords = getChannels('rc', rank); var source = getChannels('source', rank); var cLimit = "".concat(coords[rank - 1], " < ").concat(this.outputShape[rank - 1]); var innerDims = rank === 1 ? 'source' : "vec2(".concat(source.slice(-2).join(), ")"); var componentSetup = [ "".concat(dtype, " rc = outputLoc;"), "".concat(coords[rank - 1], " += 1;\n if(").concat(cLimit, ") {\n "), rank === 1 ? '' : "}\n rc = outputLoc;\n ".concat(coords[rank - 2], " += 1;\n if(").concat(coords[rank - 2], " < ").concat(this.outputShape[rank - 2], ") {"), rank === 1 ? '' : " ".concat(coords[rank - 1], " += 1;\n if(").concat(cLimit, ") {") ]; var paddingArea = rank === 1 ? 'rc < start || rc >= end' : 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))'; var mainLoop = ''; for (var i = 0, j = rank === 1 ? 2 : 4; i < j; i++) { mainLoop += "\n ".concat(componentSetup[i], "\n if (").concat(paddingArea, ") {\n result[").concat(i, "] = float(value);\n } else {\n ").concat(dtype, " source = rc - start;\n result[").concat(i, "] = getChannel(getX(").concat(source.join(), "), ").concat(innerDims, ");\n }\n "); } mainLoop += (rank === 1 ? "} " : "}}"); this.userCode = "\n const ".concat(dtype, " start = ").concat(dtype, "(").concat(start, ");\n const ").concat(dtype, " end = ").concat(dtype, "(").concat(end, ");\n\n void main() {\n ").concat(dtype, " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n ").concat(mainLoop, "\n setOutput(result);\n }\n "); } return PadPackedProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var padV2 = function (args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var paddings = attrs.paddings, constantValue = attrs.constantValue; if (tf.util.sizeFromShape(x.shape) === 0) { // Short-circuit the computation, since x doesn't have value, only // the shape is used to compute output shape to pad. var outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + x.shape[i] + p[1]; } /* afterPad */); return fill({ backend: backend, attrs: { shape: outputShape, value: constantValue, dtype: x.dtype } }); } var program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new PadPackedProgram(x.shape, paddings, constantValue) : new PadProgram(x.shape, paddings, constantValue); var customValues = [[constantValue]]; return backend.runWebGLProgram(program, [x], x.dtype, customValues); }; var padV2Config = { kernelName: tf.PadV2, backendName: 'webgl', kernelFunc: padV2 }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var POW = "\n if(a < 0.0 && floor(b) < b){\n return NAN;\n }\n if (b == 0.0) {\n return 1.0;\n }\n return (round(mod(b, 2.0)) != 1) ?\n pow(abs(a), b) : sign(a) * pow(abs(a), b);\n"; var POW_PACKED = "\n // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.\n vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));\n vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);\n vec4 result = multiplier * pow(abs(a), b);\n\n // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS\n bvec4 isExpZero = equal(b, vec4(0.0));\n result.r = isExpZero.r ? 1.0 : result.r;\n result.g = isExpZero.g ? 1.0 : result.g;\n result.b = isExpZero.b ? 1.0 : result.b;\n result.a = isExpZero.a ? 1.0 : result.a;\n\n bvec4 isNaN1 = lessThan(a, vec4(0.0));\n bvec4 isNaN2 = lessThan(floor(b), b);\n bvec4 isNaN = bvec4(isNaN1.x && isNaN2.x, isNaN1.y && isNaN2.y, isNaN1.z && isNaN2.z, isNaN1.w && isNaN2.w);\n " + CHECK_NAN_SNIPPET_PACKED + "\n return result;\n"; var pow = binaryKernelFunc({ opSnippet: POW, packedOpSnippet: POW_PACKED }); var powConfig = { kernelName: tf.Pow, backendName: 'webgl', kernelFunc: pow }; function prod(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var axis = attrs.axis, keepDims = attrs.keepDims; var xRank = x.shape.length; var toDispose = []; var origAxes = tf.util.parseAxisParam(axis, x.shape); var axes = origAxes; var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); var permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } }); axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); toDispose.push(permutedX); } tf.backend_util.assertAxesAreInnerMostDims('prod', axes, xRank); var res; if (backend.shouldExecuteOnCPU([permutedX])) { var xVals = backend.texData.get(permutedX.dataId).values; var _a = prodImplCPU(permutedX.shape, permutedX.dtype, xVals, axes), outVals = _a.outVals, outShape = _a.outShape, outDtype = _a.outDtype; res = backend.makeTensorInfo(outShape, outDtype, outVals); } else { var _b = __read(tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes), 2), outShape = _b[0], reduceShape = _b[1]; var inSize = tf.util.sizeFromShape(reduceShape); var a2D = reshape({ inputs: { x: permutedX }, backend: backend, attrs: { shape: [-1, inSize] } }); var outputDType = tf.sumOutType(x.dtype); var reduced = reduce(a2D, outputDType, 'prod', backend); res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: outShape } }); toDispose.push(a2D); toDispose.push(reduced); } if (keepDims) { toDispose.push(res); var newShape = tf.backend_util.expandShapeToKeepDim(res.shape, origAxes); res = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: newShape } }); } toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return res; } var prodConfig = { kernelName: tf.Prod, backendName: 'webgl', kernelFunc: prod }; function raggedGather(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var paramsNestedSplits = inputs.paramsNestedSplits, paramsDenseValues = inputs.paramsDenseValues, indices = inputs.indices; attrs.outputRaggedRank; var $paramsNestedSplits = paramsNestedSplits.map(function (t) { return backend.readSync(t.dataId); }); var $paramsNestedSplitsShapes = paramsNestedSplits.map(function (t) { return t.shape; }); var $paramsDenseValues = backend.readSync(paramsDenseValues.dataId); var $indices = backend.readSync(indices.dataId); var _a = __read(raggedGatherImplCPU($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape), 3), outputNestedSplits = _a[0], outputDenseValues = _a[1], outputDenseValuesShape = _a[2]; var outputNestedSplitsTensors = outputNestedSplits.map(function (splits) { return backend.makeTensorInfo([splits.length], 'int32', splits); }); var outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues); return outputNestedSplitsTensors.concat([outputDenseValuesTensor]); } var raggedGatherConfig = { kernelName: tf.RaggedGather, backendName: 'webgl', kernelFunc: raggedGather, }; function raggedRange(args) { var inputs = args.inputs, backend = args.backend; var starts = inputs.starts, limits = inputs.limits, deltas = inputs.deltas; var $starts = backend.readSync(starts.dataId); var $limits = backend.readSync(limits.dataId); var $deltas = backend.readSync(deltas.dataId); var _a = __read(raggedRangeImplCPU($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape), 2), rtNestedSplitsData = _a[0], rtDenseValuesData = _a[1]; var rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData); var rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData); return [rtNestedSplits, rtDenseValues]; } var raggedRangeConfig = { kernelName: tf.RaggedRange, backendName: 'webgl', kernelFunc: raggedRange, }; function raggedTensorToTensor(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var shape = inputs.shape, values = inputs.values, defaultValue = inputs.defaultValue, rowPartitionTensors = inputs.rowPartitionTensors; var rowPartitionTypes = attrs.rowPartitionTypes; var $shape = backend.readSync(shape.dataId); var $values = backend.readSync(values.dataId); var $defaultValue = backend.readSync(defaultValue.dataId); var $rowPartitionValues = rowPartitionTensors.map(function (t) { return backend.readSync(t.dataId); }); var rowPartitionValuesShapes = rowPartitionTensors.map(function (t) { return t.shape; }); var _a = __read(raggedTensorToTensorImplCPU($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes), 2), outputShape = _a[0], output = _a[1]; return backend.makeTensorInfo(outputShape, values.dtype, output); } var raggedTensorToTensorConfig = { kernelName: tf.RaggedTensorToTensor, backendName: 'webgl', kernelFunc: raggedTensorToTensor, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var range = function (args) { var backend = args.backend, attrs = args.attrs; var start = attrs.start, stop = attrs.stop, step = attrs.step, dtype = attrs.dtype; var values = rangeImplCPU(start, stop, step, dtype); return backend.makeTensorInfo([values.length], dtype, values); }; var rangeConfig = { kernelName: tf.Range, backendName: 'webgl', kernelFunc: range }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var RECIPROCAL = "return 1.0 / x;"; var reciprocal = unaryKernelFunc({ opSnippet: RECIPROCAL }); var reciprocalConfig = { kernelName: tf.Reciprocal, backendName: 'webgl', kernelFunc: reciprocal, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var RELU = CHECK_NAN_SNIPPET$1 + "\n return (x < 0.0) ? 0.0 : x;\n"; var RELU_PACKED = "\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n"; var relu = unaryKernelFunc({ opSnippet: RELU, packedOpSnippet: RELU_PACKED }); var reluConfig = { kernelName: tf.Relu, backendName: 'webgl', kernelFunc: relu }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var RELU6 = CHECK_NAN_SNIPPET$1 + "\n return (x < 0.0) ? 0.0 : min(6.0, x);\n"; var RELU6_PACKED = "\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n"; var relu6 = unaryKernelFunc({ opSnippet: RELU6, packedOpSnippet: RELU6_PACKED }); var relu6Config = { kernelName: tf.Relu6, backendName: 'webgl', kernelFunc: relu6 }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ResizeBilinearProgram = /** @class */ (function () { function ResizeBilinearProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.outputShape = []; var _a = __read(inputShape, 4), batch = _a[0], oldHeight = _a[1], oldWidth = _a[2], depth = _a[3]; this.outputShape = [batch, newHeight, newWidth, depth]; var effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; var effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; var sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = "(vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC" + " - vec2(0.5)"; } else { sourceFracIndexRC = "vec2(yRC) * effectiveInputOverOutputRatioRC"; } this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n ".concat(effectiveInSize[0] / effectiveOutSize[0], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ");\n const vec2 inputShapeRC = vec2(").concat(oldHeight, ".0, ").concat(oldWidth, ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = ").concat(sourceFracIndexRC, ";\n\n // Compute the four integer indices.\n ivec2 sourceFloorRC = ivec2(max(sourceFracIndexRC, vec2(0.0)));\n ivec2 sourceCeilRC = ivec2(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);\n float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);\n float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);\n float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);\n\n vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);\n\n float top = topLeft + (topRight - topLeft) * fracRC.y;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;\n float newValue = top + (bottom - top) * fracRC.x;\n\n setOutput(newValue);\n }\n "); } return ResizeBilinearProgram; }()); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ResizeBilinearPackedProgram = /** @class */ (function () { function ResizeBilinearPackedProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outputShape = []; var _a = __read(inputShape, 4), batch = _a[0], oldHeight = _a[1], oldWidth = _a[2], depth = _a[3]; this.outputShape = [batch, newHeight, newWidth, depth]; var effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; var effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; var sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = "(vec3(yRC) + vec3(0.5)) * " + "effectiveInputOverOutputRatioRC - vec3(0.5)"; } else { sourceFracIndexRC = "vec3(yRC) * effectiveInputOverOutputRatioRC"; } this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n ".concat(effectiveInSize[0] / effectiveOutSize[0], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ");\n const vec3 inputShapeRC = vec3(").concat(oldHeight, ".0, ").concat(oldWidth, ".0,\n ").concat(oldWidth, ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = ").concat(sourceFracIndexRC, ";\n\n // Compute the four integer indices.\n ivec3 sourceFloorRC = ivec3(max(sourceFracIndexRC, vec3(0.0)));\n ivec3 sourceCeilRC = ivec3(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < ").concat(depth - 1, ";\n bool hasNextRow = coords.z < ").concat(newWidth - 1, ";\n\n // In parallel, construct four corners for all four components in\n // packed 2x2 cell.\n vec4 topLeft = vec4(\n getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 bottomLeft = vec4(\n getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 topRight = vec4(\n getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec4 bottomRight = vec4(\n getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);\n\n vec4 top = mix(topLeft, topRight, fracRC.yyzz);\n vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);\n vec4 newValue = mix(top, bottom, fracRC.x);\n\n setOutput(newValue);\n }\n "); } return ResizeBilinearPackedProgram; }()); function resizeBilinear(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var images = inputs.images; var alignCorners = attrs.alignCorners, halfPixelCenters = attrs.halfPixelCenters, size = attrs.size; var _a = __read(size, 2), newHeight = _a[0], newWidth = _a[1]; var program = tf.env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeBilinearPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) : new ResizeBilinearProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters); return backend.runWebGLProgram(program, [images], 'float32'); } var resizeBilinearConfig = { kernelName: tf.ResizeBilinear, backendName: 'webgl', kernelFunc: resizeBilinear }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ResizeBilinearBackpropProgram = /** @class */ (function () { function ResizeBilinearBackpropProgram(dyShape, inputShape, alignCorners) { this.variableNames = ['dy']; this.outputShape = []; this.outputShape = inputShape; var _a = __read(inputShape, 3), xHeight = _a[1], xWidth = _a[2]; var _b = __read(dyShape, 3), yHeight = _b[1], yWidth = _b[2]; // In the backwards pass, we want to find the pixels that were generated for // each pixel in the input image the forward pass and add the corresponding // coefficient from dy to the gradient (with some interpolation). var effectiveXSize = [ (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth ]; var effectiveYSize = [ (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth ]; var heightScale = effectiveXSize[0] / effectiveYSize[0]; var widthScale = effectiveXSize[1] / effectiveYSize[1]; var invHeightScale = 1 / heightScale; var invWidthScale = 1 / widthScale; // This defines the size of the window of values around a particular // index in dy that we want to search for contributions to dx. var winHeight = (Math.ceil(invHeightScale) * 2) + 2; var winWidth = (Math.ceil(invWidthScale) * 2) + 2; this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(".concat(heightScale, ");\n const float widthScale = float(").concat(widthScale, ");\n\n const float invHeightScale = float(").concat(invHeightScale, ");\n const float invWidthScale = float(").concat(invWidthScale, ");\n\n const int winHeight = int(").concat(winHeight, ");\n const int winWidth = int(").concat(winWidth, ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(startRLerp - float(winHeight / 2));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(startCLerp - float(winWidth / 2));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= ").concat(yHeight, ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= ").concat(yWidth, ") {\n continue;\n }\n\n float dxR = float(dyR) * heightScale;\n int topDxRIndex = int(floor(dxR));\n int bottomDxRIndex = int(min(ceil(dxR), ").concat(xHeight - 1, ".0));\n float dxRLerp = dxR - float(topDxRIndex);\n float inverseDxRLerp = 1.0 - dxRLerp;\n\n float dxC = float(dyC) * widthScale;\n int leftDxCIndex = int(floor(dxC));\n int rightDxCIndex = int(min(ceil(dxC), ").concat(xWidth - 1, ".0));\n float dxCLerp = dxC - float(leftDxCIndex);\n float inverseDxCLerp = 1.0 - dxCLerp;\n\n if (r == topDxRIndex && c == leftDxCIndex) {\n // topLeft\n accumulator +=\n getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;\n }\n\n if (r == topDxRIndex && c == rightDxCIndex) {\n // topRight\n accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;\n }\n\n if (r == bottomDxRIndex && c == leftDxCIndex) {\n // bottomLeft\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;\n }\n\n if (r == bottomDxRIndex && c == rightDxCIndex) {\n // bottomRight\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n "); } return ResizeBilinearBackpropProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function resizeBilinearGrad(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var images = inputs.images, dy = inputs.dy; var alignCorners = attrs.alignCorners; var program = new ResizeBilinearBackpropProgram(dy.shape, images.shape, alignCorners); return backend.runWebGLProgram(program, [dy], dy.dtype); } var resizeBilinearGradConfig = { kernelName: tf.ResizeBilinearGrad, backendName: 'webgl', kernelFunc: resizeBilinearGrad }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ResizeNearestNeighborProgram = /** @class */ (function () { function ResizeNearestNeighborProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.outputShape = []; var _a = __read(inputShape, 4), batch = _a[0], oldHeight = _a[1], oldWidth = _a[2], depth = _a[3]; this.outputShape = [batch, newHeight, newWidth, depth]; var effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; var effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; // When align corners is false, we rounds the value with floor. var roundBase = alignCorners ? '0.5' : '0.0'; var sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = "max((vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC" + ", vec2(0.0))"; } else { sourceFracIndexRC = "vec2(yRC) * effectiveInputOverOutputRatioRC"; } this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n ".concat(effectiveInSize[0] / effectiveOutSize[0], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ");\n const vec2 inputShapeRC = vec2(").concat(oldHeight, ".0, ").concat(oldWidth, ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = ").concat(sourceFracIndexRC, ";\n\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestRC = ivec2(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ").concat(roundBase, ")));\n float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);\n\n setOutput(newValue);\n }\n "); } return ResizeNearestNeighborProgram; }()); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ResizeNearestNeighborPackedProgram = /** @class */ (function () { function ResizeNearestNeighborPackedProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outputShape = []; var _a = __read(inputShape, 4), batch = _a[0], oldHeight = _a[1], oldWidth = _a[2], depth = _a[3]; this.outputShape = [batch, newHeight, newWidth, depth]; var effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; var effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; // When align corners is false, we rounds the value with floor. var roundBase = alignCorners ? '0.5' : '0.0'; var sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = "max((vec3(yRC) + vec3(0.5)) * " + "effectiveInputOverOutputRatioRC, vec3(0.0))"; } else { sourceFracIndexRC = "vec3(yRC) * effectiveInputOverOutputRatioRC"; } this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n ".concat(effectiveInSize[0] / effectiveOutSize[0], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ",\n ").concat(effectiveInSize[1] / effectiveOutSize[1], ");\n const vec3 inputShapeRC = vec3(").concat(oldHeight, ".0, ").concat(oldWidth, ".0,\n ").concat(oldWidth, ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = ").concat(sourceFracIndexRC, ";\n\n // Compute the coordinators of nearest neighbor point.\n ivec3 sourceNearestRC = ivec3(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ").concat(roundBase, ")));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < ").concat(depth - 1, ";\n bool hasNextRow = coords.z < ").concat(newWidth - 1, ";\n\n vec4 newValue = vec4(\n getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d),\n hasNextCol ? getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d + 1) : 0.0);\n\n setOutput(newValue);\n }\n "); } return ResizeNearestNeighborPackedProgram; }()); function resizeNearestNeighbor(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var images = inputs.images; var alignCorners = attrs.alignCorners, halfPixelCenters = attrs.halfPixelCenters, size = attrs.size; var _a = __read(size, 2), newHeight = _a[0], newWidth = _a[1]; var program = tf.env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeNearestNeighborPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) : new ResizeNearestNeighborProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters); return backend.runWebGLProgram(program, [images], images.dtype); } var resizeNearestNeighborConfig = { kernelName: tf.ResizeNearestNeighbor, backendName: 'webgl', kernelFunc: resizeNearestNeighbor }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ResizeNearestNeigborBackpropProgram = /** @class */ (function () { function ResizeNearestNeigborBackpropProgram(dyShape, inputShape, alignCorners) { this.variableNames = ['dy']; this.outputShape = []; this.outputShape = inputShape; var _a = __read(inputShape, 3), xHeight = _a[1], xWidth = _a[2]; var _b = __read(dyShape, 3), yHeight = _b[1], yWidth = _b[2]; // In the backwards pass, we want to find the pixels that were generated for // each pixel in the input image the forward pass and add the corresponding // coefficient from dy to the gradient (with some interpolation). var effectiveXSize = [ (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth ]; var effectiveYSize = [ (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth ]; var heightScale = effectiveXSize[0] / effectiveYSize[0]; var widthScale = effectiveXSize[1] / effectiveYSize[1]; var invHeightScale = 1 / heightScale; var invWidthScale = 1 / widthScale; // This defines the size of the window of values around a particular // index in dy that we want to search for contributions to dx. var winHeight = (Math.ceil(invHeightScale) * 2) + 2; var winWidth = (Math.ceil(invWidthScale) * 2) + 2; this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(".concat(heightScale, ");\n const float widthScale = float(").concat(widthScale, ");\n\n const float invHeightScale = float(").concat(invHeightScale, ");\n const float invWidthScale = float(").concat(invWidthScale, ");\n\n const int winHeight = int(").concat(winHeight, ");\n const int winWidth = int(").concat(winWidth, ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(floor(startRLerp - float(winHeight / 2)));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(floor(startCLerp - float(winWidth / 2)));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= ").concat(yHeight, ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= ").concat(yWidth, ") {\n continue;\n }\n\n float sourceFracRow =\n float(").concat(effectiveXSize[0], ") *\n (float(dyR) / float(").concat(effectiveYSize[0], "));\n\n float sourceFracCol =\n float(").concat(effectiveXSize[1], ") *\n (float(dyC) / float(").concat(effectiveYSize[1], "));\n\n int sourceNearestRow = int(min(\n float(int(").concat(xHeight, ") - 1),\n ").concat(alignCorners, " ? float(round(sourceFracRow)) :\n float(floor(sourceFracRow))));\n\n int sourceNearestCol = int(min(\n float(int(").concat(xWidth, ") - 1),\n ").concat(alignCorners, " ? float(round(sourceFracCol)) :\n float(floor(sourceFracCol))));\n\n if (r == sourceNearestRow && c == sourceNearestCol) {\n accumulator += getDy(b, dyR, dyC, d);\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n "); } return ResizeNearestNeigborBackpropProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function resizeNearestNeighborGrad(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var images = inputs.images, dy = inputs.dy; var alignCorners = attrs.alignCorners; var program = new ResizeNearestNeigborBackpropProgram(dy.shape, images.shape, alignCorners); return backend.runWebGLProgram(program, [dy], dy.dtype); } var resizeNearestNeighborGradConfig = { kernelName: tf.ResizeNearestNeighborGrad, backendName: 'webgl', kernelFunc: resizeNearestNeighborGrad }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ReverseProgram = /** @class */ (function () { function ReverseProgram(xShape, axis) { this.variableNames = ['x']; var rank = xShape.length; if (rank > 4) { throw new Error("WebGL backend: Reverse of rank-".concat(rank, " tensor is not yet supported")); } this.outputShape = xShape; if (rank === 1) { this.userCode = "\n void main() {\n int coord = getOutputCoords();\n setOutput(getX(".concat(xShape[0], " - coord - 1));\n }\n "); return; } var getInCoord = function (i) { if (axis.indexOf(i) !== -1 && xShape[i] !== 1) { return "".concat(xShape[i], " - coords[").concat(i, "] - 1"); } return "coords[".concat(i, "]"); }; var inCoords = xShape.map(function (_, i) { return getInCoord(i); }).join(','); var type = getCoordsDataType(rank); this.userCode = "\n void main() {\n ".concat(type, " coords = getOutputCoords();\n setOutput(getX(").concat(inCoords, "));\n }\n "); } return ReverseProgram; }()); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ReversePackedProgram = /** @class */ (function () { function ReversePackedProgram(xShape, axis) { this.variableNames = ['x']; this.packedInputs = true; this.packedOutput = true; var rank = xShape.length; if (rank > 4) { throw new Error("WebGL backend: Reverse of rank-".concat(rank, " tensor is not yet supported")); } this.outputShape = xShape; var channels = getChannels('rc', rank); var nextColumn = "".concat(channels[rank - 1], " + 1 < ").concat(this.outputShape[rank - 1]); var nextRow = "".concat(channels[rank - 2], " + 1 < ").concat(this.outputShape[rank - 2]); var type = getCoordsDataType(rank); if (rank === 1) { this.userCode = "\n void main(){\n int rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = getChannel(getX(".concat(xShape[0], " - rc - 1),\n ").concat(xShape[0], " - rc - 1);\n if(").concat(nextColumn, "){\n result.g = getChannel(getX(").concat(xShape[0], " - (rc + 1) - 1),\n ").concat(xShape[0], " - (rc + 1) - 1);\n }\n setOutput(result);\n }\n "); } else { this.userCode = "\n void main() {\n ".concat(type, " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = ").concat(getR(channels.slice()), ";\n if(").concat(nextColumn, "){\n result.g = ").concat(getG(channels.slice()), ";\n }\n if(").concat(nextRow, ") {\n result.b = ").concat(getB(channels.slice()), ";\n if(").concat(nextColumn, ") {\n result.a = ").concat(getA(channels.slice()), ";\n }\n }\n setOutput(result);\n }\n "); } function getR(channels) { return getChannel(channels); } function getG(channels) { channels[rank - 1] = '(' + channels[rank - 1] + " + 1)"; return getChannel(channels); } function getB(channels) { channels[rank - 2] = '(' + channels[rank - 2] + " + 1)"; return getChannel(channels); } function getA(channels) { channels[rank - 1] = '(' + channels[rank - 1] + " + 1)"; channels[rank - 2] = '(' + channels[rank - 2] + " + 1)"; return getChannel(channels); } function getChannel(channels) { var inCoordsArray = xShape.map(function (_, i) { return getInCoord(i, channels); }); var inCoords = inCoordsArray.join(','); var innerDims = inCoordsArray.slice(-2).join(','); return "getChannel(getX(".concat(inCoords, "), vec2(").concat(innerDims, "))"); } function getInCoord(i, channels1) { if (axis.indexOf(i) !== -1 && xShape[i] !== 1) { return "".concat(xShape[i], " - ").concat(channels1[i], " - 1"); } else { return "".concat(channels1[i]); } } } return ReversePackedProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function reverse(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var dims = attrs.dims; var xRank = x.shape.length; var $dims = tf.util.parseAxisParam(dims, x.shape); if (xRank === 0) { return identity({ inputs: { x: x }, backend: backend }); } var program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new ReversePackedProgram(x.shape, $dims) : new ReverseProgram(x.shape, $dims); return backend.runWebGLProgram(program, [x], x.dtype); } var reverseConfig = { kernelName: tf.Reverse, backendName: 'webgl', kernelFunc: reverse }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var RotateProgram = /** @class */ (function () { function RotateProgram(imageShape, fillValue) { this.variableNames = ['Image']; this.outputShape = []; this.customUniforms = [{ name: 'params', type: 'vec4' }]; var imageHeight = imageShape[1]; var imageWidth = imageShape[2]; this.outputShape = imageShape; var fillSnippet = ''; if (typeof fillValue === 'number') { fillSnippet = "float outputValue = ".concat(fillValue.toFixed(2), ";"); } else { fillSnippet = "\n vec3 fill = vec3(".concat(fillValue.join(','), ");\n float outputValue = fill[coords[3]];"); } this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int x = coords[2];\n int y = coords[1];\n float coordXFloat = (float(x) - params[0]) * params[3] -\n (float(y) - params[1]) * params[2];\n float coordYFloat = (float(x) - params[0]) * params[2] +\n (float(y) - params[1]) * params[3];\n int coordX = int(round(coordXFloat + params[0]));\n int coordY = int(round(coordYFloat + params[1]));\n ".concat(fillSnippet, "\n if(coordX >= 0 && coordX < ").concat(imageWidth, " && coordY >= 0 && coordY < ").concat(imageHeight, ") {\n outputValue = getImage(coords[0], coordY, coordX, coords[3]);\n }\n setOutput(outputValue);\n }\n "); } return RotateProgram; }()); var rotateWithOffsetConfig = { kernelName: tf.RotateWithOffset, backendName: 'webgl', kernelFunc: function (_a) { var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend; var image = inputs.image; var radians = attrs.radians, fillValue = attrs.fillValue, center = attrs.center; var webglBackend = backend; var program = new RotateProgram(image.shape, fillValue); var _b = __read(tf.backend_util.getImageCenter(center, image.shape[1], image.shape[2]), 2), centerX = _b[0], centerY = _b[1]; var customValues = [[centerX, centerY, Math.sin(radians), Math.cos(radians)]]; var output = webglBackend.runWebGLProgram(program, [image], image.dtype, customValues); return output; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ROUND = "\n // OpenGL ES does not support round function.\n // The algorithm is based on banker's rounding.\n float base = floor(x);\n if ((x - base) < 0.5) {\n return floor(x);\n } else if ((x - base) > 0.5) {\n return ceil(x);\n } else {\n if (mod(base, 2.0) == 0.0) {\n return base;\n } else {\n return base + 1.0;\n }\n }\n"; var round = unaryKernelFunc({ opSnippet: ROUND }); var roundConfig = { kernelName: tf.Round, backendName: 'webgl', kernelFunc: round, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var RSQRT = "return inversesqrt(x);"; var rsqrt = unaryKernelFunc({ opSnippet: RSQRT, cpuKernelImpl: rsqrtImplCPU }); var rsqrtConfig = { kernelName: tf.Rsqrt, backendName: 'webgl', kernelFunc: rsqrt }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ScatterProgram = /** @class */ (function () { function ScatterProgram(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex, defaultIsTensor) { if (defaultIsTensor === void 0) { defaultIsTensor = false; } this.variableNames = ['updates', 'indices', 'defaultValue']; this.outputShape = shape; var stridesType = getCoordsDataType(strides.length); var dtype = getCoordsDataType(shape.length); var indicesString = ''; if (indicesRank === 1) { indicesString = 'i'; } else if (indicesRank === 2) { indicesString = 'i, j'; } var indicesSnippet = "getIndices(".concat(indicesString, ")"); var updatesString = ''; if (updatesRank === 1) { updatesString = 'i'; } else if (updatesRank === 2) { updatesString = 'i, coords[1]'; } var updatesSnippet = "getUpdates(".concat(updatesString, ")"); var defaultValuesString = ''; if (defaultIsTensor) { defaultValuesString = 'coords[0], coords[1]'; } var defaultValueSnippet = "getDefaultValue(".concat(defaultValuesString, ")"); var strideString = sliceDim > 1 ? 'strides[j]' : 'strides'; this.userCode = "\n ".concat(stridesType, " strides = ").concat(stridesType, "(").concat(strides, ");\n\n void main() {\n ").concat(dtype, " coords = getOutputCoords();\n float sum = 0.0;\n bool found = false;\n for (int i = 0; i < ").concat(updateSize, "; i++) {\n int flattenedIndex = 0;\n for (int j = 0; j < ").concat(sliceDim, "; j++) {\n int index = round(").concat(indicesSnippet, ");\n flattenedIndex += index * ").concat(strideString, ";\n }\n if (flattenedIndex == coords[0]) {\n sum += ").concat(updatesSnippet, ";\n found = true;\n }\n }\n setOutput(mix(").concat(defaultValueSnippet, ", sum, float(found)));\n }\n "); } return ScatterProgram; }()); /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ScatterPackedProgram = /** @class */ (function () { function ScatterPackedProgram(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex, defaultIsTensor) { if (defaultIsTensor === void 0) { defaultIsTensor = false; } this.variableNames = ['updates', 'indices', 'defaultValue']; this.packedInputs = true; this.packedOutput = true; this.outputShape = shape; var stridesType = getCoordsDataType(strides.length); var dtype = getCoordsDataType(shape.length); var indicesString = ''; if (indicesRank === 1) { indicesString = 'i'; } else if (indicesRank === 2) { indicesString = 'i, j'; } var indicesSnippet = "getIndices(".concat(indicesString, ")"); var updatesString = ''; if (updatesRank === 1) { updatesString = 'i'; } else if (updatesRank === 2) { updatesString = 'i, coords[1]'; } var updatesSnippet = "getUpdates(".concat(updatesString, ")"); var defaultValuesString = ''; if (defaultIsTensor) { defaultValuesString = 'coords[0], coords[1]'; } var defaultValueSnippet = "getDefaultValue(".concat(defaultValuesString, ")"); var strideString = sliceDim > 1 ? 'strides[j]' : 'strides'; var strideString2 = sliceDim > 1 ? 'strides[j + 1]' : 'strides'; this.userCode = "\n ".concat(stridesType, " strides = ").concat(stridesType, "(").concat(strides, ");\n\n void main() {\n ").concat(dtype, " coords = getOutputCoords();\n vec4 sum = vec4(0.);\n vec4 found = vec4(0.);\n for (int i = 0; i < ").concat(updateSize, "; i+=2) {\n ivec2 flattenedIndex = ivec2(0);\n for (int j = 0; j < ").concat(sliceDim, "; j+=2) {\n ivec4 index = round(").concat(indicesSnippet, ");\n flattenedIndex += index.xz * ").concat(strideString, ";\n if (j + 1 < ").concat(sliceDim, ") {\n flattenedIndex += index.yw * ").concat(strideString2, ";\n }\n }\n if (flattenedIndex[0] == coords[0] || flattenedIndex[1] == coords[0] ||\n flattenedIndex[0] == coords[0] + 1 || flattenedIndex[1] == coords[0] + 1) {\n vec4 updVals = ").concat(updatesSnippet, ";\n if (flattenedIndex[0] == coords[0]) {\n sum.xy += updVals.xy;\n found.xy = vec2(1.);\n } else if (flattenedIndex[0] == coords[0] + 1) {\n sum.zw += updVals.xy;\n found.zw = vec2(1.);\n }\n if (flattenedIndex[1] == coords[0]) {\n sum.xy += updVals.zw;\n found.xy = vec2(1.);\n } else if (flattenedIndex[1] == coords[0] + 1) {\n sum.zw += updVals.zw;\n found.zw = vec2(1.);\n }\n }\n }\n setOutput(mix(").concat(defaultValueSnippet, ", sum, found));\n }\n "); } return ScatterPackedProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function scatterNd(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var indices = inputs.indices, updates = inputs.updates; var shape = attrs.shape; var _a = tf.backend_util.calculateShapes(updates, indices, shape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, sliceSize = _a.sliceSize, strides = _a.strides, outputSize = _a.outputSize; var flattenShape = [outputSize / sliceSize, sliceSize]; if (outputSize === 0) { return backend.makeTensorInfo(shape, indices.dtype); } var flattenIndices = reshape({ inputs: { x: indices }, backend: backend, attrs: { shape: [numUpdates, sliceRank] } }); var flattenX = reshape({ inputs: { x: updates }, backend: backend, attrs: { shape: [numUpdates, sliceSize] } }); var defaultValue = backend.makeTensorInfo([], 'float32', new Float32Array([0])); // scalar(0) var program; if (tf.env().getBool('WEBGL_PACK')) { program = new ScatterPackedProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape); } else { program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape); } var res = backend.runWebGLProgram(program, [flattenX, flattenIndices, defaultValue], flattenX.dtype); var reshaped = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: shape } }); backend.disposeIntermediateTensorInfo(flattenIndices); backend.disposeIntermediateTensorInfo(flattenX); backend.disposeIntermediateTensorInfo(res); backend.disposeIntermediateTensorInfo(defaultValue); return reshaped; } var scatterNdConfig = { kernelName: tf.ScatterNd, backendName: 'webgl', kernelFunc: scatterNd }; /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SearchSortedProgram = /** @class */ (function () { function SearchSortedProgram(batchSize, numInputs, numValues, side) { this.variableNames = ['sortedSequence', 'values']; this.customUniforms = [{ name: 'numInputs', type: 'int' }]; this.outputShape = [batchSize, numValues]; var webGL2LoopHead = 'while (left < right) {'; // WebGL1 doesn't accept non constant loop conditions, so upper bound loop // iterations. var webGL1LoopHead = "for (int i = 0; i < ".concat(Math.ceil(Math.log2(numInputs + 1)), "; ++i) { if (left >= right) break;"); var loopHead = tf.env().getNumber('WEBGL_VERSION') === 2 ? webGL2LoopHead : webGL1LoopHead; // left corresponds to lower bound and right to upper bound. var boundComparator = side === 'left' ? '<' : '<='; this.userCode = "\n int findBound(int batch, float value) {\n int left = 0;\n int right = numInputs;\n int mid;\n ".concat(loopHead, "\n mid = (left + right) / 2;\n if (getSortedSequence(batch, mid) ").concat(boundComparator, " value) {\n left = mid + 1;\n } else {\n right = mid;\n }\n }\n return right;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int valueIndex = coords[1];\n\n float value = getValues(batch, valueIndex);\n\n setOutput(float(findBound(batch, value)));\n }\n "); } return SearchSortedProgram; }()); /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function searchSorted(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var sortedSequence = inputs.sortedSequence, values = inputs.values; var side = attrs.side; var program = new SearchSortedProgram(sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side); var customValues = [[sortedSequence.shape[1]]]; return backend.runWebGLProgram(program, [sortedSequence, values], 'int32', customValues); } var searchSortedConfig = { kernelName: tf.SearchSorted, backendName: 'webgl', kernelFunc: searchSorted, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SelectProgram = /** @class */ (function () { function SelectProgram(cRank, shape, rank) { this.variableNames = ['c', 'a', 'b']; this.outputShape = shape; var cCoords; var abCoords; if (rank > 4) { throw Error("Where for rank ".concat(rank, " is not yet supported")); } if (rank === 1) { abCoords = "resRC"; cCoords = "resRC"; } else { var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; var cCoordVars = []; var abCoordVars = []; for (var i = 0; i < shape.length; i++) { abCoordVars.push("".concat(currentCoords[i])); if (i < cRank) { cCoordVars.push("".concat(currentCoords[i])); } } cCoords = cCoordVars.join(); abCoords = abCoordVars.join(); } var dtype = getCoordsDataType(rank); this.userCode = "\n void main() {\n ".concat(dtype, " resRC = getOutputCoords();\n float cVal = getC(").concat(cCoords, ");\n if (cVal >= 1.0) {\n setOutput(getA(").concat(abCoords, "));\n } else {\n setOutput(getB(").concat(abCoords, "));\n }\n }\n "); } return SelectProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function select(args) { var inputs = args.inputs, backend = args.backend; var condition = inputs.condition, t = inputs.t, e = inputs.e; var program = new SelectProgram(condition.shape.length, t.shape, t.shape.length); return backend.runWebGLProgram(program, [condition, t, e], tf.upcastType(t.dtype, e.dtype)); } var selectConfig = { kernelName: tf.Select, backendName: 'webgl', kernelFunc: select }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SELU = "\n // Stable and Attracting Fixed Point (0, 1) for Normalized Weights.\n // see: https://arxiv.org/abs/1706.02515\n float scaleAlpha = ".concat(tf.backend_util.SELU_SCALEALPHA, ";\n float scale = ").concat(tf.backend_util.SELU_SCALE, ";\n return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);\n"); var selu = unaryKernelFunc({ opSnippet: SELU }); var seluConfig = { kernelName: tf.Selu, backendName: 'webgl', kernelFunc: selu, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SIGMOID = CHECK_NAN_SNIPPET_UNARY + "\n return 1.0 / (1.0 + exp(-1.0 * x));\n"; var SIGMOID_PACKED = "\n vec4 result = 1.0 / (1.0 + exp(-1.0 * x));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n"; var sigmoid = unaryKernelFunc({ opSnippet: SIGMOID, packedOpSnippet: SIGMOID_PACKED, cpuKernelImpl: sigmoidImplCPU }); var sigmoidConfig = { kernelName: tf.Sigmoid, backendName: 'webgl', kernelFunc: sigmoid, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Sign does not propagate NANs. var SIGN = "\n if (isnan(x)) { return 0.0; }\n return sign(x);\n"; var sign = unaryKernelFunc({ opSnippet: SIGN }); var signConfig = { kernelName: tf.Sign, backendName: 'webgl', kernelFunc: sign, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SIN = CHECK_NAN_SNIPPET_UNARY + "\n return sin(x);\n"; var SIN_PACKED = "\n vec4 result = sin(x);\n bvec4 isNaN = isnan(x);\n ".concat(CHECK_NAN_SNIPPET_PACKED, "\n return result;\n"); var sin = unaryKernelFunc({ opSnippet: SIN, packedOpSnippet: SIN_PACKED }); var sinConfig = { kernelName: tf.Sin, backendName: 'webgl', kernelFunc: sin, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SINH = "\n float e2x = exp(x);\n return (e2x - 1.0 / e2x) / 2.0;\n"; var sinh = unaryKernelFunc({ opSnippet: SINH }); var sinhConfig = { kernelName: tf.Sinh, backendName: 'webgl', kernelFunc: sinh, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SOFTPLUS = "\n float epsilon = 1.1920928955078125e-7;\n float threshold = log(epsilon) + 2.0;\n\n bool too_large = x > -threshold;\n bool too_small = x < threshold;\n\n float result;\n float exp_x = exp(x);\n\n if (too_large){\n result = x;\n }\n else if (too_small){\n result = exp_x;\n }\n else{\n result = log(exp_x + 1.0);\n }\n return result;\n"; var softplus = unaryKernelFunc({ opSnippet: SOFTPLUS }); var softplusConfig = { kernelName: tf.Softplus, backendName: 'webgl', kernelFunc: softplus, }; var spaceToBatchND = function (args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var blockShape = attrs.blockShape, paddings = attrs.paddings; tf.util.assert(x.shape.length <= 4, function () { return 'spaceToBatchND for rank > 4 with a WebGL backend not ' + 'implemented yet'; }); var prod = blockShape.reduce(function (a, b) { return a * b; }); var completePaddings = [[0, 0]]; completePaddings.push.apply(completePaddings, __spreadArray([], __read(paddings), false)); for (var i = 1 + blockShape.length; i < x.shape.length; ++i) { completePaddings.push([0, 0]); } var toDispose = []; var paddedX = padV2({ inputs: { x: x }, backend: backend, attrs: { paddings: completePaddings, constantValue: 0 } }); var reshapedPaddedShape = tf.backend_util.getReshaped(paddedX.shape, blockShape, prod, false); var permutedReshapedPaddedPermutation = tf.backend_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false); var flattenShape = tf.backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false); var reshapedPaddedX = reshape({ inputs: { x: paddedX }, backend: backend, attrs: { shape: reshapedPaddedShape } }); var paddedXT = transpose({ inputs: { x: reshapedPaddedX }, backend: backend, attrs: { perm: permutedReshapedPaddedPermutation } }); var result = reshape({ inputs: { x: paddedXT }, backend: backend, attrs: { shape: flattenShape } }); toDispose.push(paddedX); toDispose.push(reshapedPaddedX); toDispose.push(paddedXT); toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return result; }; var spaceToBatchNDConfig = { kernelName: tf.SpaceToBatchND, backendName: 'webgl', kernelFunc: spaceToBatchND }; function sparseFillEmptyRows(args) { var inputs = args.inputs, backend = args.backend; var indices = inputs.indices, values = inputs.values, denseShape = inputs.denseShape, defaultValue = inputs.defaultValue; if (denseShape.shape.length !== 1) { throw new Error("Dense shape must be a vector, saw:\n ".concat(denseShape.shape)); } if (indices.shape.length !== 2) { throw new Error("Indices must be a matrix, saw:\n ".concat(indices.shape)); } if (values.shape.length !== 1) { throw new Error("Values must be a vector, saw:\n ".concat(values.shape)); } if (defaultValue.shape.length !== 0) { throw new Error("Default value must be a scalar, saw:\n ".concat(defaultValue.shape)); } var $indices = backend.readSync(indices.dataId); var $values = backend.readSync(values.dataId); var $denseShape = backend.readSync(denseShape.dataId); var $defaultValue = backend.readSync(defaultValue.dataId)[0]; var _a = __read(sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue), 5), outputIndices = _a[0], outputIndicesShape = _a[1], outputValues = _a[2], emptyRowIndicator = _a[3], reverseIndexMap = _a[4]; return [ backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices), backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues), backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map(function (value) { return Number(value); }))), backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)), ]; } var sparseFillEmptyRowsConfig = { kernelName: tf.SparseFillEmptyRows, backendName: 'webgl', kernelFunc: sparseFillEmptyRows, }; function sparseReshape(args) { var inputs = args.inputs, backend = args.backend; var inputIndices = inputs.inputIndices, inputShape = inputs.inputShape, newShape = inputs.newShape; if (inputIndices.shape.length !== 2) { throw new Error("Input indices should be a matrix but received shape ".concat(inputIndices.shape)); } if (inputShape.shape.length !== 1) { throw new Error("Input shape should be a vector but received shape ".concat(inputShape.shape)); } if (newShape.shape.length !== 1) { throw new Error("Target shape should be a vector but received shape ".concat(newShape.shape)); } var $inputShape = Array.from(backend.readSync(inputShape.dataId)); var $inputIndices = backend.readSync(inputIndices.dataId); var targetShape = Array.from(backend.readSync(newShape.dataId)); var _a = __read(sparseReshapeImplCPU($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape), 3), newIndices = _a[0], indicesShape = _a[1], outputShape = _a[2]; return [ backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices), backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)), ]; } var sparseReshapeConfig = { kernelName: tf.SparseReshape, backendName: 'webgl', kernelFunc: sparseReshape, }; function sparseSegmentMean(args) { var inputs = args.inputs, backend = args.backend; var data = inputs.data, indices = inputs.indices, segmentIds = inputs.segmentIds; if (data.shape.length < 1) { throw new Error("Data should be at least 1 dimensional but received scalar"); } if (indices.shape.length !== 1) { throw new Error("Indices should be a vector but received shape\n ".concat(indices.shape)); } if (segmentIds.shape.length !== 1) { throw new Error("Segment ids should be a vector but received shape\n ".concat(segmentIds.shape)); } var $data = backend.readSync(data.dataId); var $indices = backend.readSync(indices.dataId); var $segmentIds = backend.readSync(segmentIds.dataId); var _a = __read(sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds, true), 2), outputData = _a[0], outputDataShape = _a[1]; return backend.makeTensorInfo(outputDataShape, data.dtype, outputData); } var sparseSegmentMeanConfig = { kernelName: tf.SparseSegmentMean, backendName: 'webgl', kernelFunc: sparseSegmentMean, }; function sparseSegmentSum(args) { var inputs = args.inputs, backend = args.backend; var data = inputs.data, indices = inputs.indices, segmentIds = inputs.segmentIds; if (data.shape.length < 1) { throw new Error("Data should be at least 1 dimensional but received scalar"); } if (indices.shape.length !== 1) { throw new Error("Indices should be a vector but received shape\n ".concat(indices.shape)); } if (segmentIds.shape.length !== 1) { throw new Error("Segment ids should be a vector but received shape\n ".concat(segmentIds.shape)); } var $data = backend.readSync(data.dataId); var $indices = backend.readSync(indices.dataId); var $segmentIds = backend.readSync(segmentIds.dataId); var _a = __read(sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds), 2), outputData = _a[0], outputDataShape = _a[1]; return backend.makeTensorInfo(outputDataShape, data.dtype, outputData); } var sparseSegmentSumConfig = { kernelName: tf.SparseSegmentSum, backendName: 'webgl', kernelFunc: sparseSegmentSum, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseToDense(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var sparseIndices = inputs.sparseIndices, sparseValues = inputs.sparseValues, defaultValue = inputs.defaultValue; var outputShape = attrs.outputShape; var _a = tf.backend_util.calculateShapes(sparseValues, sparseIndices, outputShape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, sliceSize = _a.sliceSize, strides = _a.strides, outputSize = _a.outputSize; var sumDupeIndices = false; if (sparseValues.dtype === 'string') { var indicesBuf = backend.bufferSync(sparseIndices); var updatesBuf = backend.bufferSync(sparseValues); var $defaultValue = tf.util.decodeString(backend.readSync(defaultValue.dataId)[0]); var outBuf = scatterImplCPU(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices); return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values); } var program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.shape.length, sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices); var res = backend.runWebGLProgram(program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype); var reshaped = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: outputShape } }); backend.disposeIntermediateTensorInfo(res); return reshaped; } var sparseToDenseConfig = { kernelName: tf.SparseToDense, backendName: 'webgl', kernelFunc: sparseToDense }; function splitV(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var numOrSizeSplits = attrs.numOrSizeSplits, axis = attrs.axis; var $axis = tf.util.parseAxisParam(axis, x.shape)[0]; var splitSizes = tf.backend_util.prepareSplitSize(x, numOrSizeSplits, $axis); var xRank = x.shape.length; var begin = new Array(xRank).fill(0); var size = x.shape.slice(); return splitSizes.map(function (s) { var sliceSize = __spreadArray([], __read(size), false); sliceSize[$axis] = s; var sliceT = slice({ inputs: { x: x }, backend: backend, attrs: { begin: begin, size: sliceSize } }); begin[$axis] += s; return sliceT; }); } var splitVConfig = { kernelName: tf.SplitV, backendName: 'webgl', kernelFunc: splitV }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SQRT = "return sqrt(x);"; var sqrt = unaryKernelFunc({ opSnippet: SQRT, packedOpSnippet: SQRT, cpuKernelImpl: sqrtImplCPU }); var sqrtConfig = { kernelName: tf.Sqrt, backendName: 'webgl', kernelFunc: sqrt }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SQUARE = "return x * x;"; var square = unaryKernelFunc({ opSnippet: SQUARE }); var squareConfig = { kernelName: tf.Square, backendName: 'webgl', kernelFunc: square, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SQUARED_DIFFERENCE = 'return (a - b) * (a - b);'; var squaredDifference = binaryKernelFunc({ opSnippet: SQUARED_DIFFERENCE, packedOpSnippet: SQUARED_DIFFERENCE }); var squaredDifferenceConfig = { kernelName: tf.SquaredDifference, backendName: 'webgl', kernelFunc: squaredDifference, }; /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function staticRegexReplace(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; if (x.dtype !== 'string') { throw new Error('Input must be of datatype string'); } var $x = backend.readSync(x.dataId); var stringInput = tf.backend_util.fromUint8ToStringArray($x); var output = staticRegexReplaceImplCPU(stringInput, 'string', attrs); return backend.makeTensorInfo(x.shape, 'string', output); } var staticRegexReplaceConfig = { kernelName: tf.StaticRegexReplace, backendName: 'webgl', kernelFunc: staticRegexReplace, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function step(_a) { var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend; var x = inputs.x; var opSnippet = CHECK_NAN_SNIPPET$1 + "\n return x > 0.0 ? 1.0 : float(".concat(attrs.alpha, ");\n "); var program = new UnaryOpProgram(x.shape, opSnippet); return backend.runWebGLProgram(program, [x], x.dtype); } var stepConfig = { kernelName: tf.Step, backendName: 'webgl', kernelFunc: step, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var StridedSliceProgram = /** @class */ (function () { function StridedSliceProgram(begin, strides, size) { this.variableNames = ['x']; this.outputShape = size; var rank = size.length; var inputDtype = getCoordsDataType(size.length); var dtype = getCoordsDataType(size.length); var newCoords = ''; if (rank === 1) { newCoords = 'coords * strides + begin'; } else { var outputAxis_1 = 0; newCoords = size.map(function (_, i) { outputAxis_1++; return size.length === 1 ? "coords * strides[".concat(i, "] + begin[").concat(i, "]") : "coords[".concat(outputAxis_1 - 1, "] * strides[").concat(i, "] + begin[").concat(i, "]"); }) .join(','); } this.userCode = "\n ".concat(inputDtype, " begin = ").concat(inputDtype, "(").concat(begin, ");\n ").concat(inputDtype, " strides = ").concat(inputDtype, "(").concat(strides, ");\n\n void main() {\n ").concat(dtype, " coords = getOutputCoords();\n setOutput(getX(").concat(newCoords, "));\n }\n "); } return StridedSliceProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function stridedSlice(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var begin = attrs.begin, end = attrs.end, strides = attrs.strides, beginMask = attrs.beginMask, endMask = attrs.endMask, ellipsisMask = attrs.ellipsisMask, newAxisMask = attrs.newAxisMask, shrinkAxisMask = attrs.shrinkAxisMask; var _a = tf.slice_util.sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask), finalShapeSparse = _a.finalShapeSparse, finalShape = _a.finalShape, isIdentity = _a.isIdentity, sliceDim0 = _a.sliceDim0, isSimpleSlice = _a.isSimpleSlice, $begin = _a.begin, $end = _a.end, $strides = _a.strides; var result; if (isIdentity) { // Optimization #1, slice is a no-op plus reshape result = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: finalShape } }); } else if (sliceDim0 || isSimpleSlice) { // Optimization #2, slice is memory contiguous (only occurs in dim 0) tf.util.assert(x.shape.length >= 1, function () { return "Input must have rank at least 1, got: ".concat(x.shape.length); }); var size = tf.slice_util.computeOutShape($begin, $end, $strides); // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end). var sliced = slice({ inputs: { x: x }, backend: backend, attrs: { begin: $begin, size: size } }); result = reshape({ inputs: { x: sliced }, backend: backend, attrs: { shape: finalShape } }); backend.disposeIntermediateTensorInfo(sliced); } else { var shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]); if (shouldExecuteOnCPU) { // tslint:disable-next-line: no-unnecessary-type-assertion var values = backend.readSync(x.dataId); // tslint:disable-next-line: no-unnecessary-type-assertion var xBuf = tf.buffer(x.shape, x.dtype, values); var resultValues = stridedSliceImplCPU(finalShapeSparse, xBuf, $strides, $begin); result = backend.makeTensorInfo(finalShape, x.dtype, resultValues.values); } else { var program = new StridedSliceProgram($begin, $strides, finalShapeSparse); result = backend.runWebGLProgram(program, [x], x.dtype); } } var resultReshaped = reshape({ inputs: { x: result }, backend: backend, attrs: { shape: finalShape } }); backend.disposeIntermediateTensorInfo(result); return resultReshaped; } var stridedSliceConfig = { kernelName: tf.StridedSlice, backendName: 'webgl', kernelFunc: stridedSlice }; function stringNGrams(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var separator = attrs.separator, nGramWidths = attrs.nGramWidths, leftPad = attrs.leftPad, rightPad = attrs.rightPad, padWidth = attrs.padWidth, preserveShortSequences = attrs.preserveShortSequences; var data = inputs.data, dataSplits = inputs.dataSplits; var $data = backend.readSync(data.dataId); var $dataSplits = backend.readSync(dataSplits.dataId); var _a = __read(stringNGramsImplCPU($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences), 2), nGrams = _a[0], nGramsSplits = _a[1]; return [ backend.makeTensorInfo([nGrams.length], 'string', nGrams), backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits), ]; } var stringNGramsConfig = { kernelName: tf.StringNGrams, backendName: 'webgl', kernelFunc: stringNGrams, }; function stringSplit(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var skipEmpty = attrs.skipEmpty; var input = inputs.input, delimiter = inputs.delimiter; if (input.dtype !== 'string') { throw new Error('Input must be of datatype string'); } if (input.shape.length !== 1) { throw new Error("Input must be a vector, got shape: ".concat(input.shape)); } if (delimiter.shape.length !== 0) { throw new Error("Delimiter must be a scalar, got shape: ".concat(delimiter.shape)); } var $input = backend.readSync(input.dataId); var $delimiter = backend.readSync(delimiter.dataId)[0]; var _a = __read(stringSplitImplCPU($input, $delimiter, skipEmpty), 3), indices = _a[0], values = _a[1], shape = _a[2]; var outputSize = values.length; return [ backend.makeTensorInfo([outputSize, 2], 'int32', indices), backend.makeTensorInfo([outputSize], 'string', values), backend.makeTensorInfo([2], 'int32', new Int32Array(shape)) ]; } var stringSplitConfig = { kernelName: tf.StringSplit, backendName: 'webgl', kernelFunc: stringSplit, }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function stringToHashBucketFast(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var numBuckets = attrs.numBuckets; var input = inputs.input; if (input.dtype !== 'string') { throw new Error('Input must be of datatype string'); } if (numBuckets <= 0) { throw new Error("Number of buckets must be at least 1"); } var $input = backend.readSync(input.dataId); var output = stringToHashBucketFastImplCPU($input, numBuckets); return backend.makeTensorInfo(input.shape, 'int32', output); } var stringToHashBucketFastConfig = { kernelName: tf.StringToHashBucketFast, backendName: 'webgl', kernelFunc: stringToHashBucketFast, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var TAN = "return tan(x);"; var tan = unaryKernelFunc({ opSnippet: TAN }); var tanConfig = { kernelName: tf.Tan, backendName: 'webgl', kernelFunc: tan, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var TANH = "\n float e2x = exp(-2.0 * abs(x));\n return sign(x) * (1.0 - e2x) / (1.0 + e2x);\n"; var tanh = unaryKernelFunc({ opSnippet: TANH }); var tanhConfig = { kernelName: tf.Tanh, backendName: 'webgl', kernelFunc: tanh, }; /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function tensorScatterUpdate(args) { var inputs = args.inputs, backend = args.backend; args.attrs; var tensor = inputs.tensor, indices = inputs.indices, updates = inputs.updates; var _b = tf.backend_util.calculateShapes(updates, indices, tensor.shape), sliceRank = _b.sliceRank, numUpdates = _b.numUpdates, sliceSize = _b.sliceSize, strides = _b.strides, outputSize = _b.outputSize; var flattenShape = [outputSize / sliceSize, sliceSize]; if (outputSize === 0) { return backend.makeTensorInfo(tensor.shape, indices.dtype); } var flattenIndices = reshape({ inputs: { x: indices }, backend: backend, attrs: { shape: [numUpdates, sliceRank] } }); var flattenX = reshape({ inputs: { x: updates }, backend: backend, attrs: { shape: [numUpdates, sliceSize] } }); var flattenTensor = reshape({ inputs: { x: tensor }, backend: backend, attrs: { shape: flattenShape } }); var program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape, false, true); var res = backend.runWebGLProgram(program, [flattenX, flattenIndices, flattenTensor], flattenTensor.dtype); var reshaped = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: tensor.shape } }); backend.disposeIntermediateTensorInfo(flattenIndices); backend.disposeIntermediateTensorInfo(flattenX); backend.disposeIntermediateTensorInfo(flattenTensor); backend.disposeIntermediateTensorInfo(res); return reshaped; } var tensorScatterUpdateConfig = { kernelName: tf.TensorScatterUpdate, backendName: 'webgl', kernelFunc: tensorScatterUpdate }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var TileProgram = /** @class */ (function () { function TileProgram(aShape, reps) { this.variableNames = ['A']; var outputShape = new Array(aShape.length); for (var i = 0; i < outputShape.length; i++) { outputShape[i] = aShape[i] * reps[i]; } this.outputShape = outputShape; this.rank = outputShape.length; var dtype = getCoordsDataType(this.rank); var sourceCoords = getSourceCoords(aShape); this.userCode = "\n void main() {\n ".concat(dtype, " resRC = getOutputCoords();\n setOutput(getA(").concat(sourceCoords, "));\n }\n "); } return TileProgram; }()); function getSourceCoords(aShape) { var rank = aShape.length; if (rank > 5) { throw Error("Tile for rank ".concat(rank, " is not yet supported")); } if (rank === 1) { return "imod(resRC, ".concat(aShape[0], ")"); } var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u']; var sourceCoords = []; for (var i = 0; i < aShape.length; i++) { sourceCoords.push("imod(".concat(currentCoords[i], ", ").concat(aShape[i], ")")); } return sourceCoords.join(); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function tile(params) { var inputs = params.inputs, backend = params.backend, attrs = params.attrs; var x = inputs.x; var reps = attrs.reps; // tile gpu program cannot handle rank > 5 case. if (x.dtype === 'string' || x.shape.length > 5) { // Even thought string tensor is always on CPU, just to be consistent on how // to access tensor data. var data = backend.readSync(x.dataId); var value = x.dtype === 'string' ? data.map(function (d) { return tf.util.decodeString(d); }) : data; var buf = tf.buffer(x.shape, x.dtype, value); var outBuf = tileImplCPU(buf, reps); return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values); } var program = new TileProgram(x.shape, reps); var output = backend.runWebGLProgram(program, [x], x.dtype); return output; } var tileConfig = { kernelName: tf.Tile, backendName: 'webgl', kernelFunc: tile, }; // Based on Algorithm 2 of Bitonic Top K, ref: // https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf // The original algorithm is based on computing the top K only, however // since for TFJS we require the indices of the top K values as well then the // algorithm found here is a bit modified. Rather than producing the values // at each step, the indices containing the top K are generated instead. // The output values are not generated to reduce the number of outputs in the // GPU, the values can easily be retrieved from the indices using a gather // op. var SwapProgram = /** @class */ (function () { /** * @param shape desired output shape (can be larger than input shape, output * will be padded with -Infinity) */ function SwapProgram(shape) { this.variableNames = ['x', 'indices']; // |n| Size of the original input of TopK. // |firstPass|indicates if this is the first time swap is being used which // means no indices input containing the top K is present yet. // |inc| Swaps pairs of indices (0, inc), (1, inc + 1), (2, inc + 2) ... this.customUniforms = [ { name: 'n', type: 'int' }, { name: 'firstPass', type: 'int' }, { name: 'negativeInf', type: 'float' }, { name: 'dir', type: 'int' }, { name: 'inc', type: 'int' } ]; this.outputShape = shape; this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int elemIdx = coords[1];\n\n // We compare elements pair-wise within a group of size 2 * inc.\n // The comparing rule for each group alternates between ascending\n // and descending. Within each group, we compare each pair at\n // positions i and i+inc. To decide whether an element at position i\n // is x0 or x1, we mod it by 2 * inc, if the result is smaller than\n // inc, it is in the first half of the group, we denote it as x0,\n // otherwise we denote it as x1.\n // For example, as shown in the Bitonic top K paper referenced above,\n // Figure5(a) shows that element[1] is in the\n // second half of the group when group size is 2, but it is in the\n // first half of the group when group size is 4.\n\n bool isFirstInPair = imod(elemIdx, 2 * inc) < inc;\n int i = isFirstInPair ? elemIdx : elemIdx - inc;\n\n int i0 = firstPass == 1 ? i : int(getIndices(batch, i));\n int i1 = firstPass == 1 ? i + inc : int(getIndices(batch, i + inc));\n float x0 = i0 < n ? getX(batch, i0) : negativeInf;\n float x1 = i1 < n ? getX(batch, i1) : negativeInf;\n\n // Denotes which direction indices are in (ascending or descending).\n bool reverse = imod(elemIdx, 2 * dir) >= dir;\n bool isGreater = x0 > x1 || (x0 == x1 && i1 > i0);\n if (reverse == isGreater) { // Elements in opposite order of direction\n int iTemp = i0;\n i0 = i1;\n i1 = iTemp;\n }\n if (isFirstInPair) {\n setOutput(float(i0));\n } else {\n setOutput(float(i1));\n }\n }\n "; } return SwapProgram; }()); var MergeProgram = /** @class */ (function () { /** * @param shape desired output shape (must be half of the input size) */ function MergeProgram(shape) { this.variableNames = ['x', 'indices']; // |n| Size of the original input of TopK // |firstPass| indicates if this is the first time swap is being used which // means no indices input containing the top K is present yet. // |k| Top k elements desired this.customUniforms = [ { name: 'n', type: 'int' }, { name: 'firstPass', type: 'int' }, { name: 'k', type: 'int' } ]; this.outputShape = shape; this.userCode = "\n void main() {\n // Takes max of indices (0, k), (1, k + 1), (2, k + 2) ...\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int elemIdx = coords[1];\n\n // The output size is half of the previous size.\n // If the previous sequence is | | | | _ _ _ _ | | | | _ _ _ _ (k=4),\n // we only need to output the indices at positions |, the indices at\n // positions _ can be thrown away, see Figure5(b) After Phase 2\n // (Merge phase) in the Bitonic Top K paper referenced above.\n // For example, the paper shows we only need to output the orange bars.\n // The output sequence should look like this | | | | | | | |.\n // Because the sequence is halved, to map the output index back\n // to the previous sequence to find the corresponding value,\n // we need to double the index. When we double the index,\n // we basically interpolate a position, so 2i looks like\n // | _ | _ | _ | _ | _ | _ | _. We move the | to the first k position\n // of each 2k positions by - elemIdx % k. E.g. for output at\n // index 4,5,6,7, we want to get the corresponding element at\n // original index 8,9,10,11, for output at index 8,9,10,11,\n // we want to get the corresponding element at original index\n // 16,17,18,19, so on and so forth.\n\n int i = elemIdx < k ? elemIdx : (elemIdx * 2 - imod(elemIdx, k));\n int i0 = firstPass == 1 ? i : int(getIndices(batch, i));\n int i1 = firstPass == 1 ? i + k : int(getIndices(batch, i + k));\n\n float x0 = getX(batch, i0);\n float x1 = i1 < n ? getX(batch, i1) : x0;\n\n setOutput(x0 >= x1 ? float(i0) : float(i1));\n }\n "; } return MergeProgram; }()); function disposeIntermediateTensorInfoOrNull(backend, tensorInfo) { if (tensorInfo !== null) { backend.disposeIntermediateTensorInfo(tensorInfo); } } function roundUpToPow2(num) { var pow2 = 1; while (pow2 < num) { pow2 *= 2; } return pow2; } // Based on Algorithm 2 of Bitonic Top K, ref: // https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf function topK(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x; var k = attrs.k, sorted = attrs.sorted; // Empirically determined constant used to determine last dim threshold for // handing off execution to the CPU. var TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD = tf.env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD'); // Empirically determined constant used to determine k threshold for handing // off execution to the CPU. var TOPK_K_CPU_HANDOFF_THRESHOLD = tf.env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD'); var xShape = x.shape; var lastDim = xShape[xShape.length - 1]; if (backend.shouldExecuteOnCPU([x]) || lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD || k > TOPK_K_CPU_HANDOFF_THRESHOLD) { var xVals = backend.readSync(x.dataId); var _a = __read(topKImplCPU(xVals, xShape, x.dtype, k, sorted), 2), allTopKVals = _a[0], allTopKIndices = _a[1]; return [ backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values), backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values) ]; } if (k === 0) { xShape[xShape.length - 1] = 0; return [ backend.makeTensorInfo(xShape, x.dtype, []), backend.makeTensorInfo(xShape, 'int32', []) ]; } if (lastDim === 1 /* firstPass */) { return [ x, fill({ attrs: { shape: xShape, dtype: 'int32', value: 0 }, backend: backend }) ]; } // Eagerly unpack x input since it is passed in to all the shaders which // require unpacked inputs. var xtexData = backend.texData.get(x.dataId); var xIsPacked = xtexData !== null && xtexData.isPacked; var xUnPacked = xIsPacked ? backend.unpackTensor(x) : x; // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim. var xSize = tf.util.sizeFromShape(xShape); var batch = xSize / lastDim; var x2D = reshape({ inputs: { x: xUnPacked }, attrs: { shape: [batch, lastDim] }, backend: backend }); if (xIsPacked) { disposeIntermediateTensorInfoOrNull(backend, xUnPacked); } var kPow2 = roundUpToPow2(k); var lastDimPow2 = roundUpToPow2(lastDim); // Only the indices containing the top K are kept at every step to reduce // number of outputs in the GPU algorithms, so once the final set of indices // is computed then gather is used to grab the corresponding values // from the original input. var indices = null; // GPU algorithm always takes in an indices input but this input is not used // on the first run of a GPU algorithm, therefore if indices is null we simply // pass in x2D instead of it but the value will not actually be used var getInputs = function () { return indices === null ? [x2D, x2D] : [x2D, indices]; }; var runSwap = function (dir, inc, shape) { var inputs = getInputs(); var program = new SwapProgram(shape); var fistPass = indices === null ? 1 : 0; var customValues = [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]]; var prevIndices = indices; indices = backend.runWebGLProgram(program, inputs, 'int32', customValues); disposeIntermediateTensorInfoOrNull(backend, prevIndices); }; // Step 1: local sort for (var len = 1; len < kPow2; len *= 2) { var dir = len * 2; for (var inc = len; inc >= 1; inc /= 2) { runSwap(dir, inc, [batch, lastDimPow2]); } } // Step 2: merge for (var indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) { var inputs_1 = getInputs(); var mergeProgram = new MergeProgram([batch, indicesSize / 2]); var firstPass = indices === null ? 1 : 0; var customValues = [[lastDim], [firstPass], [kPow2]]; var prevIndices_1 = indices; indices = backend.runWebGLProgram(mergeProgram, inputs_1, 'int32', customValues); disposeIntermediateTensorInfoOrNull(backend, prevIndices_1); // Step 3: rebuild var len = kPow2 / 2; var dir = len * 2; for (var inc = len; inc >= 1; inc /= 2) { runSwap(dir, inc, indices.shape); } } // Keep only the requested top K results instead of kPow2 var prevIndices = indices; indices = slice({ inputs: { x: indices }, backend: backend, attrs: { begin: 0, size: [batch, k] } }); disposeIntermediateTensorInfoOrNull(backend, prevIndices); // Gather values on last dimension var values = gatherV2({ inputs: { x: x2D, indices: indices }, backend: backend, attrs: { axis: 1, batchDims: 1 } }); disposeIntermediateTensorInfoOrNull(backend, x2D); // Reshape back to the original input shape, except that the last // dimension is k. var newShape = xShape.slice(0, -1); newShape.push(k); prevIndices = indices; indices = reshape({ inputs: { x: indices }, attrs: { shape: newShape }, backend: backend }); disposeIntermediateTensorInfoOrNull(backend, prevIndices); var prevValues = values; values = reshape({ inputs: { x: values }, attrs: { shape: newShape }, backend: backend }); disposeIntermediateTensorInfoOrNull(backend, prevValues); return [values, indices]; } var topKConfig = { kernelName: tf.TopK, backendName: 'webgl', kernelFunc: topK }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var TransformProgram = /** @class */ (function () { function TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape) { this.variableNames = ['Image', 'Transforms']; this.outputShape = outShape; var interpolationModeId = interpolation === 'nearest' ? 1 : 2; var fillModeId; switch (fillMode) { case 'constant': fillModeId = 1; break; case 'reflect': fillModeId = 2; break; case 'wrap': fillModeId = 3; break; case 'nearest': fillModeId = 4; break; default: fillModeId = 1; break; } this.userCode = "\n float mapCoord(float outCoord, float len) {\n float inCoord = outCoord;\n if(".concat(fillModeId, " == 2) {\n if (inCoord < 0.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz2 = 2.0 * len;\n if (inCoord < sz2) {\n inCoord = sz2 * float(int(float(-inCoord / sz2))) +\n inCoord;\n }\n inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0;\n }\n } else if (inCoord > len - 1.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz2 = 2.0 * len;\n inCoord -= sz2 * float(int(float(inCoord / sz2)));\n if (inCoord >= len) {\n inCoord = sz2 - inCoord - 1.0;\n }\n }\n }\n return clamp(inCoord, 0.0, len - 1.0);\n } else if (").concat(fillModeId, " == 3) {\n if (inCoord < 0.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz = len - 1.0;\n inCoord += len * (float(int(float(-inCoord / sz))) + 1.0);\n }\n } else if (inCoord > len - 1.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz = len - 1.0;\n inCoord -= len * float(int(float(inCoord / sz)));\n }\n }\n return clamp(inCoord, 0.0, len - 1.0);\n } else if (").concat(fillModeId, " == 4) {\n return clamp(outCoord, 0.0, len - 1.0);\n } else {\n return outCoord;\n }\n }\n\n float readWithFillValue(int batch, int coordY, int coordX,\n int channel) {\n float outputValue;\n if (0 <= coordY && coordY < ").concat(imageHeight, " && 0 <= coordX && coordX < ").concat(imageWidth, ") {\n outputValue = getImage(batch, coordY, coordX, channel);\n } else {\n outputValue = float(").concat(fillValue, ");\n }\n return outputValue;\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n float outputValue;\n int batch = coords[0];\n int x = coords[2];\n int y = coords[1];\n int channel = coords[3];\n float xf = float(x);\n float yf = float(y);\n float a1 = getTransforms(batch, 0);\n float a2 = getTransforms(batch, 1);\n float a3 = getTransforms(batch, 2);\n float b1 = getTransforms(batch, 3);\n float b2 = getTransforms(batch, 4);\n float b3 = getTransforms(batch, 5);\n float c1 = getTransforms(batch, 6);\n float c2 = getTransforms(batch, 7);\n float projection = c1 * xf + c2 * yf + 1.0;\n if (projection == 0.0) {\n outputValue = float(").concat(fillValue, ");\n } else {\n float inX = (a1 * xf + a2 * yf + a3) / projection;\n float inY = (b1 * xf + b2 * yf + b3) / projection;\n float mapX = mapCoord(inX, float(").concat(imageWidth, "));\n float mapY = mapCoord(inY, float(").concat(imageHeight, "));\n\n if (").concat(interpolationModeId, " == 1) {\n int coordY = int(round(mapY));\n int coordX = int(round(mapX));\n outputValue = readWithFillValue(batch, coordY, coordX,\n channel);\n } else {\n float yFloor = floor(mapY);\n float xFloor = floor(mapX);\n float yCeil = yFloor + 1.0;\n float xCeil = xFloor + 1.0;\n float valueYFloor = (xCeil - mapX) *\n readWithFillValue(batch, int(yFloor), int(xFloor), channel) +\n (mapX - xFloor) *\n readWithFillValue(batch, int(yFloor), int(xCeil), channel);\n float valueYCeil = (xCeil - mapX) *\n readWithFillValue(batch, int(yCeil), int(xFloor), channel) +\n (mapX - xFloor) *\n readWithFillValue(batch, int(yCeil), int(xCeil), channel);\n outputValue = (yCeil - mapY) * valueYFloor +\n (mapY - yFloor) * valueYCeil;\n }\n }\n setOutput(outputValue);\n }\n "); } return TransformProgram; }()); function transform(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var image = inputs.image, transforms = inputs.transforms; var interpolation = attrs.interpolation, fillMode = attrs.fillMode, fillValue = attrs.fillValue, outputShape = attrs.outputShape; var _a = __read(image.shape, 4), batch = _a[0], imageHeight = _a[1], imageWidth = _a[2], numChannels = _a[3]; var _b = __read(outputShape != null ? outputShape : [imageHeight, imageWidth], 2), outHeight = _b[0], outWidth = _b[1]; var outShape = [batch, outHeight, outWidth, numChannels]; var program = new TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape); return backend.runWebGLProgram(program, [image, transforms], 'float32'); } var transformConfig = { kernelName: tf.Transform, backendName: 'webgl', kernelFunc: transform }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function unique(args) { var inputs = args.inputs, attrs = args.attrs, backend = args.backend; var axis = attrs.axis; var x = inputs.x; assertNotComplex(x, 'unique'); // For now, always forward calculation to the CPU backend. console.warn('WARNING: ', 'UI might be locked temporarily as data is being downloaded'); var values = backend.readSync(x.dataId); var _a = uniqueImplCPU(values, axis, x.shape, x.dtype), outputValues = _a.outputValues, outputShape = _a.outputShape, indices = _a.indices; return [ backend.makeTensorInfo(outputShape, x.dtype, outputValues), backend.makeTensorInfo([indices.length], 'int32', indices), ]; } var uniqueConfig = { kernelName: tf.Unique, backendName: 'webgl', kernelFunc: unique, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function unpack(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var value = inputs.value; var axis = attrs.axis; if (axis < 0) { axis += value.shape.length; } var x = value; var xRank = x.shape.length; var num = value.shape[axis]; var outShape = new Array(xRank - 1); var outIndex = 0; for (var i = 0; i < xRank; i++) { if (i !== axis) { outShape[outIndex++] = x.shape[i]; } } var toDispose = []; var begin = new Array(xRank).fill(0); var size = x.shape.slice(); size[axis] = 1; var res = new Array(num); for (var i = 0; i < res.length; i++) { begin[axis] = i; var sliced = slice({ inputs: { x: x }, backend: backend, attrs: { begin: begin, size: size } }); var reshaped = reshape({ inputs: { x: sliced }, backend: backend, attrs: { shape: outShape } }); res[i] = reshaped; toDispose.push(sliced); } toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return res; } var unpackConfig = { kernelName: tf.Unpack, backendName: 'webgl', kernelFunc: unpack }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SegmentOpProgram = /** @class */ (function () { function SegmentOpProgram(segOpInfo, segOpType) { this.variableNames = ['x', 'segmentIds']; var windowSize = segOpInfo.windowSize; var batchSize = segOpInfo.batchSize; var inSize = segOpInfo.inSize; var numSegments = segOpInfo.numSegments; var outSize = numSegments * Math.ceil(inSize / windowSize); this.outputShape = [batchSize, outSize]; var initializationValue = '0.0'; var returnValue = "sumValue"; var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; var windowSizeVec4Remainder = windowSize % 4; var updateSnippet = "\n sumValue += dot(values, segFilter);\n "; var checkValueOutOfBounds = ''; if (inSize % windowSize > 0) { checkValueOutOfBounds = "\n if (inIdx < 0 || inIdx >= ".concat(inSize, ") {\n return initializationValue;\n }\n "); } var checkSegmentIdOutOfBounds = ''; if (inSize % windowSize > 0) { checkSegmentIdOutOfBounds = "\n if (inIdx < 0 || inIdx >= ".concat(inSize, ") {\n return -1.0;\n }\n "); } this.userCode = "\n const float initializationValue = ".concat(initializationValue, ";\n\n float getValue(int batch, int inIdx) {\n ").concat(checkValueOutOfBounds, "\n return getX(batch, inIdx);\n }\n\n float getSegmentIdAtIndex(int inIdx) {\n ").concat(checkSegmentIdOutOfBounds, "\n return getSegmentIds(inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = int(floor(float(outIdx) / float(\n ").concat(numSegments, ")) * float(").concat(windowSize, "));\n int currentSeg = int(mod(float(outIdx), float(").concat(numSegments, ")));\n\n float sumValue = 0.0;\n\n for (int i = 0; i < ").concat(windowSizeNearestVec4, "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0\n );\n\n ").concat(updateSnippet, "\n }\n\n int inIdx = inOffset + ").concat(windowSizeNearestVec4, ";\n if (").concat(windowSizeVec4Remainder === 1, ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n int inIdxSeg = int(getSegmentIdAtIndex(inIdx));\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n 0,\n 0,\n 0\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 2, ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n 0,\n 0\n );\n\n ").concat(updateSnippet, "\n } else if (").concat(windowSizeVec4Remainder === 3, ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n 0\n );\n\n ").concat(updateSnippet, "\n }\n setOutput(").concat(returnValue, ");\n }\n "); } return SegmentOpProgram; }()); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function unsortedSegmentSum(args) { var inputs = args.inputs, backend = args.backend, attrs = args.attrs; var x = inputs.x, segmentIds = inputs.segmentIds; var numSegments = attrs.numSegments; var xRank = x.shape.length; var toDispose = []; var axis = 0; var permutation = tf.backend_util.getAxesPermutation([axis], xRank); var permutedX = x; if (permutation != null) { permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutation } }); toDispose.push(permutedX); axis = tf.backend_util.getInnerMostAxes(1, xRank)[0]; } var outShape = tf.backend_util.segment_util.computeOutShape(permutedX.shape, axis, numSegments); var inSize = tf.util.sizeFromShape([permutedX.shape[axis]]); var a2D = reshape({ inputs: { x: permutedX }, backend: backend, attrs: { shape: [-1, inSize] } }); toDispose.push(a2D); var outputDType = tf.sumOutType(x.dtype); var segOpCompute = function (x, segOpType, segmentIds, dtype, numSegments) { var batchSize = x.shape[0]; var inSize = x.shape[1]; var windowSize = tf.backend_util.segment_util.segOpComputeOptimalWindowSize(inSize, numSegments); var segOpInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize, numSegments: numSegments }; var program = new SegmentOpProgram(segOpInfo, segOpType); var output = backend.compileAndRun(program, [x, segmentIds], dtype); toDispose.push(output); // No need to run another GPGPU program. if (output.shape[1] === numSegments) { return output; } var rangeInfo = range({ backend: backend, attrs: { start: 0, stop: numSegments, step: 1, dtype: 'float32' } }); var tileInfo = tile({ inputs: { x: rangeInfo }, backend: backend, attrs: { reps: [inSize / windowSize] } }); toDispose.push(rangeInfo); toDispose.push(tileInfo); var result = segOpCompute(output, segOpType, tileInfo, dtype, numSegments); return result; }; var segOpResult = segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments); var reshaped = reshape({ inputs: { x: segOpResult }, backend: backend, attrs: { shape: outShape } }); var result = reshaped; if (permutation != null) { toDispose.push(reshaped); var perm = tf.backend_util.getUndoAxesPermutation(permutation); result = transpose({ inputs: { x: result }, backend: backend, attrs: { perm: perm } }); } toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); }); return result; } var unsortedSegmentSumConfig = { kernelName: tf.UnsortedSegmentSum, backendName: 'webgl', kernelFunc: unsortedSegmentSum }; var e_1, _a; // List all kernel configs here var kernelConfigs = [ _fusedMatMulConfig, absConfig, acosConfig, acoshConfig, addConfig, addNConfig, allConfig, anyConfig, argMaxConfig, argMinConfig, asinConfig, asinhConfig, atanConfig, atan2Config, atanhConfig, avgPoolConfig, avgPool3DConfig, avgPool3DGradConfig, avgPoolGradConfig, batchMatMulConfig, batchNormConfig, batchToSpaceNDConfig, bincountConfig, bitwiseAndConfig, broadcastArgsConfig, castConfig, ceilConfig, clipByValueConfig, complexConfig, complexAbsConfig, concatConfig, conv2DConfig, conv2DBackpropFilterConfig, conv2DBackpropInputConfig, conv3DConfig, conv3DBackpropFilterV2Config, conv3DBackpropInputConfig, cosConfig, coshConfig, cropAndResizeConfig, cumprodConfig, cumsumConfig, denseBincountConfig, depthToSpaceConfig, depthwiseConv2dNativeConfig, depthwiseConv2dNativeBackpropFilterConfig, depthwiseConv2dNativeBackpropInputConfig, diagConfig, dilation2DConfig, einsumConfig, eluConfig, eluGradConfig, equalConfig, erfConfig, expConfig, expandDimsConfig, expm1Config, fftConfig, fillConfig, flipLeftRightConfig, floorConfig, floorDivConfig, fromPixelsConfig, fusedConv2DConfig, fusedDepthwiseConv2DConfig, gatherNdConfig, gatherV2Config, greaterConfig, greaterEqualConfig, identityConfig, ifftConfig, imagConfig, isFiniteConfig, isInfConfig, isNaNConfig, leakyReluConfig, lessConfig, lessEqualConfig, linSpaceConfig, logConfig, log1pConfig, logicalAndConfig, logicalNotConfig, logicalOrConfig, LRNConfig, LRNGradConfig, maxConfig, maximumConfig, maxPoolConfig, maxPool3DConfig, maxPool3DGradConfig, maxPoolGradConfig, maxPoolWithArgmaxConfig, meanConfig, minConfig, minimumConfig, mirrorPadConfig, modConfig, multinomialConfig, multiplyConfig, negConfig, nonMaxSuppressionV3Config, nonMaxSuppressionV4Config, nonMaxSuppressionV5Config, notEqualConfig, oneHotConfig, onesLikeConfig, packConfig, padV2Config, powConfig, preluConfig, prodConfig, raggedGatherConfig, raggedRangeConfig, raggedTensorToTensorConfig, rangeConfig, realConfig, realDivConfig, reciprocalConfig, reluConfig, relu6Config, reshapeConfig, resizeBilinearConfig, resizeBilinearGradConfig, resizeNearestNeighborConfig, resizeNearestNeighborGradConfig, reverseConfig, rotateWithOffsetConfig, roundConfig, rsqrtConfig, scatterNdConfig, searchSortedConfig, selectConfig, seluConfig, sigmoidConfig, signConfig, sinConfig, sinhConfig, sliceConfig, softmaxConfig, softplusConfig, spaceToBatchNDConfig, sparseFillEmptyRowsConfig, sparseReshapeConfig, sparseSegmentMeanConfig, sparseSegmentSumConfig, sparseToDenseConfig, splitVConfig, sqrtConfig, squareConfig, squaredDifferenceConfig, staticRegexReplaceConfig, stepConfig, stridedSliceConfig, stringNGramsConfig, stringSplitConfig, stringToHashBucketFastConfig, subConfig, sumConfig, tanConfig, tanhConfig, tensorScatterUpdateConfig, tileConfig, topKConfig, transformConfig, transposeConfig, uniqueConfig, unpackConfig, unsortedSegmentSumConfig, zerosLikeConfig ]; try { for (var kernelConfigs_1 = __values(kernelConfigs), kernelConfigs_1_1 = kernelConfigs_1.next(); !kernelConfigs_1_1.done; kernelConfigs_1_1 = kernelConfigs_1.next()) { var kernelConfig = kernelConfigs_1_1.value; tf.registerKernel(kernelConfig); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (kernelConfigs_1_1 && !kernelConfigs_1_1.done && (_a = kernelConfigs_1.return)) _a.call(kernelConfigs_1); } finally { if (e_1) throw e_1.error; } } exports.GPGPUContext = GPGPUContext; exports.MathBackendWebGL = MathBackendWebGL; exports.forceHalfFloat = forceHalfFloat; exports.gpgpu_util = gpgpu_util; exports.setWebGLContext = setWebGLContext; exports.version_webgl = version; exports.webgl = webgl; exports.webgl_util = webgl_util; //# sourceMappingURL=tf-backend-webgl.node.js.map