/** * @license * Copyright 2023 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ (function (global, factory) { typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports, require('@tensorflow/tfjs-core')) : typeof define === 'function' && define.amd ? define(['exports', '@tensorflow/tfjs-core'], factory) : (global = typeof globalThis !== 'undefined' ? globalThis : global || self, factory(global.tf = global.tf || {}, global.tf)); })(this, (function (exports, tf) { 'use strict'; function _interopNamespaceDefault(e) { var n = Object.create(null); if (e) { Object.keys(e).forEach(function (k) { if (k !== 'default') { var d = Object.getOwnPropertyDescriptor(e, k); Object.defineProperty(n, k, d.get ? d : { enumerable: true, get: function () { return e[k]; } }); } }); } n.default = e; return n; } var tf__namespace = /*#__PURE__*/_interopNamespaceDefault(tf); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const contexts = {}; const WEBGL_ATTRIBUTES = { alpha: false, antialias: false, premultipliedAlpha: false, preserveDrawingBuffer: false, depth: false, stencil: false, failIfMajorPerformanceCaveat: true }; function setWebGLContext(webGLVersion, gl) { contexts[webGLVersion] = gl; } function getWebGLContext(webGLVersion, customCanvas) { if (!(webGLVersion in contexts) || customCanvas != null) { const newCtx = getWebGLRenderingContext(webGLVersion, customCanvas); if (newCtx !== null) { contexts[webGLVersion] = newCtx; } else { console.log('Could not get context for WebGL version', webGLVersion); return null; } } const gl = contexts[webGLVersion]; if (gl == null || gl.isContextLost()) { delete contexts[webGLVersion]; return getWebGLContext(webGLVersion); } gl.disable(gl.DEPTH_TEST); gl.disable(gl.STENCIL_TEST); gl.disable(gl.BLEND); gl.disable(gl.DITHER); gl.disable(gl.POLYGON_OFFSET_FILL); gl.disable(gl.SAMPLE_COVERAGE); gl.enable(gl.SCISSOR_TEST); gl.enable(gl.CULL_FACE); gl.cullFace(gl.BACK); return contexts[webGLVersion]; } function createCanvas(webGLVersion) { // Use canvas element for Safari, since its offscreen canvas does not support // fencing. if (!tf.env().getBool('IS_SAFARI') && typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) { return new OffscreenCanvas(300, 150); } else if (typeof document !== 'undefined') { return document.createElement('canvas'); } else { throw new Error('Cannot create a canvas in this context'); } } function getWebGLRenderingContext(webGLVersion, customCanvas) { if (webGLVersion !== 1 && webGLVersion !== 2) { throw new Error('Cannot get WebGL rendering context, WebGL is disabled.'); } const canvas = customCanvas == null ? createCanvas(webGLVersion) : customCanvas; canvas.addEventListener('webglcontextlost', (ev) => { ev.preventDefault(); delete contexts[webGLVersion]; }, false); if (tf.env().getBool('SOFTWARE_WEBGL_ENABLED')) { WEBGL_ATTRIBUTES.failIfMajorPerformanceCaveat = false; } if (webGLVersion === 1) { return ( // tslint:disable-next-line canvas.getContext('webgl', WEBGL_ATTRIBUTES) || canvas .getContext('experimental-webgl', WEBGL_ATTRIBUTES)); } return canvas.getContext('webgl2', WEBGL_ATTRIBUTES); } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var PackingScheme; (function (PackingScheme) { /** * All values in a single texel are densely packed without any constraints. * * This is how the shader encodes a tensor with shape = [2, 3, 4] * (indices are [batch, row, col]). * * 000|001 010|011 020|021 * ------- ------- ------- * 002|003 012|013 022|023 * * 100|101 110|111 120|121 * ------- ------- ------- * 102|103 112|113 122|123 * */ PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE"; /** * Single texels contain only values from the same batch, and from adjacent * rows and columns. * * This is how the shader encodes a tensor with shape = [2, 3, 5] * (indices are [batch, row, col]). * * 000|001 002|003 004|xxx 020|021 022|023 024|xxx * ------- ------- ------- ------- ------- ------- * 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx * * 100|101 102|103 104|xxx 120|121 122|123 124|xxx * ------- ------- ------- ------- ------- ------- * 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx * */ PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH"; })(PackingScheme || (PackingScheme = {})); var TextureUsage; (function (TextureUsage) { TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER"; TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD"; TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS"; TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD"; })(TextureUsage || (TextureUsage = {})); var PhysicalTextureType; (function (PhysicalTextureType) { PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16"; PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32"; PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE"; PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32"; PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16"; })(PhysicalTextureType || (PhysicalTextureType = {})); function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) { return [columns, rows]; } function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) { return matrixSize * channelsPerTexture; } /** * Get shape for densely packed RGBA texture. */ function getDenseTexShape(shape) { const size = tf.util.sizeFromShape(shape); const texelsNeeded = Math.ceil(size / 4); return tf.util.sizeToSquarishShape(texelsNeeded); } function getPackedMatrixTextureShapeWidthHeight(rows, columns) { return [ Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2)) ]; } function getPackedRGBAArraySizeFromMatrixShape(rows, columns) { const [w, h] = getPackedMatrixTextureShapeWidthHeight(rows, columns); return w * h * 4; } function getTextureConfig( // tslint:disable-next-line:no-any gl, textureHalfFloatExtension) { // tslint:disable-next-line:no-any const glany = gl; let internalFormatFloat; let internalFormatHalfFloat; let internalFormatPackedHalfFloat; let internalFormatPackedFloat; let textureFormatFloat; let downloadTextureFormat; let downloadUnpackNumChannels; let defaultNumChannels; let textureTypeHalfFloat; let textureTypeFloat; if (tf.env().getNumber('WEBGL_VERSION') === 2) { internalFormatFloat = glany.R32F; internalFormatHalfFloat = glany.R16F; internalFormatPackedHalfFloat = glany.RGBA16F; internalFormatPackedFloat = glany.RGBA32F; textureFormatFloat = glany.RED; downloadUnpackNumChannels = 4; defaultNumChannels = 1; textureTypeHalfFloat = glany.HALF_FLOAT; textureTypeFloat = glany.FLOAT; downloadTextureFormat = glany.RGBA8; } else { internalFormatFloat = gl.RGBA; internalFormatHalfFloat = gl.RGBA; internalFormatPackedHalfFloat = gl.RGBA; internalFormatPackedFloat = glany.RGBA; textureFormatFloat = gl.RGBA; downloadUnpackNumChannels = 4; defaultNumChannels = 4; textureTypeHalfFloat = textureHalfFloatExtension != null ? textureHalfFloatExtension.HALF_FLOAT_OES : null; textureTypeFloat = gl.FLOAT; downloadTextureFormat = gl.RGBA; } return { internalFormatFloat, internalFormatHalfFloat, internalFormatPackedHalfFloat, internalFormatPackedFloat, textureFormatFloat, downloadTextureFormat, downloadUnpackNumChannels, defaultNumChannels, textureTypeHalfFloat, textureTypeFloat }; } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function callAndCheck(gl, func) { const returnValue = func(); if (tf.env().getBool('DEBUG')) { checkWebGLError(gl); } return returnValue; } function checkWebGLError(gl) { const error = gl.getError(); if (error !== gl.NO_ERROR) { throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error)); } } // https://en.wikipedia.org/wiki/Half-precision_floating-point_format const MIN_FLOAT16 = 5.96e-8; const MAX_FLOAT16 = 65504; function canBeRepresented(num) { if (tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 || (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) { return true; } return false; } function getWebGLErrorMessage(gl, status) { switch (status) { case gl.NO_ERROR: return 'NO_ERROR'; case gl.INVALID_ENUM: return 'INVALID_ENUM'; case gl.INVALID_VALUE: return 'INVALID_VALUE'; case gl.INVALID_OPERATION: return 'INVALID_OPERATION'; case gl.INVALID_FRAMEBUFFER_OPERATION: return 'INVALID_FRAMEBUFFER_OPERATION'; case gl.OUT_OF_MEMORY: return 'OUT_OF_MEMORY'; case gl.CONTEXT_LOST_WEBGL: return 'CONTEXT_LOST_WEBGL'; default: return `Unknown error code ${status}`; } } function getExtensionOrThrow(gl, extensionName) { return throwIfNull(gl, () => gl.getExtension(extensionName), 'Extension "' + extensionName + '" not supported on this browser.'); } function createVertexShader$1(gl, vertexShaderSource) { const vertexShader = throwIfNull(gl, () => gl.createShader(gl.VERTEX_SHADER), 'Unable to create vertex WebGLShader.'); callAndCheck(gl, () => gl.shaderSource(vertexShader, vertexShaderSource)); callAndCheck(gl, () => gl.compileShader(vertexShader)); if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) { console.log(gl.getShaderInfoLog(vertexShader)); throw new Error('Failed to compile vertex shader.'); } return vertexShader; } function createFragmentShader(gl, fragmentShaderSource) { const fragmentShader = throwIfNull(gl, () => gl.createShader(gl.FRAGMENT_SHADER), 'Unable to create fragment WebGLShader.'); callAndCheck(gl, () => gl.shaderSource(fragmentShader, fragmentShaderSource)); callAndCheck(gl, () => gl.compileShader(fragmentShader)); if (tf.env().get('ENGINE_COMPILE_ONLY')) { return fragmentShader; } if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) { logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader)); throw new Error('Failed to compile fragment shader.'); } return fragmentShader; } const lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g; function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) { const lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog); if (lineNumberRegexResult == null) { console.log(`Couldn't parse line number in error: ${shaderInfoLog}`); console.log(shaderSource); return; } const lineNumber = +lineNumberRegexResult[1]; const shaderLines = shaderSource.split('\n'); const pad = shaderLines.length.toString().length + 2; const linesWithLineNumbers = shaderLines.map((line, lineNumber) => tf.util.rightPad((lineNumber + 1).toString(), pad) + line); let maxLineLength = 0; for (let i = 0; i < linesWithLineNumbers.length; i++) { maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength); } const beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1); const errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber); const afterErrorLines = linesWithLineNumbers.slice(lineNumber); console.log(beforeErrorLines.join('\n')); console.log(shaderInfoLog.split('\n')[0]); console.log(`%c ${tf.util.rightPad(errorLine[0], maxLineLength)}`, 'border:1px solid red; background-color:#e3d2d2; color:#a61717'); console.log(afterErrorLines.join('\n')); } function createProgram(gl) { return throwIfNull(gl, () => gl.createProgram(), 'Unable to create WebGLProgram.'); } function linkProgram(gl, program) { callAndCheck(gl, () => gl.linkProgram(program)); if (tf.env().get('ENGINE_COMPILE_ONLY')) { return; } if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) { console.log(gl.getProgramInfoLog(program)); throw new Error('Failed to link vertex and fragment shaders.'); } } /// validateProgram is effectively "If we `useProgram(program); drawArrays();`, /// give feedback in log about perf/correctness warnings or errors that would /// occur." /// So make sure we set up all vertex/texture/sampler/uniform data before /// calling validateProgram! function validateProgram(gl, program) { callAndCheck(gl, () => gl.validateProgram(program)); if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) { console.log(gl.getProgramInfoLog(program)); throw new Error('Shader program validation failed.'); } } function createStaticVertexBuffer(gl, data) { const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer'); callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer)); callAndCheck(gl, () => gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW)); return buffer; } function createStaticIndexBuffer(gl, data) { const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer'); callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer)); callAndCheck(gl, () => gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW)); return buffer; } function getNumChannels() { if (tf.env().getNumber('WEBGL_VERSION') === 2) { return 1; } return 4; } function createTexture(gl) { return throwIfNull(gl, () => gl.createTexture(), 'Unable to create WebGLTexture.'); } function validateTextureSize(width, height) { const maxTextureSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); if ((width <= 0) || (height <= 0)) { const requested = `[${width}x${height}]`; throw new Error('Requested texture size ' + requested + ' is invalid.'); } if ((width > maxTextureSize) || (height > maxTextureSize)) { const requested = `[${width}x${height}]`; const max = `[${maxTextureSize}x${maxTextureSize}]`; throw new Error('Requested texture size ' + requested + ' greater than WebGL maximum on this browser / GPU ' + max + '.'); } } function createFramebuffer(gl) { return throwIfNull(gl, () => gl.createFramebuffer(), 'Unable to create WebGLFramebuffer.'); } function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) { const loc = gl.getAttribLocation(program, attribute); if (loc === -1) { // The GPU compiler decided to strip out this attribute because it's unused, // thus no need to bind. return false; } callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer)); callAndCheck(gl, () => gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes)); callAndCheck(gl, () => gl.enableVertexAttribArray(loc)); return true; } function bindTextureUnit(gl, texture, textureUnit) { validateTextureUnit(gl, textureUnit); callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit)); callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture)); } function unbindTextureUnit(gl, textureUnit) { validateTextureUnit(gl, textureUnit); callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit)); callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null)); } function getProgramUniformLocationOrThrow(gl, program, uniformName) { return throwIfNull(gl, () => gl.getUniformLocation(program, uniformName), 'uniform "' + uniformName + '" not present in program.'); } function getProgramUniformLocation(gl, program, uniformName) { return gl.getUniformLocation(program, uniformName); } function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) { callAndCheck(gl, () => bindTextureUnit(gl, texture, textureUnit)); callAndCheck(gl, () => gl.uniform1i(uniformSamplerLocation, textureUnit)); } function bindCanvasToFramebuffer(gl) { callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null)); callAndCheck(gl, () => gl.viewport(0, 0, gl.canvas.width, gl.canvas.height)); callAndCheck(gl, () => gl.scissor(0, 0, gl.canvas.width, gl.canvas.height)); } function bindColorTextureToFramebuffer(gl, texture, framebuffer) { callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer)); callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0)); } function unbindColorTextureFromFramebuffer(gl, framebuffer) { callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer)); callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0)); } function validateFramebuffer(gl) { const status = gl.checkFramebufferStatus(gl.FRAMEBUFFER); if (status !== gl.FRAMEBUFFER_COMPLETE) { throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status)); } } function getFramebufferErrorMessage(gl, status) { switch (status) { case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT: return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT'; case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT: return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT'; case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS: return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS'; case gl.FRAMEBUFFER_UNSUPPORTED: return 'FRAMEBUFFER_UNSUPPORTED'; default: return `unknown error ${status}`; } } function throwIfNull(gl, returnTOrNull, failureMessage) { const tOrNull = callAndCheck(gl, () => returnTOrNull()); if (tOrNull == null) { throw new Error(failureMessage); } return tOrNull; } function validateTextureUnit(gl, textureUnit) { const maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1; const glTextureUnit = textureUnit + gl.TEXTURE0; if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) { const textureUnitRange = `[gl.TEXTURE0, gl.TEXTURE${maxTextureUnit}]`; throw new Error(`textureUnit must be in ${textureUnitRange}.`); } } function getBatchDim(shape, dimsToSkip = 2) { return tf.util.sizeFromShape(shape.slice(0, shape.length - dimsToSkip)); } function getRowsCols(shape) { if (shape.length === 0) { throw Error('Cannot get rows and columns of an empty shape array.'); } return [ shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1] ]; } function getShapeAs3D(shape) { let shapeAs3D = [1, 1, 1]; const isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1); if (!isScalar) { shapeAs3D = [getBatchDim(shape), ...getRowsCols(shape)]; } return shapeAs3D; } function getTextureShapeFromLogicalShape(logShape, isPacked = false) { let maxTexSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); let maxSizeForNarrowTex = tf.env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE'); if (maxSizeForNarrowTex === Infinity && tf.env().getBool('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE')) { maxSizeForNarrowTex = maxTexSize / 2; } if (isPacked) { maxTexSize = maxTexSize * 2; maxSizeForNarrowTex = maxSizeForNarrowTex * 2; // This logic ensures we accurately count the number of packed texels needed // to accommodate the tensor. We can only pack values in the same texel if // they are from adjacent pairs of rows/cols within the same batch. So if a // tensor has 3 rows, we pretend it has 4 rows in order to account for the // fact that the texels containing the third row are half empty. logShape = logShape.map((d, i) => i >= logShape.length - 2 ? tf.util.nearestLargerEven(logShape[i]) : logShape[i]); // Packed texture height is at least 2 (the channel height of a single // texel). if (logShape.length === 1) { logShape = [2, logShape[0]]; } } // If logical shape is 2, we don't squeeze, since we want to match physical. if (logShape.length !== 2) { const squeezeResult = tf.util.squeezeShape(logShape); logShape = squeezeResult.newShape; } let size = tf.util.sizeFromShape(logShape); let textureShape = null; if (logShape.length <= 1 && size <= maxTexSize) { textureShape = [1, size]; } else if (logShape.length === 2 && logShape[0] <= maxTexSize && logShape[1] <= maxTexSize) { textureShape = logShape; } else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize && logShape[2] <= maxTexSize) { textureShape = [logShape[0] * logShape[1], logShape[2]]; } else if (logShape.length === 3 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] <= maxTexSize) { textureShape = [logShape[0], logShape[1] * logShape[2]]; } else if (logShape.length === 4 && logShape[0] * logShape[1] * logShape[2] <= maxTexSize && logShape[3] <= maxTexSize) { textureShape = [logShape[0] * logShape[1] * logShape[2], logShape[3]]; } else if (logShape.length === 4 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] * logShape[3] <= maxTexSize) { textureShape = [logShape[0], logShape[1] * logShape[2] * logShape[3]]; } // true if one edge length is 1 (1 or 2, if packed), while another edge // length exceeds maxSizeForNarrowTex. const isLongNarrowTex = textureShape != null && Math.max(...textureShape) > maxSizeForNarrowTex && Math.min(...textureShape) <= (isPacked ? 2 : 1) && Math.min(...textureShape) > 0; if (textureShape == null || isLongNarrowTex) { if (isPacked) { // For packed textures size equals the number of channels required to // accommodate the texture data. However in order to squarify such that // inner dimensions stay even, we rewrite size to equal the number of // texels. Then in the return statement we rehydrate the squarified // dimensions to channel units. const batchDim = getBatchDim(logShape); let rows = 2, cols = 2; if (logShape.length) { [rows, cols] = getRowsCols(logShape); } size = batchDim * (rows / 2) * (cols / 2); textureShape = tf.util.sizeToSquarishShape(size).map(d => d * 2); } else { textureShape = tf.util.sizeToSquarishShape(size); } } return textureShape; } function isEven(n) { return n % 2 === 0; } /** * This determines whether reshaping a packed texture requires rearranging * the data within the texture, assuming 2x2 packing. */ function isReshapeFree(shape1, shape2) { shape1 = shape1.slice(-2); shape2 = shape2.slice(-2); if (tf.util.arraysEqual(shape1, shape2)) { return true; } if (!shape1.length || !shape2.length) { // One of the shapes is a scalar. return true; } if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 || shape2[1] === 0) { return true; } if (shape1.length !== shape2.length) { // One of the shapes is a vector. const shape1Cols = shape1[shape1.length - 1]; const shape2Cols = shape2[shape2.length - 1]; if (shape1Cols === shape2Cols) { return true; } if (isEven(shape1Cols) && isEven(shape2Cols) && (shape1[0] === 1 || shape2[0] === 1)) { return true; } } return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]); } // We cache webgl params because the environment gets reset between // unit tests and we don't want to constantly query the WebGLContext for // MAX_TEXTURE_SIZE. let MAX_TEXTURE_SIZE; let MAX_TEXTURES_IN_SHADER; function getWebGLMaxTextureSize(webGLVersion) { if (MAX_TEXTURE_SIZE == null) { const gl = getWebGLContext(webGLVersion); MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE); } return MAX_TEXTURE_SIZE; } function resetMaxTextureSize() { MAX_TEXTURE_SIZE = null; } function resetMaxTexturesInShader() { MAX_TEXTURES_IN_SHADER = null; } function getMaxTexturesInShader(webGLVersion) { if (MAX_TEXTURES_IN_SHADER == null) { const gl = getWebGLContext(webGLVersion); MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS); } // We cap at 16 to avoid spurious runtime "memory exhausted" error. return Math.min(16, MAX_TEXTURES_IN_SHADER); } function getWebGLDisjointQueryTimerVersion(webGLVersion) { if (webGLVersion === 0) { return 0; } let queryTimerVersion; const gl = getWebGLContext(webGLVersion); if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') && webGLVersion === 2) { queryTimerVersion = 2; } else if (hasExtension(gl, 'EXT_disjoint_timer_query')) { queryTimerVersion = 1; } else { queryTimerVersion = 0; } return queryTimerVersion; } function hasExtension(gl, extensionName) { const ext = gl.getExtension(extensionName); return ext != null; } function isWebGLVersionEnabled(webGLVersion) { try { const gl = getWebGLContext(webGLVersion); if (gl != null) { return true; } } catch (e) { console.log('Error when getting WebGL context: ', e); return false; } return false; } function isCapableOfRenderingToFloatTexture(webGLVersion) { if (webGLVersion === 0) { return false; } const gl = getWebGLContext(webGLVersion); if (webGLVersion === 1) { if (!hasExtension(gl, 'OES_texture_float')) { return false; } } else { if (!hasExtension(gl, 'EXT_color_buffer_float')) { return false; } } const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl); return isFrameBufferComplete; } /** * Check if we can download values from a float/half-float texture. * * Note that for performance reasons we use binding a texture to a framebuffer * as a proxy for ability to download float values later using readPixels. The * texture params of this texture will not match those in readPixels exactly * but if we are unable to bind some kind of float texture to the frameBuffer * then we definitely will not be able to read float values from it. */ function isDownloadFloatTextureEnabled(webGLVersion) { if (webGLVersion === 0) { return false; } const gl = getWebGLContext(webGLVersion); if (webGLVersion === 1) { if (!hasExtension(gl, 'OES_texture_float')) { return false; } if (!hasExtension(gl, 'WEBGL_color_buffer_float')) { return false; } } else { if (hasExtension(gl, 'EXT_color_buffer_float')) { return createFloatTextureAndBindToFramebuffer(gl); } const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float'; if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) { const textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT); return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension); } return false; } const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl); return isFrameBufferComplete; } function createFloatTextureAndBindToFramebuffer(gl) { const texConfig = getTextureConfig(gl); const texture = gl.createTexture(); gl.bindTexture(gl.TEXTURE_2D, texture); const width = 1; const height = 1; gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null); const frameBuffer = gl.createFramebuffer(); gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE; gl.bindTexture(gl.TEXTURE_2D, null); gl.bindFramebuffer(gl.FRAMEBUFFER, null); gl.deleteTexture(texture); gl.deleteFramebuffer(frameBuffer); return isFrameBufferComplete; } function createHalfFloatTextureAndBindToFramebuffer( // tslint:disable-next-line:no-any gl, textureHalfFloatExtension) { const texConfig = getTextureConfig(gl, textureHalfFloatExtension); const texture = gl.createTexture(); gl.bindTexture(gl.TEXTURE_2D, texture); const width = 1; const height = 1; gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null); const frameBuffer = gl.createFramebuffer(); gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE; gl.bindTexture(gl.TEXTURE_2D, null); gl.bindFramebuffer(gl.FRAMEBUFFER, null); gl.deleteTexture(texture); gl.deleteFramebuffer(frameBuffer); return isFrameBufferComplete; } function isWebGLFenceEnabled(webGLVersion) { if (webGLVersion !== 2) { return false; } const gl = getWebGLContext(webGLVersion); // tslint:disable-next-line:no-any const isEnabled = gl.fenceSync != null; return isEnabled; } function assertNotComplex(tensor, opName) { if (!Array.isArray(tensor)) { tensor = [tensor]; } tensor.forEach(t => { if (t != null) { tf.util.assert(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors ` + 'in the WebGL backend.'); } }); } var webgl_util = { __proto__: null, assertNotComplex: assertNotComplex, bindCanvasToFramebuffer: bindCanvasToFramebuffer, bindColorTextureToFramebuffer: bindColorTextureToFramebuffer, bindTextureToProgramUniformSampler: bindTextureToProgramUniformSampler, bindTextureUnit: bindTextureUnit, bindVertexBufferToProgramAttribute: bindVertexBufferToProgramAttribute, callAndCheck: callAndCheck, canBeRepresented: canBeRepresented, createFragmentShader: createFragmentShader, createFramebuffer: createFramebuffer, createProgram: createProgram, createStaticIndexBuffer: createStaticIndexBuffer, createStaticVertexBuffer: createStaticVertexBuffer, createTexture: createTexture, createVertexShader: createVertexShader$1, getBatchDim: getBatchDim, getExtensionOrThrow: getExtensionOrThrow, getFramebufferErrorMessage: getFramebufferErrorMessage, getMaxTexturesInShader: getMaxTexturesInShader, getNumChannels: getNumChannels, getProgramUniformLocation: getProgramUniformLocation, getProgramUniformLocationOrThrow: getProgramUniformLocationOrThrow, getRowsCols: getRowsCols, getShapeAs3D: getShapeAs3D, getTextureShapeFromLogicalShape: getTextureShapeFromLogicalShape, getWebGLDisjointQueryTimerVersion: getWebGLDisjointQueryTimerVersion, getWebGLErrorMessage: getWebGLErrorMessage, getWebGLMaxTextureSize: getWebGLMaxTextureSize, hasExtension: hasExtension, isCapableOfRenderingToFloatTexture: isCapableOfRenderingToFloatTexture, isDownloadFloatTextureEnabled: isDownloadFloatTextureEnabled, isReshapeFree: isReshapeFree, isWebGLFenceEnabled: isWebGLFenceEnabled, isWebGLVersionEnabled: isWebGLVersionEnabled, linkProgram: linkProgram, logShaderSourceAndInfoLog: logShaderSourceAndInfoLog, resetMaxTextureSize: resetMaxTextureSize, resetMaxTexturesInShader: resetMaxTexturesInShader, unbindColorTextureFromFramebuffer: unbindColorTextureFromFramebuffer, unbindTextureUnit: unbindTextureUnit, validateFramebuffer: validateFramebuffer, validateProgram: validateProgram, validateTextureSize: validateTextureSize }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ENV = tf.env(); /** * This file contains WebGL-specific flag registrations. */ /** * True if WebGL is supported. */ ENV.registerFlag('HAS_WEBGL', () => ENV.getNumber('WEBGL_VERSION') > 0); /** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */ ENV.registerFlag('WEBGL_VERSION', () => { if (isWebGLVersionEnabled(2)) { return 2; } else if (isWebGLVersionEnabled(1)) { return 1; } return 0; }); /** Whether to check for numerical representation problems. */ ENV.registerFlag('WEBGL_CHECK_NUMERICAL_PROBLEMS', () => false); ENV.registerFlag('WEBGL_BUFFER_SUPPORTED', () => ENV.get('WEBGL_VERSION') === 2); /** Whether the WebGL backend will sometimes forward ops to the CPU. */ ENV.registerFlag('WEBGL_CPU_FORWARD', () => true); /** Whether the WebGL backend will always use f16 textures for rendering. */ ENV.registerFlag('WEBGL_FORCE_F16_TEXTURES', () => false); /** Whether to turn all packing related flags on. */ ENV.registerFlag('WEBGL_PACK', () => ENV.getBool('HAS_WEBGL')); /** Whether we will pack the batchnormalization op. */ ENV.registerFlag('WEBGL_PACK_NORMALIZATION', () => ENV.getBool('WEBGL_PACK')); /** Whether we will pack the clip op. */ ENV.registerFlag('WEBGL_PACK_CLIP', () => ENV.getBool('WEBGL_PACK')); /** Whether we will pack the depthwise conv op. */ ENV.registerFlag('WEBGL_PACK_DEPTHWISECONV', () => ENV.getBool('WEBGL_PACK')); /** Whether we will pack binary ops. */ ENV.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK')); /** Whether we will pack unary ops. */ ENV.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK')); /** Whether we will pack array ops. */ ENV.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', () => ENV.getBool('WEBGL_PACK')); /** Whether we will pack image ops. */ ENV.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', () => ENV.getBool('WEBGL_PACK')); /** Whether we will pack reduce ops. */ ENV.registerFlag('WEBGL_PACK_REDUCE', () => ENV.getBool('WEBGL_PACK')); /** Whether packed WebGL kernels lazily unpack their outputs. */ ENV.registerFlag('WEBGL_LAZILY_UNPACK', () => ENV.getBool('WEBGL_PACK')); /** Whether we will use the im2col algorithm to speed up convolutions. */ ENV.registerFlag('WEBGL_CONV_IM2COL', () => ENV.getBool('WEBGL_PACK')); /** Whether we will pack conv2dTranspose op. */ ENV.registerFlag('WEBGL_PACK_CONV2DTRANSPOSE', () => ENV.getBool('WEBGL_PACK')); /** The maximum texture dimension. */ ENV.registerFlag('WEBGL_MAX_TEXTURE_SIZE', () => getWebGLMaxTextureSize(ENV.getNumber('WEBGL_VERSION'))); /** The maximum texture dimension. */ ENV.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', () => getMaxTexturesInShader(ENV.getNumber('WEBGL_VERSION'))); /** * The disjoint_query_timer extension version. * 0: disabled, 1: EXT_disjoint_timer_query, 2: * EXT_disjoint_timer_query_webgl2. * In Firefox with WebGL 2.0, * EXT_disjoint_timer_query_webgl2 is not available, so we must use the * WebGL 1.0 extension. */ ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', () => { const webGLVersion = ENV.getNumber('WEBGL_VERSION'); if (webGLVersion === 0) { return 0; } return getWebGLDisjointQueryTimerVersion(webGLVersion); }); /** * Whether the timer object from the disjoint_query_timer extension gives * timing information that is reliable. */ ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', () => ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 && !tf.device_util.isMobile()); /** * Whether the device is physically capable of rendering to float32 textures. */ ENV.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', () => isCapableOfRenderingToFloatTexture(ENV.getNumber('WEBGL_VERSION'))); /** * Whether rendering to float32 textures is enabled. If disabled, renders to * float16 textures. */ ENV.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', () => { return ENV.getBool('WEBGL_FORCE_F16_TEXTURES') ? false : ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE'); }); /** * Whether downloading float textures is enabled (16 or 32 bit). If disabled, * uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading. */ ENV.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', () => isDownloadFloatTextureEnabled(ENV.getNumber('WEBGL_VERSION'))); /** Whether the fence API is available. */ ENV.registerFlag('WEBGL_FENCE_API_ENABLED', () => isWebGLFenceEnabled(ENV.getNumber('WEBGL_VERSION'))); /** * Tensors with size <= than this will be uploaded as uniforms, not textures. */ ENV.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', () => { // Use uniform uploads only when 32bit floats are supported. In // 16bit // environments there are problems with comparing a 16bit texture value // with a 32bit uniform value. const useUniforms = ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED'); return useUniforms ? 4 : 0; }); /** * If the total number of bytes allocated on the GPU is greater than this * number, we will aggressively delete textures upon disposal with * gl.deleteMatrixTexture, rather than making them available for reuse. * * Default value -1 indicates that we will never aggressively delete textures. */ ENV.registerFlag('WEBGL_DELETE_TEXTURE_THRESHOLD', () => { return -1; }, threshold => { if (!(typeof threshold === 'number')) { throw new Error('WEBGL_DELETE_TEXTURE_THRESHOLD must be a number but ' + `got ${threshold}.`); } if (threshold < 0 && threshold !== -1) { throw new Error(`WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never ` + `delete) or at least 0, but got ${threshold}.`); } }); /** * Trigger a manual GL command flush if the threshold of time has passed since * previous Kernel execution. This can be useful for Andorid device where GL * command flush are delayed un til the end of javascript task. This value is * measured in millisecond. Typically you want to set this value to close to 1. * * Default value 1 for mobile chrome, and -1 for rest cases. -1 indicates that * we will not enforce manual flush and depend on system default flush schedule. */ ENV.registerFlag('WEBGL_FLUSH_THRESHOLD', () => { return tf.device_util.isMobile() ? 1 : -1; }, threshold => { if (!(typeof threshold === 'number')) { throw new Error('WEBGL_FLUSH_THRESHOLD must be a number but got ' + `${threshold}.`); } if (threshold < 0 && threshold !== -1) { throw new Error(`WEBGL_FLUSH_THRESHOLD must be -1 (indicating never ` + `manual flush) or at least 0, but got ${threshold}.`); } }); /** * Threshold for input tensor size that determines whether WebGL backend will * delegate computation to CPU. * * Default value is 128. */ ENV.registerFlag('CPU_HANDOFF_SIZE_THRESHOLD', () => 128); /** Whether we will use shapes uniforms. */ ENV.registerFlag('WEBGL_USE_SHAPES_UNIFORMS', () => false); /** * Threshold for last dimension of input tensor that determines whether * WebGL backend for the Top K op will delegate computation to CPU. If input * is smaller than threshold then CPU will be used * * Default value is 100000. */ ENV.registerFlag('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD', () => 100000); /** * Threshold for K that determines whether * WebGL backend for the Top K op will delegate computation to CPU. If k * is larger than threshold then CPU will be used * * Default value is 128. */ ENV.registerFlag('TOPK_K_CPU_HANDOFF_THRESHOLD', () => 128); /** Whether we will use the experimental conv op. */ ENV.registerFlag('WEBGL_EXP_CONV', () => false); /** * If the device performance is low or if no hardware GPU is available, whether * software WebGL will be used. */ ENV.registerFlag('SOFTWARE_WEBGL_ENABLED', () => ENV.getBool('IS_TEST')); /** * For narrow texture (physical height or physical width is 1), if the length of * any texture edges exceed the threshold, the texture will be reshaped to be * more squarish. * * This flag is used to help some GPUs that could not provide correct * interpolations for long skinny triangles. We found Mali GPU probably has this * problem: https://github.com/tensorflow/tfjs/issues/6775. */ ENV.registerFlag('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', () => Infinity); /** * If the flag is set to true, the max size of the narrow texture will be auto * computed and it will be considerred as a threshold to reshape the narrow * texture to be more squarish. * * This flag is used to help some GPUs that could not provide correct * interpolations for long skinny triangles. We found Mali GPU probably has this * problem: https://github.com/tensorflow/tfjs/issues/6775. */ ENV.registerFlag('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', () => false); /** * Whether to use the customized isnan. It's only useful for webgl2 since webgl1 * doesn't have the builtin isnan. */ ENV.registerFlag('WEBGL2_ISNAN_CUSTOM', () => false); /** Experimental flag, whether enter compile only phase. */ ENV.registerFlag('ENGINE_COMPILE_ONLY', () => false); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function getGlslDifferences() { let version; let attribute; let varyingVs; let varyingFs; let texture2D; let output; let defineOutput; let defineSpecialNaN; let defineSpecialInf; let defineRound; if (tf.env().getNumber('WEBGL_VERSION') === 2) { version = '#version 300 es'; attribute = 'in'; varyingVs = 'out'; varyingFs = 'in'; texture2D = 'texture'; output = 'outputColor'; defineOutput = 'out vec4 outputColor;'; // Use custom isnan definition to work across differences between // implementations on various platforms. While this should happen in ANGLE // we still see differences between android and windows (on chrome) when // using isnan directly. Since WebGL2 supports uint type and // floatBitsToUinT built-in function, we could implment isnan following // IEEE 754 rules. // NaN defination in IEEE 754-1985 is : // - sign = either 0 or 1. // - biased exponent = all 1 bits. // - fraction = anything except all 0 bits (since all 0 bits represents // infinity). // https://en.wikipedia.org/wiki/IEEE_754-1985#Representation_of_non-numbers defineSpecialNaN = tf.env().getBool('WEBGL2_ISNAN_CUSTOM') ? ` bool isnan_custom(float val) { uint floatToUint = floatBitsToUint(val); return (floatToUint & 0x7fffffffu) > 0x7f800000u; } bvec4 isnan_custom(vec4 val) { return bvec4(isnan_custom(val.x), isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w)); } #define isnan(value) isnan_custom(value) ` : ''; // In webgl 2 we do not need to specify a custom isinf so there is no // need for a special INFINITY constant. defineSpecialInf = ``; defineRound = ` #define round(value) newRound(value) int newRound(float value) { return int(floor(value + 0.5)); } ivec4 newRound(vec4 value) { return ivec4(floor(value + vec4(0.5))); } `; } else { version = ''; attribute = 'attribute'; varyingVs = 'varying'; varyingFs = 'varying'; texture2D = 'texture2D'; output = 'gl_FragColor'; defineOutput = ''; // WebGL1 has no built in isnan so we define one here. defineSpecialNaN = ` #define isnan(value) isnan_custom(value) bool isnan_custom(float val) { return (val > 0. || val < 1. || val == 0.) ? false : true; } bvec4 isnan_custom(vec4 val) { return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w)); } `; defineSpecialInf = ` uniform float INFINITY; bool isinf(float val) { return abs(val) == INFINITY; } bvec4 isinf(vec4 val) { return equal(abs(val), vec4(INFINITY)); } `; defineRound = ` int round(float value) { return int(floor(value + 0.5)); } ivec4 round(vec4 value) { return ivec4(floor(value + vec4(0.5))); } `; } return { version, attribute, varyingVs, varyingFs, texture2D, output, defineOutput, defineSpecialNaN, defineSpecialInf, defineRound }; } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Produces GLSL code that derives logical coordinates from a flat * index. The code performs integer division with each stride and decrements * the index until the index equals the final dimension coordinate. */ function getLogicalCoordinatesFromFlatIndex(coords, shape, index = 'index') { const strides = tf.util.computeStrides(shape); return strides .map((stride, i) => { const line1 = `int ${coords[i]} = ${index} / ${stride}`; const line2 = i === strides.length - 1 ? `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` : `index -= ${coords[i]} * ${stride}`; return `${line1}; ${line2};`; }) .join(''); } function getOutputLogicalCoordinatesFromFlatIndexByUniform(coords, shape, index = 'index') { const strides = tf.util.computeStrides(shape); return strides .map((_, i) => { const line1 = `int ${coords[i]} = ${index} / outShapeStrides[${i}]`; const line2 = i === strides.length - 1 ? `int ${coords[i + 1]} = ${index} - ${coords[i]} * outShapeStrides[${i}]` : `index -= ${coords[i]} * outShapeStrides[${i}]`; return `${line1}; ${line2};`; }) .join(''); } // Produces GLSL code that computes strides. function symbolicallyComputeStrides(indicesArr, variableName) { const numCoords = indicesArr.length; const shape = indicesArr.map(d => `${variableName}[${d}]`); const strides = new Array(numCoords - 1); strides[numCoords - 2] = shape[numCoords - 1]; for (let i = numCoords - 3; i >= 0; --i) { strides[i] = `(${strides[i + 1]} * ${shape[i + 1]})`; } return strides; } function getLogicalCoordinatesFromFlatIndexByUniform(coords, variableName, index = 'index') { const indicesArray = coords.map((_, i) => i); const strides = symbolicallyComputeStrides(indicesArray, variableName); return strides .map((_, i) => { const line1 = `int ${coords[i]} = ${index} / ${strides[i]}`; const line2 = i === strides.length - 1 ? `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${strides[i]}` : `index -= ${coords[i]} * ${strides[i]}`; return `${line1}; ${line2};`; }) .join(''); } /** * Produces GLSL that computes the flat index from 3D coordinates. */ function getFlatIndexFrom3D(shape) { const strides = tf.util.computeStrides(shape).map(d => d.toString()); return ` int getFlatIndex(ivec3 coords) { return coords.x * ${strides[0]} + coords.y * ${strides[1]} + coords.z; } `; } function getFlatIndexFrom3DOutput() { return ` int getFlatIndex(ivec3 coords) { return coords.x * outShapeStrides[0] + coords.y * outShapeStrides[1] + coords.z; } `; } const ENCODE_FLOAT_SNIPPET = ` const float FLOAT_MAX = 1.70141184e38; const float FLOAT_MIN = 1.17549435e-38; lowp vec4 encode_float(highp float v) { if (isnan(v)) { return vec4(255, 255, 255, 255); } highp float av = abs(v); if(av < FLOAT_MIN) { return vec4(0.0, 0.0, 0.0, 0.0); } else if(v > FLOAT_MAX) { return vec4(0.0, 0.0, 128.0, 127.0) / 255.0; } else if(v < -FLOAT_MAX) { return vec4(0.0, 0.0, 128.0, 255.0) / 255.0; } highp vec4 c = vec4(0,0,0,0); highp float e = floor(log2(av)); highp float m = exp2(fract(log2(av))) - 1.0; c[2] = floor(128.0 * m); m -= c[2] / 128.0; c[1] = floor(32768.0 * m); m -= c[1] / 32768.0; c[0] = floor(8388608.0 * m); highp float ebias = e + 127.0; c[3] = floor(ebias / 2.0); ebias -= c[3] * 2.0; c[2] += floor(ebias) * 128.0; c[3] += 128.0 * step(0.0, -v); return c / 255.0; } `; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const { getBroadcastDims } = tf.backend_util; function makeShader(inputsInfo, outputShape, program) { const prefixSnippets = []; inputsInfo.forEach(x => { const size = tf.util.sizeFromShape(x.shapeInfo.logicalShape); // Snippet when we decided to upload the values as uniform. if (x.shapeInfo.isUniform) { prefixSnippets.push(`uniform float ${x.name}${size > 1 ? `[${size}]` : ''};`); } else { prefixSnippets.push(`uniform sampler2D ${x.name};`); prefixSnippets.push(`uniform int offset${x.name};`); } if (program.enableShapeUniforms) { const { uniformShape } = getUniformInfoFromShape(program.packedInputs, x.shapeInfo.logicalShape, x.shapeInfo.texShape); switch (uniformShape.length) { case 1: prefixSnippets.push(`uniform int ${x.name}Shape;`); break; case 2: prefixSnippets.push(`uniform ivec2 ${x.name}Shape;`); break; case 3: prefixSnippets.push(`uniform ivec3 ${x.name}Shape;`); break; case 4: prefixSnippets.push(`uniform ivec4 ${x.name}Shape;`); break; } prefixSnippets.push(`uniform ivec2 ${x.name}TexShape;`); } }); if (program.enableShapeUniforms) { switch (outputShape.logicalShape.length) { case 1: prefixSnippets.push(`uniform int outShape;`); break; case 2: prefixSnippets.push(`uniform ivec2 outShape;`); prefixSnippets.push(`uniform int outShapeStrides;`); break; case 3: prefixSnippets.push(`uniform ivec3 outShape;`); prefixSnippets.push(`uniform ivec2 outShapeStrides;`); break; case 4: prefixSnippets.push(`uniform ivec4 outShape;`); prefixSnippets.push(`uniform ivec3 outShapeStrides;`); break; } prefixSnippets.push(`uniform ivec2 outTexShape;`); } if (program.customUniforms) { program.customUniforms.forEach((d) => { prefixSnippets.push(`uniform ${d.type} ${d.name}${d.arrayIndex ? `[${d.arrayIndex}]` : ''};`); }); } const inputPrefixSnippet = prefixSnippets.join('\n'); const inputSamplingSnippet = inputsInfo .map(x => getInputSamplingSnippet(x, outputShape, program.packedInputs, program.enableShapeUniforms)) .join('\n'); const outTexShape = outputShape.texShape; const glsl = getGlslDifferences(); const floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl); let outputSamplingSnippet; let floatTextureSetOutputSnippet; let shaderPrefix = getShaderPrefix(glsl); if (outputShape.isPacked) { outputSamplingSnippet = getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms); floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl); } else { outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms); floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl); } if (program.packedInputs) { shaderPrefix += SHADER_PACKED_PREFIX; } const source = [ shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet, inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, program.userCode ].join('\n'); return source; } function getSamplerFromInInfo(inInfo, enableShapeUniforms = false) { const shape = inInfo.shapeInfo.logicalShape; switch (shape.length) { case 0: return getSamplerScalar(inInfo, enableShapeUniforms); case 1: return getSampler1D(inInfo, enableShapeUniforms); case 2: return getSampler2D(inInfo, enableShapeUniforms); case 3: return getSampler3D(inInfo, enableShapeUniforms); case 4: return getSampler4D(inInfo, enableShapeUniforms); case 5: return getSampler5D(inInfo); case 6: return getSampler6D(inInfo); default: throw new Error(`${shape.length}-D input sampling` + ` is not yet supported`); } } function getPackedSamplerFromInInfo(inInfo, enableShapeUniforms) { const shape = inInfo.shapeInfo.logicalShape; switch (shape.length) { case 0: return getPackedSamplerScalar(inInfo); case 1: return getPackedSampler1D(inInfo, enableShapeUniforms); case 2: return getPackedSampler2D(inInfo, enableShapeUniforms); case 3: return getPackedSampler3D(inInfo, enableShapeUniforms); default: return getPackedSamplerND(inInfo, enableShapeUniforms); } } function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures = false, enableShapeUniforms) { let res = ''; if (usesPackedTextures) { res += getPackedSamplerFromInInfo(inInfo, enableShapeUniforms); } else { res += getSamplerFromInInfo(inInfo, enableShapeUniforms); } const inShape = inInfo.shapeInfo.logicalShape; const outShape = outShapeInfo.logicalShape; if (inShape.length <= outShape.length) { if (usesPackedTextures) { res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo); } else { res += getSamplerAtOutputCoords(inInfo, outShapeInfo); } } return res; } function getPackedOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) { switch (outShape.length) { case 0: return getOutputScalarCoords(); case 1: return getOutputPacked1DCoords(outShape, outTexShape, enableShapeUniforms); case 2: return getOutputPacked2DCoords(outShape, outTexShape, enableShapeUniforms); case 3: return getOutputPacked3DCoords(outShape, outTexShape, enableShapeUniforms); default: return getOutputPackedNDCoords(outShape, outTexShape, enableShapeUniforms); } } function getOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) { switch (outShape.length) { case 0: return getOutputScalarCoords(); case 1: return getOutput1DCoords(outShape, outTexShape, enableShapeUniforms); case 2: return getOutput2DCoords(outShape, outTexShape, enableShapeUniforms); case 3: return getOutput3DCoords(outShape, outTexShape, enableShapeUniforms); case 4: return getOutput4DCoords(outShape, outTexShape, enableShapeUniforms); case 5: return getOutput5DCoords(outShape, outTexShape); case 6: return getOutput6DCoords(outShape, outTexShape); default: throw new Error(`${outShape.length}-D output sampling is not yet supported`); } } function getFloatTextureSampleSnippet(glsl) { return ` float sampleTexture(sampler2D textureSampler, vec2 uv) { return ${glsl.texture2D}(textureSampler, uv).r; } `; } function getFloatTextureSetRSnippet(glsl) { return ` void setOutput(float val) { ${glsl.output} = vec4(val, 0, 0, 0); } `; } function getFloatTextureSetRGBASnippet(glsl) { return ` void setOutput(vec4 val) { ${glsl.output} = val; } `; } function getShaderPrefix(glsl) { const SHADER_PREFIX = `${glsl.version} precision highp float; precision highp int; precision highp sampler2D; ${glsl.varyingFs} vec2 resultUV; ${glsl.defineOutput} const vec2 halfCR = vec2(0.5, 0.5); struct ivec5 { int x; int y; int z; int w; int u; }; struct ivec6 { int x; int y; int z; int w; int u; int v; }; uniform float NAN; ${glsl.defineSpecialNaN} ${glsl.defineSpecialInf} ${glsl.defineRound} int imod(int x, int y) { return x - y * (x / y); } int idiv(int a, int b, float sign) { int res = a / b; int mod = imod(a, b); if (sign < 0. && mod != 0) { res -= 1; } return res; } //Based on the work of Dave Hoskins //https://www.shadertoy.com/view/4djSRW #define HASHSCALE1 443.8975 float random(float seed){ vec2 p = resultUV * seed; vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1); p3 += dot(p3, p3.yzx + 19.19); return fract((p3.x + p3.y) * p3.z); } ${SAMPLE_1D_SNIPPET} ${SAMPLE_2D_SNIPPET} ${SAMPLE_3D_SNIPPET} `; return SHADER_PREFIX; } const SAMPLE_1D_SNIPPET = ` vec2 uvFromFlat(int texNumR, int texNumC, int index) { int texR = index / texNumC; int texC = index - texR * texNumC; return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR); } vec2 packedUVfrom1D(int texNumR, int texNumC, int index) { int texelIndex = index / 2; int texR = texelIndex / texNumC; int texC = texelIndex - texR * texNumC; return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR); } `; const SAMPLE_2D_SNIPPET = ` vec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR, int texNumC, int row, int col) { int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2); int texR = texelIndex / texNumC; int texC = texelIndex - texR * texNumC; return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR); } `; const SAMPLE_3D_SNIPPET = ` vec2 packedUVfrom3D(int texNumR, int texNumC, int texelsInBatch, int texelsInLogicalRow, int b, int row, int col) { int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2); int texR = index / texNumC; int texC = index - texR * texNumC; return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR); } `; const SHADER_PACKED_PREFIX = ` float getChannel(vec4 frag, vec2 innerDims) { vec2 modCoord = mod(innerDims, 2.); return modCoord.x == 0. ? (modCoord.y == 0. ? frag.r : frag.g) : (modCoord.y == 0. ? frag.b : frag.a); } float getChannel(vec4 frag, int dim) { float modCoord = mod(float(dim), 2.); return modCoord == 0. ? frag.r : frag.g; } `; function getOutputScalarCoords() { return ` int getOutputCoords() { return 0; } `; } function getOutputPacked1DCoords(shape, texShape, enableShapeUniforms) { const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; if (packedTexShape[0] === 1) { if (enableShapeUniforms) { return ` int getOutputCoords() { return 2 * int(resultUV.x * ceil(float(outTexShape[1]) / 2.0)); } `; } return ` int getOutputCoords() { return 2 * int(resultUV.x * ${packedTexShape[1]}.0); } `; } if (packedTexShape[1] === 1) { if (enableShapeUniforms) { return ` int getOutputCoords() { return 2 * int(resultUV.y * ceil(float(outTexShape[0]) / 2.0)); } `; } return ` int getOutputCoords() { return 2 * int(resultUV.y * ${packedTexShape[0]}.0); } `; } if (enableShapeUniforms) { return ` int getOutputCoords() { ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0)); ivec2 resTexRC = ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1])); return 2 * (resTexRC.x * packedTexShape[1] + resTexRC.y); } `; } return ` int getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]})); return 2 * (resTexRC.x * ${packedTexShape[1]} + resTexRC.y); } `; } function getOutput1DCoords(shape, texShape, enableShapeUniforms) { if (texShape[0] === 1) { if (enableShapeUniforms) { return ` int getOutputCoords() { return int(resultUV.x * float(outTexShape[1])); } `; } return ` int getOutputCoords() { return int(resultUV.x * ${texShape[1]}.0); } `; } if (texShape[1] === 1) { if (enableShapeUniforms) { return ` int getOutputCoords() { return int(resultUV.y * float(outTexShape[0])); } `; } return ` int getOutputCoords() { return int(resultUV.y * ${texShape[0]}.0); } `; } if (enableShapeUniforms) { return ` int getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); return resTexRC.x * outTexShape[1] + resTexRC.y; } `; } return ` int getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); return resTexRC.x * ${texShape[1]} + resTexRC.y; } `; } function getOutputPacked3DCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { return ` ivec3 getOutputCoords() { ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0)); int texelsInLogicalRow = int(ceil(float(outShape[2]) / 2.0)); int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[1]) / 2.0)); ivec2 resTexRC = ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1])); int index = resTexRC.x * packedTexShape[1] + resTexRC.y; int b = index / texelsInBatch; index -= b * texelsInBatch; int r = 2 * (index / texelsInLogicalRow); int c = imod(index, texelsInLogicalRow) * 2; return ivec3(b, r, c); } `; } const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; const texelsInLogicalRow = Math.ceil(shape[2] / 2); const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2); return ` ivec3 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]})); int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y; int b = index / ${texelsInBatch}; index -= b * ${texelsInBatch}; int r = 2 * (index / ${texelsInLogicalRow}); int c = imod(index, ${texelsInLogicalRow}) * 2; return ivec3(b, r, c); } `; } function getOutput3DCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], shape); return ` ivec3 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); int index = resTexRC.x * outTexShape[1] + resTexRC.y; ${coordsFromIndexSnippet} return ivec3(r, c, d); } `; } const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape); return ` ivec3 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; ${coordsFromIndexSnippet} return ivec3(r, c, d); } `; } function getOutputPackedNDCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { // TODO: support 5d and 6d return ` ivec4 getOutputCoords() { ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0)); ivec2 resTexRC = ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1])); int index = resTexRC.x * packedTexShape[1] + resTexRC.y; int texelsInLogicalRow = int(ceil(float(outShape[3]) / 2.0)); int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[2]) / 2.0)); int texelsInBatchN = texelsInBatch * outShape[1]; int b2 = index / texelsInBatchN; index -= b2 * texelsInBatchN; int b = index / texelsInBatch; index -= b * texelsInBatch; int r = 2 * (index / texelsInLogicalRow); int c = imod(index, texelsInLogicalRow) * 2; return ivec4(b2, b, r, c); } `; } const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; const texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2); const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2); let texelsInBatchN = texelsInBatch; let batches = ``; let coords = 'b, r, c'; for (let b = 2; b < shape.length - 1; b++) { texelsInBatchN *= shape[shape.length - b - 1]; batches = ` int b${b} = index / ${texelsInBatchN}; index -= b${b} * ${texelsInBatchN}; ` + batches; coords = `b${b}, ` + coords; } return ` ivec${shape.length} getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]})); int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y; ${batches} int b = index / ${texelsInBatch}; index -= b * ${texelsInBatch}; int r = 2 * (index / ${texelsInLogicalRow}); int c = imod(index, ${texelsInLogicalRow}) * 2; return ivec${shape.length}(${coords}); } `; } function getOutput4DCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd', 'd2'], shape); return ` ivec4 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); int index = resTexRC.x * outTexShape[1] + resTexRC.y; ${coordsFromIndexSnippet} return ivec4(r, c, d, d2); } `; } const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape); return ` ivec4 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; ${coordsFromIndexSnippet} return ivec4(r, c, d, d2); } `; } function getOutput5DCoords(shape, texShape) { const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape); return ` ivec5 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; ${coordsFromIndexSnippet} ivec5 outShape = ivec5(r, c, d, d2, d3); return outShape; } `; } function getOutput6DCoords(shape, texShape) { const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape); return ` ivec6 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; ${coordsFromIndexSnippet} ivec6 result = ivec6(r, c, d, d2, d3, d4); return result; } `; } function getOutputPacked2DCoords(shape, texShape, enableShapeUniforms) { const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; if (tf.util.arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0)); return 2 * ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1])); } `; } return ` ivec2 getOutputCoords() { return 2 * ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]})); } `; } // texels needed to accommodate a logical row const texelsInLogicalRow = Math.ceil(shape[1] / 2); /** * getOutputCoords * * resTexRC: The rows and columns of the texels. If you move over one * texel to the right in the packed texture, you are moving over one column * (not two). * * index: The texel index */ if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0)); int texelsInLogicalRow = int(ceil(float(outShape[1]) / 2.0)); ivec2 resTexRC = ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1])); int index = resTexRC.x * packedTexShape[1] + resTexRC.y; int r = 2 * (index / texelsInLogicalRow); int c = imod(index, texelsInLogicalRow) * 2; return ivec2(r, c); } `; } return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]})); int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y; int r = 2 * (index / ${texelsInLogicalRow}); int c = imod(index, ${texelsInLogicalRow}) * 2; return ivec2(r, c); } `; } function getOutput2DCoords(shape, texShape, enableShapeUniforms) { if (tf.util.arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { return ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); } `; } return ` ivec2 getOutputCoords() { return ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); } `; } if (shape[1] === 1) { if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); int index = resTexRC.x * outTexShape[1] + resTexRC.y; return ivec2(index, 0); } `; } return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; return ivec2(index, 0); } `; } if (shape[0] === 1) { if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); int index = resTexRC.x * outTexShape[1] + resTexRC.y; return ivec2(0, index); } `; } return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; return ivec2(0, index); } `; } if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); int index = resTexRC.x * outTexShape[1] + resTexRC.y; int r = index / outShape[1]; int c = index - r * outShape[1]; return ivec2(r, c); } `; } return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; int r = index / ${shape[1]}; int c = index - r * ${shape[1]}; return ivec2(r, c); } `; } function getFlatOffsetUniformName(texName) { return `offset${texName}`; } function getPackedSamplerScalar(inputInfo) { const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const glsl = getGlslDifferences(); return ` vec4 ${funcName}() { return ${glsl.texture2D}(${texName}, halfCR); } `; } function getSamplerScalar(inputInfo, enableShapeUniforms) { const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); if (inputInfo.shapeInfo.isUniform) { return `float ${funcName}() {return ${texName};}`; } const [texNumR, texNumC] = inputInfo.shapeInfo.texShape; if (texNumR === 1 && texNumC === 1) { return ` float ${funcName}() { return sampleTexture(${texName}, halfCR); } `; } const offset = getFlatOffsetUniformName(texName); if (enableShapeUniforms) { return ` float ${funcName}() { vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], ${offset}); return sampleTexture(${texName}, uv); } `; } const [tNumR, tNumC] = inputInfo.shapeInfo.texShape; return ` float ${funcName}() { vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, ${offset}); return sampleTexture(${texName}, uv); } `; } function getPackedSampler1D(inputInfo, enableShapeUniforms) { const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const texShape = inputInfo.shapeInfo.texShape; const glsl = getGlslDifferences(); if (enableShapeUniforms) { return ` vec4 ${funcName}(int index) { ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0)); vec2 uv = packedUVfrom1D( packedTexShape[0], packedTexShape[1], index); return ${glsl.texture2D}(${texName}, uv); } `; } const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; return ` vec4 ${funcName}(int index) { vec2 uv = packedUVfrom1D( ${packedTexShape[0]}, ${packedTexShape[1]}, index); return ${glsl.texture2D}(${texName}, uv); } `; } function getSampler1D(inputInfo, enableShapeUniforms) { const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return ` float ${funcName}(int index) { ${getUniformSampler(inputInfo)} } `; } const texShape = inputInfo.shapeInfo.texShape; const tNumR = texShape[0]; const tNumC = texShape[1]; if (tNumC === 1 && tNumR === 1) { return ` float ${funcName}(int index) { return sampleTexture(${texName}, halfCR); } `; } const offset = getFlatOffsetUniformName(texName); if (tNumC === 1) { if (enableShapeUniforms) { return ` float ${funcName}(int index) { vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / float(${texName}TexShape[0])); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int index) { vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / ${tNumR}.0); return sampleTexture(${texName}, uv); } `; } if (tNumR === 1) { if (enableShapeUniforms) { return ` float ${funcName}(int index) { vec2 uv = vec2((float(index + ${offset}) + 0.5) / float(${texName}TexShape[1]), 0.5); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int index) { vec2 uv = vec2((float(index + ${offset}) + 0.5) / ${tNumC}.0, 0.5); return sampleTexture(${texName}, uv); } `; } if (enableShapeUniforms) { return ` float ${funcName}(int index) { vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset}); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int index) { vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, index + ${offset}); return sampleTexture(${texName}, uv); } `; } function getPackedSampler2D(inputInfo, enableShapeUniforms) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const texShape = inputInfo.shapeInfo.texShape; const texNumR = texShape[0]; const texNumC = texShape[1]; const glsl = getGlslDifferences(); if (texShape != null && tf.util.arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return ` vec4 ${funcName}(int row, int col) { vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return ${glsl.texture2D}(${texName}, uv); } `; } return ` vec4 ${funcName}(int row, int col) { vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return ${glsl.texture2D}(${texName}, uv); } `; } if (enableShapeUniforms) { return ` vec4 ${funcName}(int row, int col) { ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0)); int valuesPerRow = int(ceil(float(${texName}Shape[1]) / 2.0)); vec2 uv = packedUVfrom2D(valuesPerRow, packedTexShape[0], packedTexShape[1], row, col); return ${glsl.texture2D}(${texName}, uv); } `; } const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; const valuesPerRow = Math.ceil(shape[1] / 2); return ` vec4 ${funcName}(int row, int col) { vec2 uv = packedUVfrom2D(${valuesPerRow}, ${packedTexShape[0]}, ${packedTexShape[1]}, row, col); return ${glsl.texture2D}(${texName}, uv); } `; } function getSampler2D(inputInfo, enableShapeUniforms) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const texShape = inputInfo.shapeInfo.texShape; if (texShape != null && tf.util.arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return ` float ${funcName}(int row, int col) { vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return sampleTexture(${texName}, uv); } `; } const texNumR = texShape[0]; const texNumC = texShape[1]; return ` float ${funcName}(int row, int col) { vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } const { newShape, keptDims } = tf.util.squeezeShape(shape); const squeezedShape = newShape; if (squeezedShape.length < shape.length) { const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); const params = ['row', 'col']; return ` ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)} float ${funcName}(int row, int col) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return ` float ${funcName}(int row, int col) { int index = round(dot(vec2(row, col), vec2(${shape[1]}, 1))); ${getUniformSampler(inputInfo)} } `; } const texNumR = texShape[0]; const texNumC = texShape[1]; const offset = getFlatOffsetUniformName(texName); if (texNumC === 1) { // index is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return ` float ${funcName}(int row, int col) { float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1)); vec2 uv = vec2(0.5, (index + 0.5) / float(${texName}TexShape[0])); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col) { float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1)); vec2 uv = vec2(0.5, (index + 0.5) / ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } if (texNumR === 1) { // index is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return ` float ${funcName}(int row, int col) { float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1)); vec2 uv = vec2((index + 0.5) / float(${texName}TexShape[1]), 0.5); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col) { float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1)); vec2 uv = vec2((index + 0.5) / ${texNumC}.0, 0.5); return sampleTexture(${texName}, uv); } `; } if (enableShapeUniforms) { return ` float ${funcName}(int row, int col) { // Explicitly use integer operations as dot() only works on floats. int index = row * ${texName}Shape[1] + col + ${offset}; vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col) { // Explicitly use integer operations as dot() only works on floats. int index = row * ${shape[1]} + col + ${offset}; vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index); return sampleTexture(${texName}, uv); } `; } function getPackedSampler3D(inputInfo, enableShapeUniforms) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const texShape = inputInfo.shapeInfo.texShape; const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; if (shape[0] === 1) { const squeezedShape = shape.slice(1); const keptDims = [1, 2]; const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); const params = ['b', 'row', 'col']; return ` ${getPackedSamplerFromInInfo(newInputInfo, enableShapeUniforms)} vec4 ${funcName}(int b, int row, int col) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } const glsl = getGlslDifferences(); if (enableShapeUniforms) { return ` vec4 ${funcName}(int b, int row, int col) { ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0)); int valuesPerRow = int(ceil(float(${texName}Shape[2]) / 2.0)); int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[1]) / 2.0)); vec2 uv = packedUVfrom3D( packedTexShape[0], packedTexShape[1], texelsInBatch, valuesPerRow, b, row, col); return ${glsl.texture2D}(${texName}, uv); } `; } const texNumR = packedTexShape[0]; const texNumC = packedTexShape[1]; const valuesPerRow = Math.ceil(shape[2] / 2); const texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2); return ` vec4 ${funcName}(int b, int row, int col) { vec2 uv = packedUVfrom3D( ${texNumR}, ${texNumC}, ${texelsInBatch}, ${valuesPerRow}, b, row, col); return ${glsl.texture2D}(${texName}, uv); } `; } function getSampler3D(inputInfo, enableShapeUniforms) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const stride0 = shape[1] * shape[2]; const stride1 = shape[2]; const { newShape, keptDims } = tf.util.squeezeShape(shape); const squeezedShape = newShape; if (squeezedShape.length < shape.length) { const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); const params = ['row', 'col', 'depth']; return ` ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)} float ${funcName}(int row, int col, int depth) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return ` float ${funcName}(int row, int col, int depth) { int index = round(dot(vec3(row, col, depth), vec3(${stride0}, ${stride1}, 1))); ${getUniformSampler(inputInfo)} } `; } const texShape = inputInfo.shapeInfo.texShape; const texNumR = texShape[0]; const texNumC = texShape[1]; const flatOffset = inputInfo.shapeInfo.flatOffset; if (texNumC === stride0 && flatOffset == null) { // texC is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth) { int stride1 = ${texName}Shape[2]; float texR = float(row); float texC = dot(vec2(col, depth), vec2(stride1, 1)); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth) { float texR = float(row); float texC = dot(vec2(col, depth), vec2(${stride1}, 1)); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } if (texNumC === stride1 && flatOffset == null) { // texR is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth) { float texR = dot(vec2(row, col), vec2(${texName}Shape[1], 1)); float texC = float(depth); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth) { float texR = dot(vec2(row, col), vec2(${shape[1]}, 1)); float texC = float(depth); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } const offset = getFlatOffsetUniformName(texName); if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth) { // Explicitly use integer operations as dot() only works on floats. int stride0 = ${texName}Shape[1] * ${texName}Shape[2]; int stride1 = ${texName}Shape[2]; int index = row * stride0 + col * stride1 + depth + ${offset}; vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth) { // Explicitly use integer operations as dot() only works on floats. int index = row * ${stride0} + col * ${stride1} + depth + ${offset}; vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index); return sampleTexture(${texName}, uv); } `; } function getPackedSamplerND(inputInfo, enableShapeUniforms) { const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const glsl = getGlslDifferences(); if (enableShapeUniforms) { // TODO: support 5d and 6d return ` vec4 ${funcName}(int b2, int b, int row, int col) { int valuesPerRow = int(ceil(float(${texName}Shape[3]) / 2.0)); int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[2]) / 2.0)); int index = b * texelsInBatch + (row / 2) * valuesPerRow + (col / 2); texelsInBatch *= ${texName}Shape[1]; index = b2 * texelsInBatch + index; ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0)); int texR = index / packedTexShape[1]; int texC = index - texR * packedTexShape[1]; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(packedTexShape[1], packedTexShape[0]); return ${glsl.texture2D}(${texName}, uv); } `; } const shape = inputInfo.shapeInfo.logicalShape; const rank = shape.length; const texShape = inputInfo.shapeInfo.texShape; const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; const texNumR = packedTexShape[0]; const texNumC = packedTexShape[1]; const valuesPerRow = Math.ceil(shape[rank - 1] / 2); let texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2); let params = `int b, int row, int col`; let index = `b * ${texelsInBatch} + (row / 2) * ${valuesPerRow} + (col / 2)`; for (let b = 2; b < rank - 1; b++) { params = `int b${b}, ` + params; texelsInBatch *= shape[rank - b - 1]; index = `b${b} * ${texelsInBatch} + ` + index; } return ` vec4 ${funcName}(${params}) { int index = ${index}; int texR = index / ${texNumC}; int texC = index - texR * ${texNumC}; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}, ${texNumR}); return ${glsl.texture2D}(${texName}, uv); } `; } function getSampler4D(inputInfo, enableShapeUniforms) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const stride2 = shape[3]; const stride1 = shape[2] * stride2; const stride0 = shape[1] * stride1; const { newShape, keptDims } = tf.util.squeezeShape(shape); if (newShape.length < shape.length) { const newInputInfo = squeezeInputInfo(inputInfo, newShape); const params = ['row', 'col', 'depth', 'depth2']; return ` ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)} float ${funcName}(int row, int col, int depth, int depth2) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return ` float ${funcName}(int row, int col, int depth, int depth2) { int index = round(dot(vec4(row, col, depth, depth2), vec4(${stride0}, ${stride1}, ${stride2}, 1))); ${getUniformSampler(inputInfo)} } `; } const flatOffset = inputInfo.shapeInfo.flatOffset; const texShape = inputInfo.shapeInfo.texShape; const texNumR = texShape[0]; const texNumC = texShape[1]; const stride2Str = `int stride2 = ${texName}Shape[3];`; const stride1Str = `int stride1 = ${texName}Shape[2] * stride2;`; const stride0Str = `int stride0 = ${texName}Shape[1] * stride1;`; if (texNumC === stride0 && flatOffset == null) { // texC is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth, int depth2) { ${stride2Str} ${stride1Str} float texR = float(row); float texC = dot(vec3(col, depth, depth2), vec3(stride1, stride2, 1)); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth, int depth2) { float texR = float(row); float texC = dot(vec3(col, depth, depth2), vec3(${stride1}, ${stride2}, 1)); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } if (texNumC === stride2 && flatOffset == null) { // texR is used directly as physical (no risk of float16 overflow). if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth, int depth2) { float texR = dot(vec3(row, col, depth), vec3(${texName}Shape[1] * ${texName}Shape[2], ${texName}Shape[2], 1)); float texC = float(depth2); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth, int depth2) { float texR = dot(vec3(row, col, depth), vec3(${shape[1] * shape[2]}, ${shape[2]}, 1)); float texC = float(depth2); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } const offset = getFlatOffsetUniformName(texName); if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth, int depth2) { // Explicitly use integer operations as dot() only works on floats. ${stride2Str} ${stride1Str} ${stride0Str} int index = row * stride0 + col * stride1 + depth * stride2 + depth2; vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset}); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth, int depth2) { // Explicitly use integer operations as dot() only works on floats. int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} + depth2; vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index + ${offset}); return sampleTexture(${texName}, uv); } `; } function getSampler5D(inputInfo) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const stride3 = shape[4]; const stride2 = shape[3] * stride3; const stride1 = shape[2] * stride2; const stride0 = shape[1] * stride1; const { newShape, keptDims } = tf.util.squeezeShape(shape); if (newShape.length < shape.length) { const newInputInfo = squeezeInputInfo(inputInfo, newShape); const params = ['row', 'col', 'depth', 'depth2', 'depth3']; return ` ${getSamplerFromInInfo(newInputInfo)} float ${funcName}(int row, int col, int depth, int depth2, int depth3) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3) { float index = dot( vec4(row, col, depth, depth2), vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) + depth3; ${getUniformSampler(inputInfo)} } `; } const flatOffset = inputInfo.shapeInfo.flatOffset; const texShape = inputInfo.shapeInfo.texShape; const texNumR = texShape[0]; const texNumC = texShape[1]; if (texNumC === stride0 && flatOffset == null) { // texC is used directly as physical (no risk of float16 overflow). return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3) { int texR = row; float texC = dot(vec4(col, depth, depth2, depth3), vec4(${stride1}, ${stride2}, ${stride3}, 1)); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } if (texNumC === stride3 && flatOffset == null) { // texR is used directly as physical (no risk of float16 overflow). return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3) { float texR = dot( vec4(row, col, depth, depth2), vec4(${shape[1] * shape[2] * shape[3]}, ${shape[2] * shape[3]}, ${shape[3]}, 1)); int texC = depth3; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } const offset = getFlatOffsetUniformName(texName); return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3) { // Explicitly use integer operations as dot() only works on floats. int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} + depth2 * ${stride3} + depth3 + ${offset}; vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index); return sampleTexture(${texName}, uv); } `; } function getSampler6D(inputInfo) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const { newShape, keptDims } = tf.util.squeezeShape(shape); if (newShape.length < shape.length) { const newInputInfo = squeezeInputInfo(inputInfo, newShape); const params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4']; return ` ${getSamplerFromInInfo(newInputInfo)} float ${funcName}(int row, int col, int depth, int depth2, int depth3, int depth4) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } const stride4 = shape[5]; const stride3 = shape[4] * stride4; const stride2 = shape[3] * stride3; const stride1 = shape[2] * stride2; const stride0 = shape[1] * stride1; if (inputInfo.shapeInfo.isUniform) { // Uniform arrays will be less than 65505 (no risk of float16 overflow). return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3, int depth4) { int index = round(dot( vec4(row, col, depth, depth2), vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) + dot( vec2(depth3, depth4), vec2(${stride4}, 1))); ${getUniformSampler(inputInfo)} } `; } const flatOffset = inputInfo.shapeInfo.flatOffset; const texShape = inputInfo.shapeInfo.texShape; const texNumR = texShape[0]; const texNumC = texShape[1]; if (texNumC === stride0 && flatOffset == null) { // texC is used directly as physical (no risk of float16 overflow). return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3, int depth4) { int texR = row; float texC = dot(vec4(col, depth, depth2, depth3), vec4(${stride1}, ${stride2}, ${stride3}, ${stride4})) + float(depth4); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } if (texNumC === stride4 && flatOffset == null) { // texR is used directly as physical (no risk of float16 overflow). return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3, int depth4) { float texR = dot(vec4(row, col, depth, depth2), vec4(${shape[1] * shape[2] * shape[3] * shape[4]}, ${shape[2] * shape[3] * shape[4]}, ${shape[3] * shape[4]}, ${shape[4]})) + float(depth3); int texC = depth4; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } const offset = getFlatOffsetUniformName(texName); return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3, int depth4) { // Explicitly use integer operations as dot() only works on floats. int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} + depth2 * ${stride3} + depth3 * ${stride4} + depth4 + ${offset}; vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index); return sampleTexture(${texName}, uv); } `; } function getUniformSampler(inputInfo) { const texName = inputInfo.name; const inSize = tf.util.sizeFromShape(inputInfo.shapeInfo.logicalShape); if (inSize < 2) { return `return ${texName};`; } return ` for (int i = 0; i < ${inSize}; i++) { if (i == index) { return ${texName}[i]; } } `; } function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) { const texName = inputInfo.name; const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); const funcName = 'get' + texFuncSnippet + 'AtOutCoords'; const inRank = inputInfo.shapeInfo.logicalShape.length; const outRank = outShapeInfo.logicalShape.length; const broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); const type = getCoordsDataType(outRank); const rankDiff = outRank - inRank; let coordsSnippet; const fields = ['x', 'y', 'z', 'w', 'u', 'v']; if (inRank === 0) { coordsSnippet = ''; } else if (outRank < 2 && broadcastDims.length >= 1) { coordsSnippet = 'coords = 0;'; } else { coordsSnippet = broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`) .join('\n'); } let unpackedCoordsSnippet = ''; if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape .map((s, i) => `coords.${fields[i + rankDiff]}`) .join(', '); } let output = `return outputValue;`; const inSize = tf.util.sizeFromShape(inputInfo.shapeInfo.logicalShape); const isInputScalar = inSize === 1; const outSize = tf.util.sizeFromShape(outShapeInfo.logicalShape); const isOutputScalar = outSize === 1; if (inRank === 1 && !isInputScalar && !isOutputScalar) { output = ` return vec4(outputValue.xy, outputValue.xy); `; } else if (isInputScalar && !isOutputScalar) { if (outRank === 1) { output = ` return vec4(outputValue.x, outputValue.x, 0., 0.); `; } else { output = ` return vec4(outputValue.x); `; } } else if (broadcastDims.length) { const rows = inRank - 2; const cols = inRank - 1; if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) { output = `return vec4(outputValue.x);`; } else if (broadcastDims.indexOf(rows) > -1) { output = `return vec4(outputValue.x, outputValue.y, ` + `outputValue.x, outputValue.y);`; } else if (broadcastDims.indexOf(cols) > -1) { output = `return vec4(outputValue.xx, outputValue.zz);`; } } return ` vec4 ${funcName}() { ${type} coords = getOutputCoords(); ${coordsSnippet} vec4 outputValue = get${texFuncSnippet}(${unpackedCoordsSnippet}); ${output} } `; } function getSamplerAtOutputCoords(inputInfo, outShapeInfo) { const texName = inputInfo.name; const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); const funcName = 'get' + texFuncSnippet + 'AtOutCoords'; const outTexShape = outShapeInfo.texShape; const inTexShape = inputInfo.shapeInfo.texShape; const inRank = inputInfo.shapeInfo.logicalShape.length; const outRank = outShapeInfo.logicalShape.length; if (!inputInfo.shapeInfo.isUniform && inRank === outRank && inputInfo.shapeInfo.flatOffset == null && tf.util.arraysEqual(inTexShape, outTexShape)) { return ` float ${funcName}() { return sampleTexture(${texName}, resultUV); } `; } const type = getCoordsDataType(outRank); const broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); const rankDiff = outRank - inRank; let coordsSnippet; const fields = ['x', 'y', 'z', 'w', 'u', 'v']; if (inRank === 0) { coordsSnippet = ''; } else if (outRank < 2 && broadcastDims.length >= 1) { coordsSnippet = 'coords = 0;'; } else { coordsSnippet = broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`) .join('\n'); } let unpackedCoordsSnippet = ''; if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape .map((s, i) => `coords.${fields[i + rankDiff]}`) .join(', '); } return ` float ${funcName}() { ${type} coords = getOutputCoords(); ${coordsSnippet} return get${texFuncSnippet}(${unpackedCoordsSnippet}); } `; } function getCoordsDataType(rank) { if (rank <= 1) { return 'int'; } else if (rank === 2) { return 'ivec2'; } else if (rank === 3) { return 'ivec3'; } else if (rank === 4) { return 'ivec4'; } else if (rank === 5) { return 'ivec5'; } else if (rank === 6) { return 'ivec6'; } else { throw Error(`GPU for rank ${rank} is not yet supported`); } } function getUniformInfoFromShape(isPacked, shape, texShape) { const { newShape, keptDims } = tf.util.squeezeShape(shape); const rank = shape.length; const useSqueezePackedShape = isPacked && rank === 3 && shape[0] === 1; const squeezeShape = useSqueezePackedShape ? shape.slice(1) : newShape; const useSqueezeShape = (!isPacked && rank > 1 && !tf.util.arraysEqual(shape, texShape) && newShape.length < rank) || useSqueezePackedShape; const uniformShape = useSqueezeShape ? squeezeShape : shape; return { useSqueezeShape, uniformShape, keptDims }; } /** Returns a new input info (a copy) that has a squeezed logical shape. */ function squeezeInputInfo(inInfo, squeezedShape) { // Deep copy. const newInputInfo = JSON.parse(JSON.stringify(inInfo)); newInputInfo.shapeInfo.logicalShape = squeezedShape; return newInputInfo; } function getSqueezedParams(params, keptDims) { return keptDims.map(d => params[d]).join(', '); } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function compileProgram(gpgpu, program, inputs, output) { const inputInfos = inputs.map((input, i) => { const shapeInfo = { logicalShape: input.shape, texShape: input.isUniform ? null : input.texData.texShape, isUniform: input.isUniform, isPacked: input.isUniform ? false : input.texData.isPacked, flatOffset: null }; if (input.texData != null && input.texData.slice != null && input.texData.slice.flatOffset > 0) { shapeInfo.flatOffset = input.texData.slice.flatOffset; } return { name: program.variableNames[i], shapeInfo }; }); const inShapeInfos = inputInfos.map(x => x.shapeInfo); const outShapeInfo = { logicalShape: output.shape, texShape: output.texData.texShape, isUniform: false, isPacked: output.texData.isPacked, flatOffset: null }; const source = makeShader(inputInfos, outShapeInfo, program); const fragmentShader = createFragmentShader(gpgpu.gl, source); const webGLProgram = gpgpu.createProgram(fragmentShader); if (!tf.env().get('ENGINE_COMPILE_ONLY')) { gpgpu.buildVao(webGLProgram); return Object.assign({ program, fragmentShader, source, webGLProgram, inShapeInfos, outShapeInfo }, getUniformLocations(gpgpu, program, webGLProgram)); } else { return { program, fragmentShader, source, webGLProgram, inShapeInfos, outShapeInfo, variablesLocations: null, customUniformLocations: null, infLoc: null, nanLoc: null, outShapeLocation: null, outShapeStridesLocation: null, outTexShapeLocation: null }; } } function getUniformLocations(gpgpu, program, webGLProgram) { const variablesLocations = []; const customUniformLocations = []; let outShapeLocation; let outTexShapeLocation; let outShapeStridesLocation; let infLoc = null; let nanLoc = null; // Add special uniforms (NAN, INFINITY) nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false); if (tf.env().getNumber('WEBGL_VERSION') === 1) { infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false); } // Add user-defined uniforms const shouldThrow = false; for (const varName of program.variableNames) { const varLocs = { name: varName, uniform: gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow), offset: gpgpu.getUniformLocation(webGLProgram, `offset${varName}`, shouldThrow), }; if (program.enableShapeUniforms) { varLocs.shape = gpgpu.getUniformLocation(webGLProgram, `${varName}Shape`, shouldThrow); varLocs.texShape = gpgpu.getUniformLocation(webGLProgram, `${varName}TexShape`, shouldThrow); } variablesLocations.push(varLocs); } if (program.enableShapeUniforms) { outShapeLocation = gpgpu.getUniformLocation(webGLProgram, 'outShape', shouldThrow); outShapeStridesLocation = gpgpu.getUniformLocation(webGLProgram, 'outShapeStrides', shouldThrow); outTexShapeLocation = gpgpu.getUniformLocation(webGLProgram, 'outTexShape', shouldThrow); } if (program.customUniforms) { for (const d of program.customUniforms) { customUniformLocations.push(gpgpu.getUniformLocation(webGLProgram, d.name, shouldThrow)); } } return { variablesLocations, customUniformLocations, infLoc, nanLoc, outShapeLocation, outShapeStridesLocation, outTexShapeLocation }; } function validateBinaryAndProgram(shapeInfos, inputs) { if (shapeInfos.length !== inputs.length) { throw Error(`Binary was compiled with ${shapeInfos.length} inputs, but ` + `was executed with ${inputs.length} inputs`); } shapeInfos.forEach((s, i) => { const shapeA = s.logicalShape; const input = inputs[i]; const shapeB = input.shape; if (!tf.util.arraysEqual(shapeA, shapeB)) { throw Error(`Binary was compiled with different shapes than ` + `the current args. Shapes ${shapeA} and ${shapeB} must match`); } // The input is uploaded as uniform. if (s.isUniform && input.isUniform) { return; } const texShapeA = s.texShape; const texShapeB = input.isUniform ? null : input.texData.texShape; if (!tf.util.arraysEqual(texShapeA, texShapeB)) { throw Error(`Binary was compiled with different texture shapes than the` + ` current args. Shape ${texShapeA} and ${texShapeB} must match`); } }); } function runProgram(gpgpu, binary, inputs, output, customUniformValues) { if (!binary.program.enableShapeUniforms) { validateBinaryAndProgram(binary.inShapeInfos, inputs); validateBinaryAndProgram([binary.outShapeInfo], [output]); } const outTex = output.texData.texture; const outTexShape = output.texData.texShape; if (output.texData.isPacked) { gpgpu.setOutputPackedMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]); } else { gpgpu.setOutputMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]); } gpgpu.setProgram(binary.webGLProgram); gpgpu.bindVertexArray(binary.webGLProgram.vao); // Set special uniforms (NAN, INFINITY) if (tf.env().getNumber('WEBGL_VERSION') === 1) { if (binary.infLoc !== null) { gpgpu.gl.uniform1f(binary.infLoc, Infinity); } } if (binary.nanLoc !== null) { gpgpu.gl.uniform1f(binary.nanLoc, NaN); } // Set user-defined inputs for (let i = 0; i < inputs.length; ++i) { const input = inputs[i]; const { uniform: varLoc, offset: varOffsetLoc, shape: varShapeLoc, texShape: varTexShapeLoc, } = binary.variablesLocations[i]; if (varShapeLoc) { const { uniformShape } = getUniformInfoFromShape(binary.program.packedInputs, input.shape, input.texData.texShape); switch (uniformShape.length) { case 1: gpgpu.gl.uniform1iv(varShapeLoc, new Int32Array(uniformShape)); break; case 2: gpgpu.gl.uniform2iv(varShapeLoc, new Int32Array(uniformShape)); break; case 3: gpgpu.gl.uniform3iv(varShapeLoc, new Int32Array(uniformShape)); break; case 4: gpgpu.gl.uniform4iv(varShapeLoc, new Int32Array(uniformShape)); break; } } if (varTexShapeLoc) { gpgpu.gl.uniform2i(varTexShapeLoc, input.texData.texShape[0], input.texData.texShape[1]); } if (varLoc == null) { // The compiler inferred that this variable is not used in this shader. continue; } if (input.isUniform) { // Upload the values of the tensor as uniform. if (tf.util.sizeFromShape(input.shape) < 2) { gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]); } else { let vals = input.uniformValues; if (!(vals instanceof Float32Array)) { vals = new Float32Array(vals); } gpgpu.gl.uniform1fv(varLoc, vals); } continue; } // If the input was sliced, upload the flat offset index. if (input.texData.slice != null && varOffsetLoc != null) { gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset); } gpgpu.setInputMatrixTexture(input.texData.texture.texture, varLoc, i); } const outShapeLoc = binary.outShapeLocation; if (outShapeLoc) { switch (output.shape.length) { case 1: gpgpu.gl.uniform1iv(outShapeLoc, new Int32Array(output.shape)); break; case 2: gpgpu.gl.uniform2iv(outShapeLoc, new Int32Array(output.shape)); break; case 3: gpgpu.gl.uniform3iv(outShapeLoc, new Int32Array(output.shape)); break; case 4: gpgpu.gl.uniform4iv(outShapeLoc, new Int32Array(output.shape)); break; } } if (binary.outShapeStridesLocation) { const strides = tf.util.computeStrides(output.shape); switch (output.shape.length) { case 2: gpgpu.gl.uniform1iv(binary.outShapeStridesLocation, new Int32Array(strides)); break; case 3: gpgpu.gl.uniform2iv(binary.outShapeStridesLocation, new Int32Array(strides)); break; case 4: gpgpu.gl.uniform3iv(binary.outShapeStridesLocation, new Int32Array(strides)); break; } } if (binary.outTexShapeLocation) { gpgpu.gl.uniform2i(binary.outTexShapeLocation, output.texData.texShape[0], output.texData.texShape[1]); } if (binary.program.customUniforms && customUniformValues) { for (let i = 0; i < binary.program.customUniforms.length; ++i) { const d = binary.program.customUniforms[i]; const customLoc = binary.customUniformLocations[i]; const customValue = customUniformValues[i]; if (d.type === 'float') { gpgpu.gl.uniform1fv(customLoc, customValue); } else if (d.type === 'vec2') { gpgpu.gl.uniform2fv(customLoc, customValue); } else if (d.type === 'vec3') { gpgpu.gl.uniform3fv(customLoc, customValue); } else if (d.type === 'vec4') { gpgpu.gl.uniform4fv(customLoc, customValue); } else if (d.type === 'int') { gpgpu.gl.uniform1iv(customLoc, customValue); } else if (d.type === 'ivec2') { gpgpu.gl.uniform2iv(customLoc, customValue); } else if (d.type === 'ivec3') { gpgpu.gl.uniform3iv(customLoc, customValue); } else if (d.type === 'ivec4') { gpgpu.gl.uniform4iv(customLoc, customValue); } else { throw Error(`uniform type ${d.type} is not supported yet.`); } } } gpgpu.executeProgram(); } function makeShaderKey(program, inputs, output) { let keyInputs = ''; inputs.concat(output).forEach(x => { const hasOffset = x.texData != null && x.texData.slice != null && x.texData.slice.flatOffset > 0; // TODO: Remove the condition of !x.isUniform. if (program.enableShapeUniforms && !x.isUniform) { const xTexShape = x.texData.texShape; const { useSqueezeShape, uniformShape, keptDims } = getUniformInfoFromShape(program.packedInputs, x.shape, xTexShape); let rank1 = '', rank2 = '', rank34 = ''; if (uniformShape.length === 1 && program.packedInputs) { const packedTexShape = [Math.ceil(xTexShape[0] / 2), Math.ceil(xTexShape[1] / 2)]; rank1 = `${packedTexShape[0] > 1}_${packedTexShape[1] > 1}`; } else if (uniformShape.length === 2 && !program.packedInputs) { rank2 = `${uniformShape[0] > 1}_${uniformShape[1] > 1}`; } else if (uniformShape.length > 2 && !program.packedInputs) { const strides = tf.util.computeStrides(uniformShape); rank34 = `${strides[0] === xTexShape[1]}_${strides[strides.length - 1] === xTexShape[1]}`; } const xRank = x.shape.length; const isLogicalShapTexShapeEqual = uniformShape.length === 2 && tf.util.arraysEqual(x.shape, xTexShape); const isScalar = tf.util.sizeFromShape(x.shape) === 1; const broadcastDims = tf.backend_util.getBroadcastDims(x.shape, output.shape); const isInOutTexShapeEqual = !program.packedInputs && xRank === output.shape.length && tf.util.arraysEqual(xTexShape, output.texData.texShape); const isTexShapeGreaterThanOne = program.packedInputs || uniformShape.length > 2 ? '' : `${xTexShape[0] > 1}_${xTexShape[1] > 1}`; // These key components are needed due to shader_compiler is embedding // them in the shader. // |xRank| is used to determine the coords length. See // get[Packed]SamplerAtOutputCoords. // |isInOutTexShapeEqual| is used to determine whether going to an // optimization path in getSamplerAtOutputCoords. // |useSqueezeShape| is extracted from squeezeInputInfo of // getSampler[2|3|4]D/getPackedSampler3D. // |isScalar| is extracted from isInputScalar/isOutputScalar in // getPackedSamplerAtOutputCoords. // |broadcastDims| is extracted from get[Packed]SamplerAtOutputCoords. // |isLogicalShapTexShapeEqual| is used in // getOutput[Packed]2DCoords/get[Packed]Sampler2D. // |rank1| is used in getOutputPacked1DCoords. // |rank2| is used in getOutput2DCoords. // |rank34| is used in getSampler3D/getSampler4D. // |isTexShapeGreaterThanOne| are used in // getSampler[Scalar|1D|2D]/getOutput1DCoords. keyInputs += `${xRank}_${isInOutTexShapeEqual}_${useSqueezeShape ? keptDims : ''}_${uniformShape.length}_${isScalar}_${broadcastDims}_${isLogicalShapTexShapeEqual}_${rank1}_${rank2}_${rank34}_${isTexShapeGreaterThanOne}_${hasOffset}`; } else { const texShape = x.isUniform ? 'uniform' : x.texData.texShape; keyInputs += `${x.shape}_${texShape}_${hasOffset}`; } }); const keyUserCode = program.userCode; let key = program.constructor.name; // Fast string concat. See https://jsperf.com/string-concatenation/14. key += '_' + keyInputs + '_' + keyUserCode + `${tf.env().getNumber('WEBGL_VERSION')}`; return key; } function useShapeUniforms(rank) { // TODO: Remove the limitaion of rank <= 4. return tf.env().getBool('WEBGL_USE_SHAPES_UNIFORMS') && rank <= 4; } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class DecodeMatrixProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; this.outPackingScheme = PackingScheme.DENSE; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = ` ivec3 outCoordsFromFlatIndex(int index) { ${this.enableShapeUniforms ? getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)} return ivec3(r, c, d); } void main() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1])); int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y); vec4 result = vec4(0.); for (int i=0; i<4; i++) { int flatIndex = index + i; ivec3 rc = outCoordsFromFlatIndex(flatIndex); result[i] = getA(rc.x, rc.y, rc.z); } ${glsl.output} = result; } `; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class DecodeMatrixPackedProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outPackingScheme = PackingScheme.DENSE; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = ` ivec3 outCoordsFromFlatIndex(int index) { ${this.enableShapeUniforms ? getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)} return ivec3(r, c, d); } void main() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1])); int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y); vec4 result = vec4(0.); for (int i=0; i<4; i++) { int flatIndex = index + i; ivec3 rc = outCoordsFromFlatIndex(flatIndex); result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z)); } ${glsl.output} = result; } `; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class EncodeFloatProgram { constructor(outputShape) { this.variableNames = ['A']; this.outTexUsage = TextureUsage.DOWNLOAD; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.userCode = ` ${ENCODE_FLOAT_SNIPPET} void main() { float x = getAAtOutCoords(); ${glsl.output} = encode_float(x); } `; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class EncodeFloatPackedProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = false; this.outTexUsage = TextureUsage.DOWNLOAD; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.userCode = ` ${ENCODE_FLOAT_SNIPPET} void main() { ivec3 coords = getOutputCoords(); float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z)); ${glsl.output} = encode_float(x); } `; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const CHANNEL_CHAR_TO_INDEX_MAP = { 'R': 0, 'G': 1, 'B': 2, 'A': 3 }; class EncodeMatrixProgram { constructor(outputShape, inputIsUnsignedByte = false, usedChannels = 'RGBA') { this.variableNames = ['A']; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); let output = `result`; if (inputIsUnsignedByte) { output = `floor(result * 255. + 0.5)`; } let mainLoop = ''; for (let usedChannelIndex = 0; usedChannelIndex < usedChannels.length; usedChannelIndex++) { const curChannel = usedChannels[usedChannelIndex]; mainLoop += ` if(offset == ${usedChannelIndex}) { result = values[${CHANNEL_CHAR_TO_INDEX_MAP[curChannel]}]; }`; } this.userCode = ` ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape)} void main() { ivec3 coords = getOutputCoords(); int flatIndex = getFlatIndex(coords); float result = 0.; int offset = imod(flatIndex, ${usedChannels.length}); flatIndex = idiv(flatIndex, ${usedChannels.length}, 1.); int r = flatIndex / texShape[1]; if (r < texShape[0]) { int c = imod(flatIndex, texShape[1]); vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]); vec4 values = ${glsl.texture2D}(A, uv); ${mainLoop} } ${glsl.output} = vec4(${output}, 0., 0., 0.); } `; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /* This is how the shader encodes a tensor with shape = [2, 3, 5] (indices are [batch, row, col]). 000|001 002|003 004|xxx 020|021 022|023 024|xxx ------- ------- ------- ------- ------- ------- 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx 100|101 102|103 104|xxx 120|121 122|123 124|xxx ------- ------- ------- ------- ------- ------- 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx Single texels contain only values from the same batch, and from adjacent rows and columns. */ class EncodeMatrixPackedProgram { constructor(outputShape, inputIsUnsignedByte = false) { this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); let mainLoop = ''; let output = 'result'; if (inputIsUnsignedByte) { output = 'floor(result * 255. + 0.5)'; } for (let row = 0; row <= 1; row++) { for (let col = 0; col <= 1; col++) { const channel = row * 2 + col; mainLoop += ` localCoords = coords; if(localCoords[2] + ${col} < ${this.enableShapeUniforms ? 'outShape[2]' : `${outputShape[2]}`}) { localCoords[2] += ${col}; if (localCoords[1] + ${row} < ${this.enableShapeUniforms ? 'outShape[1]' : `${outputShape[1]}`}) { localCoords[1] += ${row}; flatIndex = getFlatIndex(localCoords); offset = imod(flatIndex, 4); flatIndex = idiv(flatIndex, 4, 1.); int r = flatIndex / texShape[1]; int c = imod(flatIndex, texShape[1]); vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]); values = ${glsl.texture2D}(A, uv); if (offset == 0) { result[${channel}] = values[0]; } else if (offset == 1) { result[${channel}] = values[1]; } else if (offset == 2) { result[${channel}] = values[2]; } else { result[${channel}] = values[3]; } } } `; } } this.userCode = ` ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape)} void main() { ivec3 coords = getOutputCoords(); vec4 result = vec4(0.); int flatIndex, r, c, offset; ivec3 localCoords; vec2 uv; vec4 values; ${mainLoop} ${glsl.output} = ${output}; } `; } } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function createVertexShader(gl) { const glsl = getGlslDifferences(); const vertexShaderSource = `${glsl.version} precision highp float; ${glsl.attribute} vec3 clipSpacePos; ${glsl.attribute} vec2 uv; ${glsl.varyingVs} vec2 resultUV; void main() { gl_Position = vec4(clipSpacePos, 1); resultUV = uv; }`; return createVertexShader$1(gl, vertexShaderSource); } function createVertexBuffer(gl) { // [x y z u v] * [upper-left, lower-left, upper-right, lower-right] const vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]); return createStaticVertexBuffer(gl, vertexArray); } function createIndexBuffer(gl) { // OpenGL (and WebGL) have "CCW == front" winding const triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]); return createStaticIndexBuffer(gl, triangleVertexIndices); } function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) { validateTextureSize(width, height); const texture = createTexture(gl); const tex2d = gl.TEXTURE_2D; callAndCheck(gl, () => gl.bindTexture(tex2d, texture)); callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE)); callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE)); callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST)); callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST)); if (tf.env().getNumber('WEBGL_VERSION') === 1) { callAndCheck(gl, () => gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null)); } else { callAndCheck(gl, () => gl .texStorage2D(tex2d, 1, internalFormat, width, height)); } callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null)); return { texture, texShape: [height, width] }; } function getInternalFormatForFloat32MatrixTexture(textureConfig) { return textureConfig.internalFormatFloat; } function createFloat32MatrixTexture(gl, rows, columns, textureConfig) { const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT); } function getInternalFormatForFloat16MatrixTexture(textureConfig) { return textureConfig.internalFormatHalfFloat; } function createFloat16MatrixTexture(gl, rows, columns, textureConfig) { const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat); } function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) { return textureConfig.downloadTextureFormat; } function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) { const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE); } function getInternalFormatForPackedMatrixTexture(textureConfig) { return textureConfig.internalFormatPackedFloat; } function createPackedMatrixTexture(gl, rows, columns, textureConfig) { const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT); } function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) { return textureConfig.internalFormatPackedHalfFloat; } function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) { const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat); } function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) { const posOffset = 0; // x is the first buffer element const uvOffset = 3 * 4; // uv comes after [x y z] const stride = (3 * 4) + (2 * 4); // xyz + uv, each entry is 4-byte float. callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer)); const success = bindVertexBufferToProgramAttribute(gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset); return success && bindVertexBufferToProgramAttribute(gl, program, 'uv', vertexBuffer, 2, stride, uvOffset); } function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) { callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture)); let dataForUpload, texelDataType, internalFormat; if (data instanceof Uint8Array) { dataForUpload = new Uint8Array(width * height * 4); texelDataType = gl.UNSIGNED_BYTE; internalFormat = gl.RGBA; } else { dataForUpload = new Float32Array(width * height * 4); texelDataType = gl.FLOAT; internalFormat = textureConfig.internalFormatPackedFloat; } dataForUpload.set(data); if (tf.env().getNumber('WEBGL_VERSION') === 2) { callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, width, height, gl.RGBA, texelDataType, dataForUpload)); } else { callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload)); } callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null)); } function uploadPixelDataToTexture(gl, texture, pixels) { callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture)); if (pixels.data instanceof Uint8Array) { if (tf.env().getNumber('WEBGL_VERSION') === 2) { callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, pixels.width, pixels.height, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data)); } else { callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data)); } } else { if (tf.env().getNumber('WEBGL_VERSION') === 2) { callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels)); } else { callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels)); } } callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null)); } function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) { // Create and bind the buffer. const buffer = gl2.createBuffer(); callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer)); // Initialize the buffer to the size of the texture in bytes. const bytesPerFloat = 4; const valuesPerTexel = 4; const bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns; callAndCheck(gl2, () => gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ)); // Enqueue a command on the GPU command queue to copy of texture into the // buffer. callAndCheck(gl2, () => gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0)); callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null)); return buffer; } function downloadFloat32MatrixFromBuffer(gl, buffer, size) { const gl2 = gl; const downloadTarget = new Float32Array(size); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); return downloadTarget; } function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) { const [w, h] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns); const numChannels = 4; const downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels)); callAndCheck(gl, () => gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget)); // By wrapping the buffer in a Float32Array, we use native browser IEEE 754 // decoding of the 4 bytes that back each 32 bit float. return new Float32Array(downloadTarget.buffer); } function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) { const gl2 = gl; const downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols)); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); return downloadTarget; } function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) { const packedRGBA = new Float32Array(physicalRows * physicalCols * 4); callAndCheck(gl, () => gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA)); return packedRGBA; } var gpgpu_util = { __proto__: null, bindVertexProgramAttributeStreams: bindVertexProgramAttributeStreams, createBufferFromOutputTexture: createBufferFromOutputTexture, createFloat16MatrixTexture: createFloat16MatrixTexture, createFloat16PackedMatrixTexture: createFloat16PackedMatrixTexture, createFloat32MatrixTexture: createFloat32MatrixTexture, createIndexBuffer: createIndexBuffer, createPackedMatrixTexture: createPackedMatrixTexture, createUnsignedBytesMatrixTexture: createUnsignedBytesMatrixTexture, createVertexBuffer: createVertexBuffer, createVertexShader: createVertexShader, downloadByteEncodedFloatMatrixFromOutputTexture: downloadByteEncodedFloatMatrixFromOutputTexture, downloadFloat32MatrixFromBuffer: downloadFloat32MatrixFromBuffer, downloadMatrixFromPackedOutputTexture: downloadMatrixFromPackedOutputTexture, downloadPackedMatrixFromBuffer: downloadPackedMatrixFromBuffer, getInternalFormatForFloat16MatrixTexture: getInternalFormatForFloat16MatrixTexture, getInternalFormatForFloat16PackedMatrixTexture: getInternalFormatForFloat16PackedMatrixTexture, getInternalFormatForFloat32MatrixTexture: getInternalFormatForFloat32MatrixTexture, getInternalFormatForPackedMatrixTexture: getInternalFormatForPackedMatrixTexture, getInternalFormatForUnsignedBytesMatrixTexture: getInternalFormatForUnsignedBytesMatrixTexture, uploadDenseMatrixToTexture: uploadDenseMatrixToTexture, uploadPixelDataToTexture: uploadPixelDataToTexture }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class GPGPUContext { constructor(gl) { this.outputTexture = null; this.program = null; this.disposed = false; this.itemsToPoll = []; const glVersion = tf.env().getNumber('WEBGL_VERSION'); if (gl != null) { this.gl = gl; setWebGLContext(glVersion, gl); } else { this.gl = getWebGLContext(glVersion); } gl = this.gl; if (tf.env().getNumber('WEBGL_VERSION') === 2) { const gl2 = gl; this.createVertexArray = () => { return callAndCheck(gl2, () => gl2.createVertexArray()); }; this.bindVertexArray = (vao) => { return callAndCheck(gl2, () => gl2.bindVertexArray(vao)); }; this.deleteVertexArray = (vao) => { return callAndCheck(gl2, () => gl2.deleteVertexArray(vao)); }; this.getVertexArray = () => { return callAndCheck(gl2, () => gl2.getParameter(gl2.VERTEX_ARRAY_BINDING)); }; } else if (gl != null) { const ext = gl.getExtension('OES_vertex_array_object'); if (ext == null) { throw new Error('All WebGL1 implementations are expected to offer' + ' OES_vertex_array_object.'); } this.createVertexArray = () => { return callAndCheck(gl, () => ext.createVertexArrayOES()); }; this.bindVertexArray = (vao) => { return callAndCheck(gl, () => ext.bindVertexArrayOES(vao)); }; this.deleteVertexArray = (vao) => { return callAndCheck(gl, () => ext.deleteVertexArrayOES(vao)); }; this.getVertexArray = () => { return callAndCheck(gl, () => gl.getParameter(ext.VERTEX_ARRAY_BINDING_OES)); }; } // WebGL 2.0 enables texture floats without an extension. let COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float'; const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float'; this.parallelCompilationExtension = this.gl.getExtension('KHR_parallel_shader_compile'); if (tf.env().getNumber('WEBGL_VERSION') === 1) { const TEXTURE_FLOAT = 'OES_texture_float'; const TEXTURE_HALF_FLOAT = 'OES_texture_half_float'; this.textureFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_FLOAT); if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) { this.textureHalfFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT); } else if (tf.env().get('WEBGL_FORCE_F16_TEXTURES')) { throw new Error('GL context does not support half float textures, yet the ' + 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.'); } this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT); if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) { this.colorBufferHalfFloatExtension = getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT); } else if (tf.env().get('WEBGL_FORCE_F16_TEXTURES')) { throw new Error('GL context does not support color renderable half floats, yet ' + 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.'); } } else { COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float'; if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) { this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT); } else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) { this.colorBufferHalfFloatExtension = this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT); } else { throw new Error('GL context does not support color renderable floats'); } } this.vertexBuffer = createVertexBuffer(this.gl); this.indexBuffer = createIndexBuffer(this.gl); this.framebuffer = createFramebuffer(this.gl); this.textureConfig = getTextureConfig(this.gl, this.textureHalfFloatExtension); } get debug() { return tf.env().getBool('DEBUG'); } dispose() { if (this.disposed) { return; } if (this.program != null) { console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' + ' This is probably a resource leak, delete the program with ' + 'GPGPUContext.deleteProgram before disposing.'); } if (this.outputTexture != null) { console.warn('Disposing a GPGPUContext that still has a bound output matrix ' + 'texture. This is probably a resource leak, delete the output ' + 'matrix texture with GPGPUContext.deleteMatrixTexture before ' + 'disposing.'); } const gl = this.gl; callAndCheck(gl, () => gl.finish()); callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null)); callAndCheck(gl, () => gl.deleteFramebuffer(this.framebuffer)); callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, null)); callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null)); callAndCheck(gl, () => gl.deleteBuffer(this.indexBuffer)); this.disposed = true; } createFloat32MatrixTexture(rows, columns) { this.throwIfDisposed(); return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig); } createFloat16MatrixTexture(rows, columns) { this.throwIfDisposed(); return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig); } createUnsignedBytesMatrixTexture(rows, columns) { this.throwIfDisposed(); return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig); } uploadPixelDataToTexture(texture, pixels) { this.throwIfDisposed(); uploadPixelDataToTexture(this.gl, texture, pixels); } uploadDenseMatrixToTexture(texture, width, height, data) { this.throwIfDisposed(); uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig); } createFloat16PackedMatrixTexture(rows, columns) { this.throwIfDisposed(); return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig); } createPackedMatrixTexture(rows, columns) { this.throwIfDisposed(); return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig); } deleteMatrixTexture(texture) { this.throwIfDisposed(); if (this.outputTexture === texture) { unbindColorTextureFromFramebuffer(this.gl, this.framebuffer); this.outputTexture = null; } callAndCheck(this.gl, () => this.gl.deleteTexture(texture)); } downloadByteEncodedFloatMatrixFromOutputTexture(texture, rows, columns) { return this.downloadMatrixDriver(texture, () => downloadByteEncodedFloatMatrixFromOutputTexture(this.gl, rows, columns, this.textureConfig)); } downloadPackedMatrixFromBuffer(buffer, batch, rows, columns, physicalRows, physicalCols) { return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig); } downloadFloat32MatrixFromBuffer(buffer, size) { return downloadFloat32MatrixFromBuffer(this.gl, buffer, size); } createBufferFromTexture(texture, rows, columns) { this.bindTextureToFrameBuffer(texture); const result = createBufferFromOutputTexture(this.gl, rows, columns, this.textureConfig); this.unbindTextureToFrameBuffer(); return result; } createAndWaitForFence() { const fenceContext = this.createFence(this.gl); return this.pollFence(fenceContext); } createFence(gl) { let query; let isFencePassed; if (tf.env().getBool('WEBGL_FENCE_API_ENABLED')) { const gl2 = gl; const sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0); gl.flush(); isFencePassed = () => { const status = gl2.clientWaitSync(sync, 0, 0); return status === gl2.ALREADY_SIGNALED || status === gl2.CONDITION_SATISFIED; }; query = sync; } else if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { query = this.beginQuery(); this.endQuery(); isFencePassed = () => this.isQueryAvailable(query, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); } else { // If we have no way to fence, return true immediately. This will fire in // WebGL 1.0 when there is no disjoint query timer. In this case, because // the fence passes immediately, we'll immediately ask for a download of // the texture, which will cause the UI thread to hang. isFencePassed = () => true; } return { query, isFencePassed }; } downloadMatrixFromPackedTexture(texture, physicalRows, physicalCols) { return this.downloadMatrixDriver(texture, () => downloadMatrixFromPackedOutputTexture(this.gl, physicalRows, physicalCols)); } createProgram(fragmentShader) { this.throwIfDisposed(); const gl = this.gl; if (this.vertexShader == null) { this.vertexShader = createVertexShader(gl); } const program = createProgram(gl); callAndCheck(gl, () => gl.attachShader(program, this.vertexShader)); callAndCheck(gl, () => gl.attachShader(program, fragmentShader)); linkProgram(gl, program); const program2 = Object.assign(program, { vao: this.createVertexArray() }); if (this.debug) { validateProgram(gl, program2); } return program2; } buildVao(program) { this.setProgram(program); this.bindVertexArray(program.vao); const gl = this.gl; // Bind index buffer, and vertex buffers based on program attrib // locations. callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer)); bindVertexProgramAttributeStreams(gl, program, this.vertexBuffer); } deleteProgram(program) { this.throwIfDisposed(); if (program === this.program) { this.program = null; } if (program != null) { callAndCheck(this.gl, () => this.gl.deleteProgram(program)); this.deleteVertexArray(program.vao); } } setProgram(program) { this.throwIfDisposed(); this.program = program; if (this.program != null) { if (this.debug) { validateProgram(this.gl, this.program); } } callAndCheck(this.gl, () => this.gl.useProgram(program)); } getUniformLocation(program, uniformName, shouldThrow = true) { this.throwIfDisposed(); if (shouldThrow) { return getProgramUniformLocationOrThrow(this.gl, program, uniformName); } else { return getProgramUniformLocation(this.gl, program, uniformName); } } getAttributeLocation(program, attribute) { this.throwIfDisposed(); return callAndCheck(this.gl, () => this.gl.getAttribLocation(program, attribute)); } getUniformLocationNoThrow(program, uniformName) { this.throwIfDisposed(); return this.gl.getUniformLocation(program, uniformName); } setInputMatrixTexture(inputMatrixTexture, uniformLocation, textureUnit) { this.throwIfDisposed(); this.throwIfNoProgram(); bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit); } setOutputMatrixTexture(outputMatrixTexture, rows, columns) { this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows); } setOutputPackedMatrixTexture(outputPackedMatrixTexture, rows, columns) { this.throwIfDisposed(); const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns); this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height); } setOutputMatrixWriteRegion(startRow, numRows, startColumn, numColumns) { this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows); } setOutputPackedMatrixWriteRegion(startRow, numRows, startColumn, numColumns) { throw new Error('setOutputPackedMatrixWriteRegion not implemented.'); } debugValidate() { if (this.program != null) { validateProgram(this.gl, this.program); } validateFramebuffer(this.gl); } executeProgram() { this.throwIfDisposed(); this.throwIfNoProgram(); const gl = this.gl; if (this.debug) { const boundVao = this.getVertexArray(); console.assert(boundVao === this.program.vao, 'VAO changed between setProgram and executeProgram!'); this.debugValidate(); } callAndCheck(gl, () => gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0)); } blockUntilAllProgramsCompleted() { this.throwIfDisposed(); callAndCheck(this.gl, () => this.gl.finish()); } getQueryTimerExtension() { if (this.disjointQueryTimerExtension == null) { this.disjointQueryTimerExtension = getExtensionOrThrow(this.gl, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ? 'EXT_disjoint_timer_query_webgl2' : 'EXT_disjoint_timer_query'); } return this.disjointQueryTimerExtension; } getQueryTimerExtensionWebGL2() { return this.getQueryTimerExtension(); } getQueryTimerExtensionWebGL1() { return this.getQueryTimerExtension(); } beginQuery() { if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { const gl2 = this.gl; const ext = this.getQueryTimerExtensionWebGL2(); const query = gl2.createQuery(); gl2.beginQuery(ext.TIME_ELAPSED_EXT, query); return query; } const ext = this.getQueryTimerExtensionWebGL1(); const query = ext.createQueryEXT(); ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query); return query; } endQuery() { if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { const gl2 = this.gl; const ext = this.getQueryTimerExtensionWebGL2(); gl2.endQuery(ext.TIME_ELAPSED_EXT); return; } const ext = this.getQueryTimerExtensionWebGL1(); ext.endQueryEXT(ext.TIME_ELAPSED_EXT); } async waitForQueryAndGetTime(query) { await tf.util.repeatedTry(() => this.disposed || // while testing contexts are created / disposed // in rapid succession, so without this check we // may poll for the query timer indefinitely this.isQueryAvailable(query, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'))); return this.getQueryTime(query, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); } getQueryTime(query, queryTimerVersion) { if (queryTimerVersion === 0) { return null; } if (queryTimerVersion === 2) { const gl2 = this.gl; const timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT); // Return milliseconds. return timeElapsedNanos / 1000000; } else { const ext = this.getQueryTimerExtensionWebGL1(); const timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT); // Return milliseconds. return timeElapsedNanos / 1000000; } } isQueryAvailable(query, queryTimerVersion) { if (queryTimerVersion === 0) { return true; } if (queryTimerVersion === 2) { const gl2 = this.gl; const ext = this.getQueryTimerExtensionWebGL2(); const available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE); if (this.disjoint == null) { this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); } return available && !this.disjoint; } else { const ext = this.getQueryTimerExtensionWebGL1(); const available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT); if (this.disjoint == null) { this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); } return available && !this.disjoint; } } pollFence(fenceContext) { return new Promise(resolve => { this.addItemToPoll(() => fenceContext.isFencePassed(), () => resolve()); }); } pollItems() { // Find the last query that has finished. const index = linearSearchLastTrue(this.itemsToPoll.map(x => x.isDoneFn)); for (let i = 0; i <= index; ++i) { const { resolveFn } = this.itemsToPoll[i]; resolveFn(); } this.itemsToPoll = this.itemsToPoll.slice(index + 1); } addItemToPoll(isDoneFn, resolveFn) { this.itemsToPoll.push({ isDoneFn, resolveFn }); if (this.itemsToPoll.length > 1) { // We already have a running loop that polls. return; } // Start a new loop that polls. let scheduleFn = undefined; if ('setTimeoutCustom' in tf.env().platform) { scheduleFn = tf.env().platform.setTimeoutCustom.bind(tf.env().platform); } tf.util.repeatedTry(() => { this.pollItems(); // End the loop if no more items to poll. return this.itemsToPoll.length === 0; }, () => 0, null, scheduleFn); } bindTextureToFrameBuffer(texture) { this.throwIfDisposed(); bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer); if (this.debug) { validateFramebuffer(this.gl); } } unbindTextureToFrameBuffer() { if (this.outputTexture != null) { bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer); if (this.debug) { validateFramebuffer(this.gl); } } else { unbindColorTextureFromFramebuffer(this.gl, this.framebuffer); } } downloadMatrixDriver(texture, downloadAndDecode) { this.bindTextureToFrameBuffer(texture); const result = downloadAndDecode(); this.unbindTextureToFrameBuffer(); return result; } setOutputMatrixTextureDriver(outputMatrixTextureMaybePacked, width, height) { this.throwIfDisposed(); const gl = this.gl; bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer); if (this.debug) { validateFramebuffer(gl); } this.outputTexture = outputMatrixTextureMaybePacked; callAndCheck(gl, () => gl.viewport(0, 0, width, height)); callAndCheck(gl, () => gl.scissor(0, 0, width, height)); } setOutputMatrixWriteRegionDriver(x, y, width, height) { this.throwIfDisposed(); callAndCheck(this.gl, () => this.gl.scissor(x, y, width, height)); } throwIfDisposed() { if (this.disposed) { throw new Error('Attempted to use disposed GPGPUContext.'); } } throwIfNoProgram() { if (this.program == null) { throw new Error('No GPU program is currently set.'); } } } /** * Finds the index of the last true element using linear search. * Note: We can't do binary search because Chrome expects us to explicitly * test all fences before download: * https://github.com/tensorflow/tfjs/issues/1145 */ function linearSearchLastTrue(arr) { let i = 0; for (; i < arr.length; ++i) { const isDone = arr[i](); if (!isDone) { break; } } return i - 1; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function simpleAbsImpl(vals) { const resultValues = new Float32Array(vals.length); for (let i = 0; i < vals.length; ++i) { resultValues[i] = Math.abs(vals[i]); } return resultValues; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Template that creates implementation for binary ops. Supports broadcast. */ function createSimpleBinaryKernelImpl(op) { return (aShape, bShape, aVals, bVals, dtype) => { const newShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape); const resultRank = newShape.length; const resultStrides = tf.util.computeStrides(newShape); const resultSize = tf.util.sizeFromShape(newShape); const result = tf.util.getTypedArrayFromDType(dtype, resultSize); const aRank = aShape.length; const bRank = bShape.length; const aStrides = tf.util.computeStrides(aShape); const bStrides = tf.util.computeStrides(bShape); const aBroadcastDims = tf.backend_util.getBroadcastDims(aShape, newShape); const bBroadcastDims = tf.backend_util.getBroadcastDims(bShape, newShape); if (aBroadcastDims.length + bBroadcastDims.length === 0) { for (let i = 0; i < result.length; ++i) { result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]); } } else { for (let i = 0; i < result.length; ++i) { const loc = tf.util.indexToLoc(i, resultRank, resultStrides); const aLoc = loc.slice(-aRank); aBroadcastDims.forEach(d => aLoc[d] = 0); const aIndex = tf.util.locToIndex(aLoc, aRank, aStrides); const bLoc = loc.slice(-bRank); bBroadcastDims.forEach(d => bLoc[d] = 0); const bIndex = tf.util.locToIndex(bLoc, bRank, bStrides); result[i] = op(aVals[aIndex], bVals[bIndex]); } } return [result, newShape]; }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function castImpl(values, shape, inputType, dtype) { if (dtype === 'int32') { const resultValues = Int32Array.from(values); return [shape, 'int32', resultValues]; } if (dtype === 'bool') { // This is essentially the result of notEqual(x, 0). We avoid using // kernel notEqual to avoid circular dependency, i.e. binary_utils -> // cast -> notEqual -> binary_utils. const zero = tf.util.toTypedArray([0], inputType); const [resultData, resultShape] = createSimpleBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0)(shape, [], values, zero, 'bool'); return [resultShape, 'bool', resultData]; } throw new Error(`Error in Cast: failed to cast ${inputType} to ${dtype}`); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const addImpl = createSimpleBinaryKernelImpl(((a, b) => a + b)); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function bincountImpl(xVals, weightsVals, weightsDtype, weightsShape, size) { const weightsSize = tf.util.sizeFromShape(weightsShape); const outVals = tf.util.makeZerosTypedArray(size, weightsDtype); for (let i = 0; i < xVals.length; i++) { const value = xVals[i]; if (value < 0) { throw new Error('Input x must be non-negative!'); } if (value >= size) { continue; } if (weightsSize > 0) { outVals[value] += weightsVals[i]; } else { outVals[value] += 1; } } return outVals; } function bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput = false) { const numRows = xBuf.shape[0]; const numCols = xBuf.shape[1]; const outBuf = tf.buffer([numRows, size], weightsBuf.dtype); for (let i = 0; i < numRows; i++) { for (let j = 0; j < numCols; j++) { const value = xBuf.get(i, j); if (value < 0) { throw new Error('Input x must be non-negative!'); } if (value >= size) { continue; } if (binaryOutput) { outBuf.set(1, i, value); } else { if (weightsBuf.size > 0) { outBuf.set(outBuf.get(i, value) + weightsBuf.get(i, j), i, value); } else { outBuf.set(outBuf.get(i, value) + 1, i, value); } } } } return outBuf; } /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const bitwiseAndImpl = createSimpleBinaryKernelImpl(((a, b) => a & b)); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Template that creates implementation for unary op. */ function createSimpleUnaryImpl(op) { return (values, dtype, attrs) => { const newValues = tf.util.getArrayFromDType(dtype, values.length); for (let i = 0; i < values.length; ++i) { newValues[i] = op(values[i], attrs); } return newValues; }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ceilImpl = createSimpleUnaryImpl((xi) => Math.ceil(xi)); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function concatImpl$1(inputs, outShape, dtype, simplyConcat) { const outVals = tf.util.getArrayFromDType(dtype, tf.util.sizeFromShape(outShape)); if (simplyConcat && dtype !== 'string') { // Use built-in TypedArray.set() method for speed. let offset = 0; inputs.forEach(input => { const size = tf.util.sizeFromShape(input.shape); outVals.set(input.vals, offset); offset += size; }); } else { let colOffset = 0; inputs.forEach(input => { const decodedData = dtype === 'string' ? tf.backend_util.fromUint8ToStringArray(input.vals) : input.vals; let tIdx = 0; for (let row = 0; row < input.shape[0]; ++row) { const resIdx = row * outShape[1] + colOffset; for (let col = 0; col < input.shape[1]; ++col) { outVals[resIdx + col] = decodedData[tIdx++]; } } colOffset += input.shape[1]; }); } return outVals; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const equalImpl = createSimpleBinaryKernelImpl((a, b) => (a === b) ? 1 : 0); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const expImpl = createSimpleUnaryImpl((xi) => Math.exp(xi)); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const expm1Impl = createSimpleUnaryImpl((xi) => Math.expm1(xi)); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const floorImpl = createSimpleUnaryImpl((xi) => Math.floor(xi)); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const floorDivImpl = createSimpleBinaryKernelImpl((a, b) => Math.floor(a / b)); /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function gatherNdImpl(indicesData, paramsBuf, dtype, numSlices, sliceRank, sliceSize, strides, paramsShape, paramsSize) { const outBuf = tf.buffer([numSlices, sliceSize], dtype); for (let i = 0; i < numSlices; i++) { const index = []; let flattenIndex = 0; for (let j = 0; j < sliceRank; j++) { const dim = indicesData[i * sliceRank + j]; flattenIndex += dim * strides[j]; index.push(dim); } if (flattenIndex < 0 || flattenIndex >= paramsSize / sliceSize) { throw new Error(`Invalid indices: ${index} does not index into ${paramsShape}`); } for (let k = 0; k < sliceSize; k++) { outBuf.values[i * sliceSize + k] = paramsBuf.get(...paramsBuf.indexToLoc(flattenIndex * sliceSize + k)); } } return outBuf; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function gatherV2Impl(xBuf, indicesBuf, flattenOutputShape) { const outBuf = tf.buffer(flattenOutputShape, xBuf.dtype); for (let i = 0; i < outBuf.size; ++i) { const newLoc = outBuf.indexToLoc(i); const originalLoc = newLoc.slice(); const batchIdx = originalLoc[0]; const indicesIdx = originalLoc[2]; const indicesIndex = indicesBuf.locToIndex([batchIdx, indicesIdx]); originalLoc[2] = indicesBuf.values[indicesIndex]; const originalIndex = xBuf.locToIndex(originalLoc); if (0 <= originalIndex && originalIndex < xBuf.values.length) { outBuf.values[i] = xBuf.values[originalIndex]; } // Else, index is out of bounds, so leave the default zero val in outBuf. } return outBuf; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const greaterImpl = createSimpleBinaryKernelImpl((a, b) => (a > b) ? 1 : 0); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const greaterEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a >= b) ? 1 : 0); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const lessImpl = createSimpleBinaryKernelImpl((a, b) => (a < b) ? 1 : 0); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const lessEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a <= b) ? 1 : 0); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function linSpaceImpl(start, stop, num) { const step = (stop - start) / (num - 1); const values = tf.util.makeZerosTypedArray(num, 'float32'); values[0] = start; for (let i = 1; i < values.length; i++) { values[i] = values[i - 1] + step; } return values; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const logImpl = createSimpleUnaryImpl((xi) => Math.log(xi)); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxImpl$1(aVals, reduceSize, outShape, dtype) { const vals = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(outShape)); for (let i = 0; i < vals.length; ++i) { const offset = i * reduceSize; let max = aVals[offset]; for (let j = 0; j < reduceSize; ++j) { const value = aVals[offset + j]; if (Number.isNaN(value) || value > max) { // comparison with NaN always return false max = value; } } vals[i] = max; } return vals; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const maximumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.max(aValue, bValue))); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const minimumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.min(aValue, bValue))); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const multiplyImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue * bValue)); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function negImpl(xVals, xShape, xDtype) { const minusOne = tf.util.createScalarValue(-1, xDtype); return multiplyImpl([], xShape, minusOne, xVals, xDtype); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const notEqualImpl = createSimpleBinaryKernelImpl(((a, b) => (a !== b) ? 1 : 0)); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function transposeImpl$1(xVals, xShape, dtype, perm, newShape) { const xRank = xShape.length; const xSize = tf.util.sizeFromShape(xShape); const xStrides = tf.util.computeStrides(xShape); const newStrides = tf.util.computeStrides(newShape); const result = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(newShape)); for (let i = 0; i < xSize; ++i) { const loc = tf.util.indexToLoc(i, xRank, xStrides); // Permute location. const newLoc = new Array(loc.length); for (let i = 0; i < newLoc.length; i++) { newLoc[i] = loc[perm[i]]; } const newIndex = tf.util.locToIndex(newLoc, xRank, newStrides); result[newIndex] = xVals[i]; } return result; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function prodImpl(xShape, xDtype, xVals, reductionAxes) { const [outShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(xShape, reductionAxes); const outDtype = tf.upcastType(xDtype, 'int32'); const outVals = tf.util.makeZerosTypedArray(tf.util.sizeFromShape(outShape), outDtype); const reduceSize = tf.util.sizeFromShape(reduceShape); for (let i = 0; i < outVals.length; ++i) { const offset = i * reduceSize; let prod = 1; for (let j = 0; j < reduceSize; ++j) { prod *= xVals[offset + j]; } outVals[i] = prod; } return { outVals, outShape, outDtype }; } /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function validateIndices(indices, indicesShape, numParams) { indices.forEach((index, i) => { if (index < 0 || index >= numParams) { const locString = tf.util.indexToLoc(i, indicesShape.length, tf.util.computeStrides(indicesShape)) .join(','); throw new Error(`indices[${locString}] = ${index} is not in [0, ${numParams})`); } }); } function validateSplits(paramsNestedSplits, numParamsDenseValues) { // Validate for (let dim = 0; dim < paramsNestedSplits.length; ++dim) { const splits = paramsNestedSplits[dim]; const lastSplit = (dim === paramsNestedSplits.length - 1) ? numParamsDenseValues : paramsNestedSplits[dim + 1].length; if (splits.length === 0) { throw new Error('Ragged splits may not be empty'); } if (splits[0] < 0) { throw new Error('Ragged splits must be non-negative'); } if (splits[splits.length - 1] > lastSplit) { throw new Error('Ragged splits must not point past values'); } for (let i = 1; i < splits.length; ++i) { if (splits[i - 1] > splits[i]) { throw new Error('Ragged splits must be sorted in ascending order'); } } } } // Construct the `splits` output tensors, encoded using a nested vector. // Also find the slices of values that need to be copied, and store them // in `valueSlices`. The total number of values that will be copied (which // we need for allocating the output values tensor) is stored in `numValues`. function makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues) { const valueSlices = []; let numValues = 0; const numSplits = indicesShape.length - 1 + paramsNestedSplits.length; const outSplits = new Array(numSplits).fill(null).map(() => [0]); validateSplits(paramsNestedSplits, numParamsDenseValues); // Add `splits` that come from all but the last dimension of the dense // Tensor `indices`. In particular, for each dimension D, we add a // splits tensor whose values are: // range(reduceProd(splits.shape[:D]) + 1) * splits.shape[D+1] // E.g., if indices.shape=[2, 3, 4] then we will add splits tensors: // [0, 3, 6] # length=2+1, stride=3 // [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4 let nrows = 1; for (let dim = 0; dim < indicesShape.length - 1; ++dim) { nrows *= indicesShape[dim]; const rowLength = indicesShape[dim + 1]; for (let i = 1; i < nrows + 1; ++i) { outSplits[dim].push(i * rowLength); } } // Add `splits` that come from `paramsNestedSplits`. Starting with the // outermost ragged dimension (i.e., the first `splits` tensor), we work // our way in, finding the range of values that should be copied. As we // go, we update the output `splits` for each dimension with the appropriate // values. In particular, the *lengths* of the slices from `param_splits` // should be copied to generate corresponding slice lengths in the output // splits. E.g., if we are copying a ragged row with length 4, then we // should add a new split point to outSplits that is 4 greater than the // previous split point in outSplits. for (let i = 0; i < indices.length; ++i) { let start = indices[i]; let limit = indices[i] + 1; // Copy splits. for (let dim = 0; dim < paramsNestedSplits.length; ++dim) { const splits = paramsNestedSplits[dim]; const outDim = dim + indicesShape.length - 1; if (outDim >= 0) { const outSplitsOutDim = outSplits[outDim]; const delta = outSplitsOutDim[outSplitsOutDim.length - 1] - splits[start]; for (let j = start; j < limit; ++j) { outSplits[outDim].push(splits[j + 1] + delta); } } start = splits[start]; limit = splits[limit]; } if (limit !== start) { valueSlices.push([start, limit]); numValues += limit - start; } } return { outSplits, valueSlices, numValues }; } function getSplits(outSplits) { const splitsOut = []; for (let i = 0; i < outSplits.length; ++i) { const numSplits = outSplits[i].length; const splits = tf.util.getArrayFromDType('int32', numSplits); splitsOut.push(splits); outSplits[i].forEach((value, j) => splits[j] = value); } return splitsOut; } function computeFlatOuterDims(orig, numOutDims) { const outDims = orig.slice(0, numOutDims); while (outDims.length < numOutDims) { outDims.push(1); } for (let inDim = numOutDims; inDim < orig.length; inDim++) { outDims[numOutDims - 1] *= orig[inDim]; } return outDims; } // For each slice in `(start, limit)` in `valueSlices`, append // `paramsDenseValues[start,...,limit] to `values`. `valueSize` indicates // the number of scalars contained in each value paramsDenseValues[i]. function writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, values, valuesShape) { const denseM = computeFlatOuterDims(paramsDenseValuesShape, 2)[1]; const valuesM = computeFlatOuterDims(valuesShape, 2)[1]; let outPos = 0; for (const slice of valueSlices) { for (let i = slice[0]; i < slice[1]; ++i) { for (let j = 0; j < valueSize; ++j) { values[outPos * valuesM + j] = paramsDenseValues[i * denseM + j]; } ++outPos; } } } function getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues) { const valuesShape = paramsDenseValuesShape.slice(); valuesShape[0] = numValues; const valuesOut = tf.util.getArrayFromDType(paramsDenseValuesDType, tf.util.sizeFromShape(valuesShape)); const numElements = paramsDenseValues.length; const valueSize = numElements === 0 ? 0 : (numElements / paramsDenseValuesShape[0]); writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, valuesOut, valuesShape); return [valuesOut, valuesShape]; } function raggedGatherImpl(paramsNestedSplits, paramsNestedSplitsShapes, paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, indices, indicesShape, outputRaggedRank) { if (paramsNestedSplits.length === 0) { throw new Error('paramsNestedSplits must be non empty'); } if (paramsNestedSplitsShapes[0].length === 0) { throw new Error('Split tensors must not be scalars'); } const numParams = paramsNestedSplitsShapes[0][0] - 1; validateIndices(indices, indicesShape, numParams); if (paramsDenseValuesShape.length === 0) { throw new Error('params.rank must be nonzero'); } const numParamsDenseValues = paramsDenseValuesShape[0]; // Calculate the `splits`, and store the value slices that we need to // copy in `valueSlices`. const { outSplits, valueSlices, numValues } = makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues); // Write the output tensors. const outputNestedSplits = getSplits(outSplits); const outputDenseValues = getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues); return [outputNestedSplits, outputDenseValues[0], outputDenseValues[1]]; } /** * @license * Copyright 2022 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const INT32_MAX = 2147483647; function raggedRangeImpl(starts, startsShape, startsDType, limits, limitsShape, deltas, deltasShape) { // Check input tensor shapes. if (startsShape.length > 1) { throw new Error('starts must be a scalar or vector'); } if (limitsShape.length > 1) { throw new Error('limits must be a scalar or vector'); } if (deltasShape.length > 1) { throw new Error('deltas must be a scalar or vector'); } // Determine which tensors we need to broadcast. const broadcastStarts = startsShape.length === 0; const broadcastLimits = limitsShape.length === 0; const broadcastDeltas = deltasShape.length === 0; // nRows (number of output rows) is the size of the non-broadcast inputs, // or 1 if all inputs are scalars. const inSizes = []; if (!broadcastStarts) { inSizes.push(startsShape[0]); } if (!broadcastLimits) { inSizes.push(limitsShape[0]); } if (!broadcastDeltas) { inSizes.push(deltasShape[0]); } for (let i = 1; i < inSizes.length; ++i) { if (inSizes[i] !== inSizes[i - 1]) { throw new Error('starts, limits, and deltas must have the same shape'); } } const nRows = inSizes.length === 0 ? 1 : inSizes[0]; // Construct the rtNestedSplits tensor. const rtNestedSplits = tf.util.getArrayFromDType('int32', nRows + 1); rtNestedSplits[0] = 0; for (let row = 0; row < nRows; ++row) { const start = broadcastStarts ? starts[0] : starts[row]; const limit = broadcastLimits ? limits[0] : limits[row]; const delta = broadcastDeltas ? deltas[0] : deltas[row]; if (delta === 0) { throw new Error('Requires delta != 0'); } let size; // The number of elements in the specified range. if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) { size = 0; } else { size = Math.ceil(Math.abs((limit - start) / delta)); if (size > INT32_MAX) { throw new Error(`Requires ((limit - start) / delta) <= ${INT32_MAX}`); } } rtNestedSplits[row + 1] = rtNestedSplits[row] + size; } const nVals = rtNestedSplits[nRows]; // Construct the rtDenseValues tensor. const rtDenseValues = tf.util.getArrayFromDType(startsDType, nVals); let valueIndex = 0; for (let row = 0; row < nRows; ++row) { const rowSize = rtNestedSplits[row + 1] - rtNestedSplits[row]; let value = broadcastStarts ? starts[0] : starts[row]; const delta = broadcastDeltas ? deltas[0] : deltas[row]; for (let i = 0; i < rowSize; ++i) { rtDenseValues[valueIndex++] = value; value += delta; } } return [rtNestedSplits, rtDenseValues]; } /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var RowPartitionType = tf.backend_util.RowPartitionType; // Based on // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc class RaggedTensorToTensorOp { constructor(shape, shapeShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypeStrings) { this.shape = shape; this.shapeShape = shapeShape; this.values = values; this.valuesShape = valuesShape; this.valuesDType = valuesDType; this.defaultValue = defaultValue; this.defaultValueShape = defaultValueShape; this.rowPartitionValues = rowPartitionValues; this.rowPartitionValuesShapes = rowPartitionValuesShapes; this.rowPartitionTypes = tf.backend_util.getRowPartitionTypesHelper(rowPartitionTypeStrings); this.raggedRank = tf.backend_util.getRaggedRank(this.rowPartitionTypes); } getRowPartitionTypeByDimension(dimension) { if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) { return this.rowPartitionTypes[dimension + 1]; } else { return this.rowPartitionTypes[dimension]; } } // Returns the relationship between dimension and dimension + 1. getRowPartitionTensor(dimension) { if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) { return this.rowPartitionValues[dimension + 1]; } else { return this.rowPartitionValues[dimension]; } } getMaxWidth(dimension) { const rowPartitionTensor = this.getRowPartitionTensor(dimension - 1); switch (this.getRowPartitionTypeByDimension(dimension - 1)) { case RowPartitionType.VALUE_ROWIDS: return RaggedTensorToTensorOp.getMaxWidthValueRowID(rowPartitionTensor); case RowPartitionType.ROW_SPLITS: return RaggedTensorToTensorOp.getMaxWidthRowSplit(rowPartitionTensor); default: throw new Error(`Cannot handle partition type ${RowPartitionType[this.getRowPartitionTypeByDimension(dimension - 1)]}`); } } static getMaxWidthRowSplit(rowSplit) { const tensorLength = rowSplit.length; if (tensorLength === 0 || tensorLength === 1) { return 0; } let maxWidth = 0; for (let i = 0; i < tensorLength - 1; ++i) { const currentWidth = rowSplit[i + 1] - rowSplit[i]; if (currentWidth > maxWidth) { maxWidth = currentWidth; } } return maxWidth; } static getMaxWidthValueRowID(valueRowIds) { const indexLength = valueRowIds.length; if (indexLength === 0) { return 0; } let firstEqualIndex = 0; let firstEqualIndexValue = valueRowIds[0]; let maxWidth = 0; for (let i = 1; i < indexLength; ++i) { const value = valueRowIds[i]; if (value !== firstEqualIndexValue) { firstEqualIndexValue = value; maxWidth = Math.max(i - firstEqualIndex, maxWidth); firstEqualIndex = i; } } return Math.max(indexLength - firstEqualIndex, maxWidth); } tensorShapeFromTensor(t, tShape, isPartial = true) { if (tShape.length === 0) { if (t[0] === -1) { return []; } throw new Error(`The only valid scalar shape tensor is the fully unknown shape specified as -1.`); } // MakePartialShape/MakeShapeHelper. return makeShape(t, isPartial); } calculateOutputSize(firstDim) { const valueShape = this.valuesShape; const defaultValueShape = this.defaultValueShape; tf.backend_util.validateDefaultValueShape(defaultValueShape, valueShape); const shape = this.tensorShapeFromTensor(this.shape, this.shapeShape); const outputShape = tf.backend_util.combineRaggedTensorToTensorShapes(this.raggedRank, shape, valueShape); const result = outputShape; if (result[0] < 0) { result[0] = firstDim; } for (let i = 1; i <= this.raggedRank; ++i) { if (result[i] < 0) { result[i] = this.getMaxWidth(i); } } return result; } /** * The outputIndex represents the index in the output tensor * where the first element of a particular dimension would be written. * If it is -1, it indicates that the index is out of scope. * Example, given firstDimension = 10, firstDimensionOutput = 6, * and outputIndexMultiplier = 100: * result = [0 100 200 300 400 500 -1 -1 -1 -1] * If firstDimensionOutput = 11 instead, then: * result = [0 100 200 300 400 500 600 700 800 900] */ calculateFirstParentOutputIndex(firstDimension, outputIndexMultiplier, firstDimensionOutput) { const minDimension = Math.min(firstDimension, firstDimensionOutput); const result = []; let currentOutputIndex = 0; for (let i = 0; i < minDimension; ++i, currentOutputIndex += outputIndexMultiplier) { result.push(currentOutputIndex); } for (let i = minDimension; i < firstDimension; ++i) { result.push(-1); } tf.util.assert(result.length === firstDimension, () => 'Final length of result must be equal to firstDimension.'); return result; } calculateOutputIndexRowSplit(rowSplit, parentOutputIndex, outputIndexMultiplier, outputSize) { const rowSplitSize = rowSplit.length; const result = []; for (let i = 0; i < rowSplitSize - 1; ++i) { const rowLength = rowSplit[i + 1] - rowSplit[i]; let realLength = Math.min(outputSize, rowLength); let parentOutputIndexCurrent = parentOutputIndex[i]; if (parentOutputIndexCurrent === -1) { realLength = 0; } for (let j = 0; j < realLength; ++j) { result.push(parentOutputIndexCurrent); parentOutputIndexCurrent += outputIndexMultiplier; } for (let j = 0; j < rowLength - realLength; ++j) { result.push(-1); } } if (rowSplitSize > 0 && result.length !== rowSplit[rowSplitSize - 1]) { throw new Error('Invalid row split size.'); } return result; } // Calculate the output index of the first element of a list. // The parentOutputIndex is the same computation for the previous list. // -1 indicates an element or list that is out of range. // The outputIndexMultiplier is the number of output indices one moves // forward for each column. // E.g., given: // valueRowIds:[0 1 2 2 2 3 5 5 6] // parentOutputIndex:[1000 1100 2000 2100 -1 3000 4000] // outputIndexMultiplier: 10 // outputSize: 2 // You get: // result = [1000 1100 2000 2010 -1 2100 -1 -1 3000] // result[0] = parentOutputIndex[valueRowIds[0]] // result[1] = parentOutputIndex[valueRowIds[1]] // result[2] = parentOutputIndex[valueRowIds[2]] // result[3] = parentOutputIndex[valueRowIds[2] + 10] // result[4] = -1 because it is the third element the size is 2. // result[5] = parentOutputIndex[valueRowIds[3]] // result[6] = -1 because parentOutputIndex[valueRowIds[6]] == -1 // result[7] = -1 because parentOutputIndex[valueRowIds[6]] == -1 // result[8] = parentOutputIndex[valueRowIds[7]] calculateOutputIndexValueRowID(valueRowIds, parentOutputIndex, outputIndexMultiplier, outputSize) { const indexSize = valueRowIds.length; const result = []; if (indexSize === 0) { return []; } let currentOutputColumn = 0; let currentValueRowId = valueRowIds[0]; if (currentValueRowId >= parentOutputIndex.length) { throw new Error(`Got currentValueRowId=${currentValueRowId}, which is not less than ${parentOutputIndex.length}`); } let currentOutputIndex = parentOutputIndex[currentValueRowId]; result.push(currentOutputIndex); for (let i = 1; i < indexSize; ++i) { const nextValueRowId = valueRowIds[i]; if (nextValueRowId === currentValueRowId) { if (currentOutputIndex >= 0) { ++currentOutputColumn; if (currentOutputColumn < outputSize) { currentOutputIndex += outputIndexMultiplier; } else { currentOutputIndex = -1; } } } else { currentOutputColumn = 0; currentValueRowId = nextValueRowId; if (nextValueRowId >= parentOutputIndex.length) { throw new Error(`Got nextValueRowId=${nextValueRowId} which is not less than ${parentOutputIndex.length}`); } currentOutputIndex = parentOutputIndex[nextValueRowId]; } result.push(currentOutputIndex); } if (result.length !== valueRowIds.length) { throw new Error('Invalid row ids.'); } return result; } calculateOutputIndex(dimension, parentOutputIndex, outputIndexMultiplier, outputSize) { const rowPartitionTensor = this.getRowPartitionTensor(dimension); const partitionType = this.getRowPartitionTypeByDimension(dimension); switch (partitionType) { case RowPartitionType.VALUE_ROWIDS: return this.calculateOutputIndexValueRowID(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize); case RowPartitionType.ROW_SPLITS: if (rowPartitionTensor.length - 1 > parentOutputIndex.length) { throw new Error(`Row partition size is greater than output size: ${rowPartitionTensor.length - 1} > ${parentOutputIndex.length}`); } return this.calculateOutputIndexRowSplit(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize); default: throw new Error(`Unsupported partition type: ${RowPartitionType[partitionType]}`); } } getFirstDimensionSize() { const firstPartitionTensor = this.rowPartitionValues[0]; if (this.rowPartitionTypes.length === 0) { throw new Error('No row_partition_types given.'); } const firstPartitionType = this.rowPartitionTypes[0]; switch (firstPartitionType) { case RowPartitionType.FIRST_DIM_SIZE: return firstPartitionTensor[0]; case RowPartitionType.VALUE_ROWIDS: throw new Error('Cannot handle VALUE_ROWIDS in first dimension.'); case RowPartitionType.ROW_SPLITS: return this.rowPartitionValuesShapes[0][0] - 1; default: throw new Error(`Cannot handle type ${RowPartitionType[firstPartitionType]}`); } } compute() { const firstPartitionTensor = this.rowPartitionValues[0]; if (firstPartitionTensor.length <= 0) { throw new Error('Invalid first partition input. ' + 'Tensor requires at least one element.'); } const firstDimension = this.getFirstDimensionSize(); const outputSize = this.calculateOutputSize(firstDimension); const multiplier = new Array(this.raggedRank + 1); multiplier[multiplier.length - 1] = 1; for (let i = multiplier.length - 2; i >= 0; --i) { multiplier[i] = multiplier[i + 1] * outputSize[i + 1]; } // Full size of the tensor. const outputShape = makeShape(outputSize, false); const outputTensor = tf.util.getArrayFromDType(this.valuesDType, tf.util.sizeFromShape(outputShape)); const fullSize = multiplier[0] * outputSize[0]; if (fullSize > 0) { let outputIndex = this.calculateFirstParentOutputIndex(firstDimension, multiplier[0], outputSize[0]); for (let i = 1; i <= this.raggedRank; ++i) { const newOutputIndex = this.calculateOutputIndex(i - 1, outputIndex, multiplier[i], outputSize[i]); outputIndex = newOutputIndex; } this.setOutput(this.raggedRank, outputIndex, outputTensor, outputShape); } return [outputShape, outputTensor]; } setOutput(raggedRank, outputIndex, outputTensor, outputShape) { if (outputTensor.length === 0) { return; } const valuesBase = this.values; const outputBase = outputTensor; let elementShape = outputShape.slice(); elementShape = elementShape.slice(raggedRank + 1); const valueElementSize = tf.util.sizeFromShape(elementShape); const outputIndexSize = outputIndex.length; // Broadcast the default value to value_element_size. (We can skip this // if defaultValueTensor.size == 1, since we use fill when that's true.) let defaultValue = this.defaultValue; if (defaultValue.length !== valueElementSize && defaultValue.length !== 1) { const srcShape = this.defaultValueShape; tf.tidy(() => { const defaultValueTensor = tf.reshape(defaultValue, srcShape); const bCastDefault = tf.broadcastTo(defaultValueTensor, elementShape); defaultValue = bCastDefault.dataSync(); }); } // Loop through the outputIndex array, finding contiguous regions that // should be copied. Once we find the end of a contiguous region, copy it // and add any necessary padding (with defaultValue). let srcStart = 0; // Start of contiguous region (in values) let dstStart = 0; // Destination for contiguous region (in output) let dstEnd = 0; // Destination for contiguous region (in output) for (let srcI = 0; srcI <= outputIndexSize; ++srcI) { // dstI is the destination where the value at srcI should be copied. let dstI = srcI < outputIndexSize ? outputIndex[srcI] : -1; // If we're still in a contiguous region, then update dstEnd go to the // next srcI. if (dstI === dstEnd) { ++dstEnd; continue; } // We found the end of contiguous region. This can be because we found // a gap (dstI > dstEnd), or a source value that shouldn't be copied // because it's out-of-bounds (dstI == -1), or the end of the tensor // (dstI === -1). if (dstStart < dstEnd) { // Copy the contiguous region. const src = valuesBase.subarray(srcStart * valueElementSize); const dst = outputBase.subarray(dstStart * valueElementSize); const nVals = (dstEnd - dstStart) * valueElementSize; copyArray(dst, src, nVals); } // Add any necessary padding (w/ defaultValue). if (srcI >= outputIndexSize) { // We reached the end of values: pad to the end of output. const outputSize = outputTensor.length; dstI = Math.floor(outputSize / valueElementSize); } if (dstI > dstEnd) { if (this.defaultValue.length === 1) { outputBase .subarray(dstEnd * valueElementSize, dstI * valueElementSize) .fill(this.defaultValue[0]); dstEnd = dstI; } else { while (dstI > dstEnd) { const dst = outputBase.slice(dstEnd * valueElementSize); copyArray(dst, defaultValue, valueElementSize); ++dstEnd; } } } // Update indices. if (dstI < 0) { // srcI should be skipped -- leave it out of the contiguous region. srcStart = srcI + 1; dstStart = dstEnd; } else { // srcI should be copied -- include it in the contiguous region. srcStart = srcI; dstStart = dstEnd; dstEnd = dstStart + 1; } } } } function copyArray(dst, src, size) { for (let i = 0; i < size; i++) { dst[i] = src[i]; } } function makeShape(shape, isPartial) { const out = []; for (let dim of shape) { if (dim < 0) { if (!isPartial) { throw new Error(`Dimension ${dim} must be >= 0`); } if (dim < -1) { throw new Error(`Dimension ${dim} must be >= -1`); } dim = -1; } out.push(dim); } return out; } function raggedTensorToTensorImpl(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes) { return new RaggedTensorToTensorOp(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes) .compute(); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function rangeImpl(start, stop, step, dtype) { const sameStartStop = start === stop; const increasingRangeNegativeStep = start < stop && step < 0; const decreasingRangePositiveStep = stop < start && step > 1; if (sameStartStop || increasingRangeNegativeStep || decreasingRangePositiveStep) { return tf.util.makeZerosTypedArray(0, dtype); } const numElements = Math.abs(Math.ceil((stop - start) / step)); const values = tf.util.makeZerosTypedArray(numElements, dtype); if (stop < start && step === 1) { // Auto adjust the step's sign if it hasn't been set // (or was set to 1) step = -1; } values[0] = start; for (let i = 1; i < values.length; i++) { values[i] = values[i - 1] + step; } return values; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const rsqrtImpl = createSimpleUnaryImpl((xi) => 1 / Math.sqrt(xi)); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function scatterImpl(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) { const flattenShape = [outputSize / sliceSize, sliceSize]; const indicesData = indices.values; const updatesData = updates.values; if (outputSize === 0) { return tf.buffer(shape, updates.dtype); } const outBuf = (defaultValue instanceof tf.TensorBuffer) ? defaultValue : tf.buffer(flattenShape, updates.dtype); if (typeof defaultValue === 'string') { outBuf.values.fill(defaultValue); } else if (typeof defaultValue === 'number') { outBuf.values.fill(defaultValue); } else if (typeof defaultValue === 'boolean') { outBuf.values.fill(+defaultValue); } for (let i = 0; i < numUpdates; i++) { const index = []; let flattenIndex = 0; for (let j = 0; j < sliceRank; j++) { const dim = indicesData[i * sliceRank + j]; index.push(dim); flattenIndex += dim * strides[j]; } if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) { throw new Error(`Invalid indices: ${index} does not index into ${shape}`); } for (let k = 0; k < sliceSize; k++) { if (sumDupeIndices) { outBuf.values[flattenIndex * sliceSize + k] += updatesData[i * sliceSize + k]; } else { outBuf.values[flattenIndex * sliceSize + k] = updates.rank === 0 ? updatesData[0] : updatesData[i * sliceSize + k]; } } } return outBuf; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const sigmoidImpl = createSimpleUnaryImpl((xi) => 1 / (1 + Math.exp(-xi))); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sliceImpl(vals, begin, size, shape, dtype) { const isContinous = tf.slice_util.isSliceContinous(shape, begin, size); const length = tf.util.sizeFromShape(size); const xStrides = tf.util.computeStrides(shape); if (isContinous) { const flatOffset = tf.slice_util.computeFlatOffset(begin, xStrides); if (dtype === 'string') { return vals.slice(flatOffset, flatOffset + length); } return vals.subarray(flatOffset, flatOffset + length); } const decodedData = dtype === 'string' ? tf.backend_util.fromUint8ToStringArray(vals) : vals; const inBuf = tf.buffer(shape, dtype, decodedData); const outBuf = tf.buffer(size, dtype); for (let i = 0; i < outBuf.size; ++i) { const outLoc = outBuf.indexToLoc(i); const inLoc = outLoc.map((idx, j) => idx + begin[j]); outBuf.set(inBuf.get(...inLoc), ...outLoc); } if (dtype === 'string') { return tf.backend_util.fromStringArrayToUint8(outBuf.values); } return outBuf.values; } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) { const indicesCount = indicesShape[0]; const denseRows = denseShape[0]; const emptyRowIndicator = new Array(denseRows); const reverseIndexMap = new Array(indicesCount); const rank = indicesShape[1]; if (denseRows === 0) { if (indicesCount !== 0) { throw new Error(tf.backend_util.getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesCount)); } const outputIndices = tf.util.getArrayFromDType(indicesDType, 0); const outputValues = tf.util.getArrayFromDType(valuesDType, 0); return [ outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap ]; } let rowsAreOrdered = true; let lastIndicesRow = 0; const csrOffset = new Array(denseRows).fill(0); for (let i = 0; i < indicesCount; ++i) { // indices is a 2d tensor with shape of [N, rank] const row = indices[i * rank]; if (row < 0) { throw new Error(tf.backend_util.getSparseFillEmptyRowsNegativeIndexErrorMessage(i, row)); } if (row >= denseRows) { throw new Error(tf.backend_util.getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(i, row, denseRows)); } ++csrOffset[row]; rowsAreOrdered = rowsAreOrdered && (row >= lastIndicesRow); lastIndicesRow = row; } let allRowsFull = true; for (let row = 0; row < denseRows; ++row) { // csrOffset here describes the number of elements in this dense row const rowEmpty = (csrOffset[row] === 0); emptyRowIndicator[row] = rowEmpty; allRowsFull = allRowsFull && !rowEmpty; // In filled version, each row has at least one element. csrOffset[row] = Math.max(csrOffset[row], 1); // Update csrOffset to represent the number of elements up to and // including denseRows + 1: // csrOffset[0] == #{elements of row 0} // csrOffset[1] == #{elements of row 1} + #{elements of row 0} // .. // csrOffset[i] == starting index for elements in row i + 1. if (row > 0) { csrOffset[row] += csrOffset[row - 1]; } } if (allRowsFull && rowsAreOrdered) { const outputIndices = indices; const outputValues = values; for (let i = 0; i < indicesCount; ++i) { reverseIndexMap[i] = i; } return [ outputIndices, [indicesCount, rank], outputValues, emptyRowIndicator, reverseIndexMap ]; } else { const fullIndicesCount = csrOffset[denseRows - 1]; const outputIndices = tf.util.getArrayFromDType(indicesDType, fullIndicesCount * rank); const outputValues = tf.util.getArrayFromDType(valuesDType, fullIndicesCount); const filledCount = new Array(denseRows).fill(0); // Fill in values for rows that are not missing for (let i = 0; i < indicesCount; ++i) { // indices is a 2d tensor with shape of [N, rank] const row = indices[i * rank]; const offset = filledCount[row]; const outputI = ((row === 0) ? 0 : csrOffset[row - 1]) + offset; filledCount[row]++; // Increment the filled count for this row. for (let j = 0; j < rank; ++j) { // indices and outputIndices are 2d tensors with shape of [N, rank] outputIndices[outputI * rank + j] = indices[i * rank + j]; } outputValues[outputI] = values[i]; // We'll need this reverse index map to backprop correctly. reverseIndexMap[i] = outputI; } // Fill in values for rows that are missing for (let row = 0; row < denseRows; ++row) { const rowCount = filledCount[row]; if (rowCount === 0) { // We haven't filled this row const startingIndex = (row === 0) ? 0 : csrOffset[row - 1]; // Remaining index values were set to zero already. // Just need to set the row index in the right location. // outputIndices is a 2d tensor with shape of [N, rank] outputIndices[startingIndex * rank + 0] = row; for (let col = 1; col < rank; ++col) { outputIndices[startingIndex * rank + col] = 0; } outputValues[startingIndex] = defaultValue; } } return [ outputIndices, [fullIndicesCount, rank], outputValues, emptyRowIndicator, reverseIndexMap ]; } } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) { const denseSize = tf.util.sizeFromShape(inputShape); const nnz = inputIndicesShape[0]; const outputRank = targetShape.length; // Compute the output shape. Determine product of specified dimensions, and // find the index of the unspecified one. const outputShape = []; let product = 1; let unknownIndex = -1; for (let d = 0; d < outputRank; ++d) { const size = targetShape[d]; if (size === -1) { if (unknownIndex !== -1) { throw new Error(tf.backend_util .getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(unknownIndex, d)); } unknownIndex = d; outputShape.push(1); } else { if (size < 0) { throw new Error(tf.backend_util.getSparseReshapeNegativeOutputDimErrorMessage(d, size)); } product *= size; outputShape.push(size); } } if (unknownIndex !== -1) { if (product <= 0) { throw new Error(tf.backend_util.getSparseReshapeEmptyTensorZeroOutputDimErrorMessage()); } const missing = Math.trunc(denseSize / product); if (product * missing !== denseSize) { throw new Error(tf.backend_util.getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape)); } outputShape[unknownIndex] = missing; } const outputSize = tf.util.sizeFromShape(outputShape); if (outputSize !== denseSize) { throw new Error(tf.backend_util.getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape)); } const inputRank = inputShape.length; const inputStrides = []; if (inputRank > 0) { inputStrides[inputRank - 1] = 1; for (let d = inputRank - 2; d >= 0; --d) { inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1]; } } const outputStrides = []; if (outputRank > 0) { outputStrides[outputRank - 1] = 1; for (let d = outputRank - 2; d >= 0; --d) { outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1]; } } const newIndices = tf.util.getArrayFromDType(inputDType, nnz * outputRank); for (let i = 0; i < nnz; ++i) { let id = 0; for (let j = 0; j < inputRank; ++j) { // inputIndices is a 2d tensor with shape of [nnz, inputRank] id += inputIndices[i * inputRank + j] * inputStrides[j]; } for (let j = 0; j < outputRank; ++j) { // newIndices is a 2d tensor with shape of [nnz, outputRank] newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]); id %= outputStrides[j]; } } return [newIndices, [nnz, outputRank], outputShape]; } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds, isMean = false, defaultValue = 0) { const numIndices = indices.length; // Flatten the array to two dimensions const inputFlat = [inputShape[0], input.length / inputShape[0]]; const numCol = inputFlat[1]; // Note that the current implementation assumes that segmentIds values are // sorted. const lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0; const outputRows = lastSegmentIdPlusOne; if (outputRows < 0) { throw new Error(tf.backend_util.getSparseSegmentReductionNegativeSegmentIdsErrorMessage()); } const outputShape = inputShape.slice(); outputShape[0] = outputRows; const outputLength = outputShape.reduce((product, value) => product * value, 1); // Output array is initialized with the value 0 by default. const output = tf.util.getArrayFromDType(inputDType, outputLength); // Note that we do not initialize the output buffer with a default value, so // we need to explicitly set missing indices to the default value. if (numIndices === 0) { if (outputRows > 0) { output.fill(defaultValue); } return [output, outputShape]; } if (outputRows <= 0) { throw new Error(tf.backend_util.getSparseSegmentReductionNegativeSegmentIdsErrorMessage()); } let start = 0, end = 1; // Index from which the output is not initialized. let uninitializedIndex = 0; let outIndex = segmentIds[start]; while (true) { // We initialize nextIndex to 0 to avoid may be uninitialized warning let nextIndex = 0; if (end < numIndices) { nextIndex = segmentIds[end]; if (outIndex === nextIndex) { ++end; continue; } // We have a new segment here. Verify that the segment ids are growing. if (outIndex >= nextIndex) { throw new Error(tf.backend_util .getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage()); } } if (outIndex < 0 || outIndex >= outputRows) { throw new Error(tf.backend_util.getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(outIndex, outputRows)); } // If there is a gap between two indices, we need to set that gap to the // default value. if (outIndex > uninitializedIndex) { output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol); } for (let i = start; i < end; ++i) { const index = indices[i]; if (index < 0 || index >= inputFlat[0]) { throw new Error(tf.backend_util.getSparseSegmentReductionIndicesOutOfRangeErrorMessage(i, indices[i], inputFlat[0])); } for (let j = 0; j < numCol; j++) { output[outIndex * numCol + j] += input[index * numCol + j]; } } if (isMean) { for (let j = 0; j < numCol; j++) { output[outIndex * numCol + j] /= end - start; } } start = end; ++end; uninitializedIndex = outIndex + 1; outIndex = nextIndex; if (end > numIndices) { break; } } // Fill the gap at the end with the default value. if (uninitializedIndex < outputRows) { output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol); } return [output, outputShape]; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const sqrtImpl = createSimpleUnaryImpl((xi) => Math.sqrt(xi)); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const squaredDifferenceImpl = createSimpleBinaryKernelImpl(((a, b) => { const diff = a - b; return diff * diff; })); /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const staticRegexReplaceImpl = createSimpleUnaryImpl((x, attrs) => { const { pattern, replaceGlobal, rewrite } = attrs; // TODO(mattSoulanille): Don't create a regex each time. return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite); }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function stridedSliceImpl(outShape, xBuf, strides, begin) { const outBuf = tf.buffer(outShape, xBuf.dtype); for (let i = 0; i < outBuf.size; i++) { const loc = outBuf.indexToLoc(i); const newLoc = new Array(loc.length); for (let j = 0; j < newLoc.length; j++) { newLoc[j] = loc[j] * strides[j] + begin[j]; } outBuf.set(xBuf.get(...newLoc), ...loc); } return outBuf; } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * The StringNGramsOp class creates ngrams from ragged string data. * The constructor contains all attributes related to the operation such as * padding widths and strings, and the compute function can be used to * compute the ngrams for different ragged tensor inputs. */ class StringNGramsOp { constructor(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) { this.separator = tf.util.encodeString(separator); this.nGramWidths = nGramWidths; this.leftPad = tf.util.encodeString(leftPad); this.rightPad = tf.util.encodeString(rightPad); this.padWidth = padWidth; this.preserveShort = preserveShortSequences; } getPadWidth(nGramWidth) { // Ngrams can be padded with either a fixed pad width or a dynamic pad // width depending on the 'padWidth' arg, but in no case should the padding // ever be wider than 'nGramWidth' - 1. return Math.min(this.padWidth < 0 ? nGramWidth - 1 : this.padWidth, nGramWidth - 1); } getNumNGrams(length, nGramWidth) { const padWidth = this.getPadWidth(nGramWidth); return Math.max(0, ((length + 2 * padWidth) - nGramWidth) + 1); } createNGrams(data, splitIndex, output, outputStartIndex, numNGrams, nGramWidth) { for (let nGramIndex = 0; nGramIndex < numNGrams; ++nGramIndex) { const padWidth = this.getPadWidth(nGramWidth); const leftPadding = Math.max(0, padWidth - nGramIndex); const rightPadding = Math.max(0, padWidth - (numNGrams - (nGramIndex + 1))); const numTokens = nGramWidth - (leftPadding + rightPadding); const dataStartIndex = splitIndex + (leftPadding > 0 ? 0 : nGramIndex - padWidth); // Calculate the total expected size of the nGram so we can reserve the // correct amount of space in the string. let nGramSize = 0; // Size of the left padding. nGramSize += leftPadding * this.leftPad.length; // Size of the tokens. for (let n = 0; n < numTokens; ++n) { nGramSize += data[dataStartIndex + n].length; } // Size of the right padding. nGramSize += rightPadding * this.rightPad.length; // Size of the separators. const numSeparators = leftPadding + rightPadding + numTokens - 1; nGramSize += numSeparators * this.separator.length; // Build the nGram. output[outputStartIndex + nGramIndex] = new Uint8Array(nGramSize); const nGram = output[outputStartIndex + nGramIndex]; let nextNGramIndex = 0; const appendToNGram = (str) => str.forEach((value) => nGram[nextNGramIndex++] = value); for (let n = 0; n < leftPadding; ++n) { appendToNGram(this.leftPad); appendToNGram(this.separator); } // Only output first numTokens - 1 pairs of data and separator for (let n = 0; n < numTokens - 1; ++n) { appendToNGram(data[dataStartIndex + n]); appendToNGram(this.separator); } // Handle case when there are no tokens or no right padding as these // can result in consecutive separators. if (numTokens > 0) { // If we have tokens, then output last and then pair each separator // with the right padding that follows, to ensure nGram ends either with // the token or with the right pad. appendToNGram(data[dataStartIndex + numTokens - 1]); for (let n = 0; n < rightPadding; ++n) { appendToNGram(this.separator); appendToNGram(this.rightPad); } } else { // If we don't have tokens, then the last item inserted into the nGram // has been the separator from the left padding loop above. Hence, // output right pad and separator and make sure to finish with a // padding, not a separator. for (let n = 0; n < rightPadding - 1; ++n) { appendToNGram(this.rightPad); appendToNGram(this.separator); } appendToNGram(this.rightPad); } } } // Data and splits together form the definition of the ragged tensor, // where data is 1 dimensional and contains the values of the tensor // and splits denotes the indices at which each row starts. compute(data, splits) { // Validate that the splits are valid indices into data, only if there are // splits specified. const inputDataSize = data.length; const splitsSize = splits.length; if (splitsSize > 0) { let prevSplit = splits[0]; if (prevSplit !== 0) { throw new Error(`First split value must be 0, got ${prevSplit}`); } for (let i = 1; i < splitsSize; ++i) { let validSplits = splits[i] >= prevSplit; validSplits = validSplits && (splits[i] <= inputDataSize); if (!validSplits) { throw new Error(`Invalid split value ${splits[i]}, must be in [${prevSplit}, ${inputDataSize}]`); } prevSplit = splits[i]; } if (prevSplit !== inputDataSize) { throw new Error(`Last split value must be data size. Expected ${inputDataSize}, got ${prevSplit}`); } } const numBatchItems = splitsSize - 1; const nGramsSplits = tf.util.getArrayFromDType('int32', splitsSize); // If there is no data or size, return an empty ragged tensor. if (inputDataSize === 0 || splitsSize === 0) { const empty = new Array(inputDataSize); for (let i = 0; i <= numBatchItems; ++i) { nGramsSplits[i] = 0; } return [empty, nGramsSplits]; } nGramsSplits[0] = 0; for (let i = 1; i <= numBatchItems; ++i) { const length = splits[i] - splits[i - 1]; let numNGrams = 0; this.nGramWidths.forEach((nGramWidth) => { numNGrams += this.getNumNGrams(length, nGramWidth); }); if (this.preserveShort && length > 0 && numNGrams === 0) { numNGrams = 1; } nGramsSplits[i] = nGramsSplits[i - 1] + numNGrams; } const nGrams = new Array(nGramsSplits[numBatchItems]); for (let i = 0; i < numBatchItems; ++i) { const splitIndex = splits[i]; let outputStartIdx = nGramsSplits[i]; this.nGramWidths.forEach((nGramWidth) => { const length = splits[i + 1] - splits[i]; const numNGrams = this.getNumNGrams(length, nGramWidth); this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth); outputStartIdx += numNGrams; }); // If we're preserving short sequences, check to see if no sequence was // generated by comparing the current output start idx to the original // one (nGramSplitsdata). If no ngrams were generated, then they will // be equal (since we increment outputStartIdx by numNGrams every // time we create a set of ngrams.) if (this.preserveShort && outputStartIdx === nGramsSplits[i]) { const dataLength = splits[i + 1] - splits[i]; // One legitimate reason to not have any ngrams when this.preserveShort // is true is if the sequence itself is empty. In that case, move on. if (dataLength === 0) { continue; } // We don't have to worry about dynamic padding sizes here: if padding // was dynamic, every sequence would have had sufficient padding to // generate at least one nGram. const nGramWidth = dataLength + 2 * this.padWidth; const numNGrams = 1; this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth); } } return [nGrams, nGramsSplits]; } } function stringNGramsImpl(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) { return new StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) .compute(data, dataSplits); } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function split(str, delimiters, skipEmpty, result) { if (!str.length) { return; } // When the delimiter is empty, the input is split into individual characters. if (delimiters.length === 0) { for (let i = 0; i < str.length; ++i) { result.push(str.subarray(i, i + 1)); } return; } // When there is one delimiter, the input is split only at that delimiter. if (delimiters.length === 1) { const delimiter = delimiters[0]; let f = str.indexOf(delimiter); while (f !== -1) { const token = str.subarray(0, f); if (!skipEmpty || token.length !== 0) { result.push(token); } str = str.subarray(f + 1); f = str.indexOf(delimiter); } if (!skipEmpty || str.length !== 0) { result.push(str); } return; } // When there are multiple delimiters, the input is split at every instance // one of the delimiters appears. let tokenStart = 0; for (let i = 0; i < str.length + 1; i++) { if ((i === str.length) || (delimiters.indexOf(str[i]) !== -1)) { const token = str.subarray(tokenStart, i); if (!skipEmpty || token.length !== 0) { result.push(token); } tokenStart = i + 1; } } } function stringSplitImpl(input, delimiter, skipEmpty) { const batchSize = input.length; // Empty delimiter means split the input character by character. const tokens = []; let outputSize = 0; let maxNumEntries = 0; const numIndices = new Array(batchSize); for (let i = 0; i < batchSize; ++i) { const prevTokensLength = tokens.length; split(input[i], delimiter, skipEmpty, tokens); const nEntries = tokens.length - prevTokensLength; numIndices[i] = nEntries; outputSize += nEntries; maxNumEntries = Math.max(maxNumEntries, nEntries); } const indices = tf.util.getArrayFromDType('int32', outputSize * 2); const values = new Array(outputSize); const shape = [batchSize, maxNumEntries]; let c = 0; for (let i = 0; i < batchSize; ++i) { for (let j = 0; j < numIndices[i]; ++j) { // indices is a 2d tensor with shape of [outputSize, 2] indices[c * 2] = i; indices[c * 2 + 1] = j; values[c] = tokens[c]; ++c; } } return [indices, values, shape]; } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function stringToHashBucketFastImpl(input, numBuckets) { const output = tf.util.getArrayFromDType('int32', input.length); for (let i = 0; i < input.length; ++i) { output[i] = tf.util.fingerPrint64(input[i]).modulo(numBuckets).getLowBitsUnsigned(); } return output; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const subImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue - bValue)); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * An implementation of the tile kernel shared between webgl and cpu for string * tensors only. */ function tileImpl(xBuf, reps) { const newShape = new Array(xBuf.rank); for (let i = 0; i < newShape.length; i++) { newShape[i] = xBuf.shape[i] * reps[i]; } const result = tf.buffer(newShape, xBuf.dtype); for (let i = 0; i < result.values.length; ++i) { const newLoc = result.indexToLoc(i); const originalLoc = new Array(xBuf.rank); for (let j = 0; j < originalLoc.length; j++) { originalLoc[j] = newLoc[j] % xBuf.shape[j]; } const originalIndex = xBuf.locToIndex(originalLoc); result.values[i] = xBuf.values[originalIndex]; } return result; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const comparePair = (a, b) => { const valueDiff = b.value - a.value; return valueDiff === 0 ? a.index - b.index : valueDiff; }; /** * Partitions array where all elements smaller than the (k+1) smallest element * are found to the left of it, and all larger to the right of it. * Based on the Floyd-Rivest Algorithm, ref: * https://en.wikipedia.org/wiki/Floyd%E2%80%93Rivest_algorithm * @param array: Array to partition * @param left: Left index for the interval * @param right: Right index for the interval * @param k: Desired index value, where array[k] is the (k+1)th smallest element * when left = 0 */ function select$1(array, k, left = 0, right = array.length - 1) { while (right > left) { // Use select recursively to sample a smaller set of size s // the arbitrary constants 600 and 0.5 are used in the original // version to minimize execution time. if (right - left > 600) { const n = right - left + 1; const i = k - left + 1; const z = Math.log(n); const s = 0.5 * Math.exp(2 * z / 3); const sd = 0.5 * Math.sqrt(z * s * (n - s) / n) * Math.sign(i - n / 2); const newLeft = Math.max(left, Math.floor(k - i * s / n + sd)); const newRight = Math.min(right, Math.floor(k + (n - i) * s / n + sd)); select$1(array, k, newLeft, newRight); } // partition the elements between left and right around t const t = array[k]; let i = left; let j = right; tf.util.swap(array, left, k); if (comparePair(array[right], t) > 0) { tf.util.swap(array, left, right); } while (i < j) { tf.util.swap(array, i, j); i++; j--; while (comparePair(array[i], t) < 0) { i = i + 1; } while (comparePair(array[j], t) > 0) { j = j - 1; } } if (comparePair(array[left], t) === 0) { tf.util.swap(array, left, j); } else { j = j + 1; tf.util.swap(array, j, right); } // Adjust left and right towards the boundaries of the subset // containing the (k - left + 1)th smallest element. if (j <= k) { left = j + 1; } if (k <= j) { right = j - 1; } } } function topKImpl(x, xShape, xDtype, k, sorted) { // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim. const lastDim = xShape[xShape.length - 1]; const [batch, size] = [x.length / lastDim, lastDim]; const allTopKVals = tf.util.getTypedArrayFromDType(xDtype, batch * k); const allTopKIndices = tf.util.getTypedArrayFromDType('int32', batch * k); for (let b = 0; b < batch; b++) { const offset = b * size; const vals = x.subarray(offset, offset + size); let valAndInd = new Array(vals.length); vals.forEach((value, index) => valAndInd[index] = { value, index }); if (k < valAndInd.length) { select$1(valAndInd, k); valAndInd = valAndInd.slice(0, k); } if (sorted) { valAndInd.sort(comparePair); } const outOffset = b * k; const topKVals = allTopKVals.subarray(outOffset, outOffset + k); const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k); for (let i = 0; i < k; i++) { topKVals[i] = valAndInd[i].value; topKIndices[i] = valAndInd[i].index; } } // Reshape back to the original input shape, except that the last // dimension is k. const outputShape = xShape.slice(); outputShape[outputShape.length - 1] = k; return [ tf.buffer(outputShape, xDtype, allTopKVals), tf.buffer(outputShape, 'int32', allTopKIndices) ]; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function uniqueImpl(values, axis, shape, dtype) { // Normalize and validate axis. const $axis = tf.util.parseAxisParam(axis, shape)[0]; // Calculate the new shape that is suitable for extracting data along the // given axis. // // The rank is 3. // The size of the 1st dimension is the size of all the axes < the given axis. // The size of the 2nd dimension is the same as the size of the given axis. // The size of the 3rd dimension is the size of all the axes > the given axis. // // For example, for a 4D tensor with shape=[2, 3, 5, 4] and axis=2, the // newShape would be: [2*3, 5, 4]. // // Note that this is not the final output shape. This will be the shape for an // intermediate TensorBuffer (see inputBuffer below) to allow us to extract // values along the given axis. To demonstrate how it works, consider the // following example: // // Input: a 3D tensor, with shape [1, 2, 3] // [ // [ // [1,2,3], // [4,5,6] // ] // ] // Axis: 2 (the last axis). // Along axis 2, we expect to extract 3 tensors: [1,4], [2,5], [3,6]. // // For this example, newShape would be: [2, 3, 1], where 2 is calculated from // 1*2. The re-shaped data would look like: // // [ // [ // [1], [2], [3] // ], // [ // [4], [5], [6] // ] // ] // // Then, we can construct a 3-level nested loop by the following dimension // order to extract the values along the axis (dimension1): // i: dimension1 // 0,1,2 (newShape[1]) // m: dimension0 // 0,1 (newShape[0]) // n: dimension2 // 0 (newShape[2]) // // m, i, n // --------- // Iteration 0: data at [0, 0, 0] => "1" // Iteration 1: data at [1, 0, 0] => "4" // We got [1,4]. // Iteration 2: data at [0, 1, 0] => "2" // Iteration 3: data at [1, 1, 0] => "5" // We got [2,5]. // Iteration 4: data at [0, 2, 0] => "3" // Iteration 5: data at [1, 2, 0] => "6" // We got [3,6]. const newShape = [1, shape[0], 1]; for (let i = 0; i < $axis; i++) { newShape[0] *= shape[i]; } newShape[1] = shape[$axis]; for (let i = $axis + 1; i < shape.length; i++) { newShape[2] *= shape[i]; } // A map from unique elements (their string representations) to their values // in "indices" (below). const uniqueElements = new Map(); // The indices of each unique element in the original tensor along the given // axis. It is 1D and has the same size as the given axis. const indices = new Int32Array(shape[$axis]); // Create a buffer so we can easily extract value at a given location. const inputBuffer = new tf.TensorBuffer(newShape, dtype, values); // The indices along the given axis that have unique elements. This is a // de-duped version of "indices" above. const uniqueIndices = []; const is1DTensor = newShape[0] === 1 && newShape[2] === 1; for (let i = 0; i < shape[$axis]; i++) { // Extract values along the axis. let element; if (is1DTensor) { // Fast path for 1D tensor input. element = values[i].toString(); } else { const axisValues = []; for (let m = 0; m < newShape[0]; m++) { for (let n = 0; n < newShape[2]; n++) { axisValues.push(inputBuffer.get(m, i, n)); } } element = axisValues.join(','); } // Dedup and update various indices. const existingIndex = uniqueElements.get(element); if (existingIndex != null) { indices[i] = existingIndex; } else { const uniqueIndex = uniqueElements.size; uniqueElements.set(element, uniqueIndex); indices[i] = uniqueIndex; uniqueIndices.push(i); } } // Now we know where each of the unique elements are located along the axis // (uniqueIndices). Extract them from input buffer and store them in the // output buffer. const outputTmpShape = newShape.slice(); outputTmpShape[1] = uniqueElements.size; const outputBuffer = new tf.TensorBuffer(outputTmpShape, dtype); uniqueIndices.forEach((uniqueElementIndex, i) => { for (let m = 0; m < newShape[0]; m++) { for (let n = 0; n < newShape[2]; n++) { outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n); } } }); // The output shape can be calculated from the input shape with the size of // the given axis replaced by the number of unique elements along that axis. const outputShape = shape.slice(); outputShape[$axis] = outputTmpShape[1]; return { outputValues: outputBuffer.values, outputShape, indices, }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var shared = { __proto__: null, addImpl: addImpl, bincountImpl: bincountImpl, bincountReduceImpl: bincountReduceImpl, bitwiseAndImpl: bitwiseAndImpl, castImpl: castImpl, ceilImpl: ceilImpl, concatImpl: concatImpl$1, equalImpl: equalImpl, expImpl: expImpl, expm1Impl: expm1Impl, floorDivImpl: floorDivImpl, floorImpl: floorImpl, gatherNdImpl: gatherNdImpl, gatherV2Impl: gatherV2Impl, greaterEqualImpl: greaterEqualImpl, greaterImpl: greaterImpl, lessEqualImpl: lessEqualImpl, lessImpl: lessImpl, linSpaceImpl: linSpaceImpl, logImpl: logImpl, maxImpl: maxImpl$1, maximumImpl: maximumImpl, minimumImpl: minimumImpl, multiplyImpl: multiplyImpl, negImpl: negImpl, notEqualImpl: notEqualImpl, prodImpl: prodImpl, raggedGatherImpl: raggedGatherImpl, raggedRangeImpl: raggedRangeImpl, raggedTensorToTensorImpl: raggedTensorToTensorImpl, rangeImpl: rangeImpl, rsqrtImpl: rsqrtImpl, scatterImpl: scatterImpl, sigmoidImpl: sigmoidImpl, simpleAbsImpl: simpleAbsImpl, sliceImpl: sliceImpl, sparseFillEmptyRowsImpl: sparseFillEmptyRowsImpl, sparseReshapeImpl: sparseReshapeImpl, sparseSegmentReductionImpl: sparseSegmentReductionImpl, sqrtImpl: sqrtImpl, squaredDifferenceImpl: squaredDifferenceImpl, staticRegexReplaceImpl: staticRegexReplaceImpl, stridedSliceImpl: stridedSliceImpl, stringNGramsImpl: stringNGramsImpl, stringSplitImpl: stringSplitImpl, stringToHashBucketFastImpl: stringToHashBucketFastImpl, subImpl: subImpl, tileImpl: tileImpl, topKImpl: topKImpl, transposeImpl: transposeImpl$1, uniqueImpl: uniqueImpl }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const { addImpl: addImplCPU, bincountImpl: bincountImplCPU, bincountReduceImpl: bincountReduceImplCPU, bitwiseAndImpl: bitwiseAndImplCPU, castImpl: castImplCPU, ceilImpl: ceilImplCPU, concatImpl: concatImplCPU, equalImpl: equalImplCPU, expImpl: expImplCPU, expm1Impl: expm1ImplCPU, floorImpl: floorImplCPU, gatherNdImpl: gatherNdImplCPU, gatherV2Impl: gatherV2ImplCPU, greaterImpl: greaterImplCPU, greaterEqualImpl: greaterEqualImplCPU, lessImpl: lessImplCPU, lessEqualImpl: lessEqualImplCPU, linSpaceImpl: linSpaceImplCPU, logImpl: logImplCPU, maxImpl: maxImplCPU, maximumImpl: maximumImplCPU, minimumImpl: minimumImplCPU, multiplyImpl: multiplyImplCPU, negImpl: negImplCPU, notEqualImpl: notEqualImplCPU, prodImpl: prodImplCPU, raggedGatherImpl: raggedGatherImplCPU, raggedRangeImpl: raggedRangeImplCPU, raggedTensorToTensorImpl: raggedTensorToTensorImplCPU, rangeImpl: rangeImplCPU, rsqrtImpl: rsqrtImplCPU, scatterImpl: scatterImplCPU, sigmoidImpl: sigmoidImplCPU, simpleAbsImpl: simpleAbsImplCPU, sliceImpl: sliceImplCPU, sparseFillEmptyRowsImpl: sparseFillEmptyRowsImplCPU, sparseReshapeImpl: sparseReshapeImplCPU, sparseSegmentReductionImpl: sparseSegmentReductionImplCPU, sqrtImpl: sqrtImplCPU, staticRegexReplaceImpl: staticRegexReplaceImplCPU, stridedSliceImpl: stridedSliceImplCPU, stringNGramsImpl: stringNGramsImplCPU, stringSplitImpl: stringSplitImplCPU, stringToHashBucketFastImpl: stringToHashBucketFastImplCPU, subImpl: subImplCPU, tileImpl: tileImplCPU, topKImpl: topKImplCPU, transposeImpl: transposeImplCPU, uniqueImpl: uniqueImplCPU, } = shared; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function getVecChannels(name, rank) { return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(d => `${name}.${d}`); } function getChannels(name, rank) { if (rank === 1) { return [name]; } return getVecChannels(name, rank); } function getSourceCoords$2(rank, dims) { if (rank === 1) { return 'rc'; } let coords = ''; for (let i = 0; i < rank; i++) { coords += dims[i]; if (i < rank - 1) { coords += ','; } } return coords; } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class PackProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; // Only input / output 3D tensors. this.outputShape = outputShape; this.rank = outputShape.length; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); if (this.rank === 0) { this.userCode = ` void main() { setOutput(vec4(getA(), 0., 0., 0.)); } `; } else { const channels = getChannels('rc', this.rank); const dtype = getCoordsDataType(this.rank); const outOfBoundsCondition = this.getOutOfBoundsCondition(channels); const setup = this.getSetup(channels); const output = this.getOutput(channels); this.userCode = ` void main() { ${dtype} rc = getOutputCoords(); if(${outOfBoundsCondition}) { setOutput(vec4(0)); } else { ${setup} setOutput(vec4(${output})); } } `; } } getSourceCoordsArr(dims) { const coords = []; for (let row = 0; row <= 1; row++) { for (let col = 0; col <= 1; col++) { let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`; for (let d = 2; d < this.rank; d++) { coord = `${dims[dims.length - 1 - d]},` + coord; } coords.push(coord); } } return coords; } getOutOfBoundsCondition(dims) { if (this.rank === 1) { return `rc > ${this.enableShapeUniforms ? 'outShape' : this.outputShape[0]}`; } let cond = ''; for (let i = this.rank - 2; i < this.rank; i++) { cond += `${dims[i]} >= ${this.enableShapeUniforms ? `outShape[${i}]` : this.outputShape[i]}`; if (i < this.rank - 1) { cond += '||'; } } return cond; } getSetup(dims) { if (this.rank === 1) { return ''; } const innerDims = dims.slice(-2); const col = this.enableShapeUniforms ? `outShape[${this.rank} - 1]` : this.outputShape[this.rank - 1]; const row = this.enableShapeUniforms ? `outShape[${this.rank} - 2]` : this.outputShape[this.rank - 2]; return ` int r = ${innerDims[0]}; int c = ${innerDims[1]}; int rp1 = r + 1; int cp1 = c + 1; bool cEdge = cp1 >= ${col}; bool rEdge = rp1 >= ${row}; `; } getOutput(dims) { const sourceCoords = this.getSourceCoordsArr(dims); if (this.rank === 1) { const outShape = this.enableShapeUniforms ? 'outShape' : this.outputShape[0]; return `getA(rc), (rc + 1 >= ${outShape} ? 0. : getA(rc + 1)), 0, 0`; } return `getA(${sourceCoords[0]}), cEdge ? 0. : getA(${sourceCoords[1]}), rEdge ? 0. : getA(${sourceCoords[2]}), rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ReshapePackedProgram { constructor(outputShape, inputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [{ name: 'inputShape', type: 'ivec3' }]; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); let mainLoop = ``; for (let i = 0; i < 4; i++) { let thisRC = `thisRC = rc;`; if (i % 2 === 1) { thisRC += `thisRC.z += 1;`; } if (i > 1) { thisRC += `thisRC.y += 1;`; } mainLoop += ` ${thisRC} ${i > 0 ? `if(thisRC.y < rows && thisRC.z < cols){` : ''} int flatIndex = getFlatIndex(thisRC); ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex); vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z)); result[${i}] = getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims); ${i > 0 ? '}' : ''} `; } this.userCode = ` ${getReshapedInputCoords(inputShape, this.enableShapeUniforms)} ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape)} void main() { ivec3 rc = getOutputCoords(); vec4 result = vec4(0.); ivec3 thisRC; int rows = ${this.enableShapeUniforms ? 'outShape[1]' : outputShape[1]}; int cols = ${this.enableShapeUniforms ? 'outShape[2]' : outputShape[2]}; ${mainLoop} setOutput(result); } `; } } function getReshapedInputCoords(shape, enableShapeUniforms) { const coordsFromIndexSnippet = enableShapeUniforms ? getLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], 'inputShape') : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape); return ` ivec3 inputCoordsFromReshapedOutCoords(int index) { ${coordsFromIndexSnippet} return ivec3(r, c, d); } `; } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class TextureManager { constructor(gpgpu) { this.gpgpu = gpgpu; this.numUsedTextures = 0; this.numFreeTextures = 0; this._numBytesAllocated = 0; // Number of bytes that have been allocated and available for reuse. this._numBytesFree = 0; this.freeTextures = {}; this.usedTextures = {}; this.logEnabled = false; } acquireTexture(shapeRC, usage, isPacked) { const physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked); const shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked); if (!(shapeKey in this.freeTextures)) { this.freeTextures[shapeKey] = []; } if (!(shapeKey in this.usedTextures)) { this.usedTextures[shapeKey] = []; } const texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked); if (this.freeTextures[shapeKey].length > 0) { this.numFreeTextures--; this.numUsedTextures++; this._numBytesFree -= texBytes; this.log(); const newTexture = this.freeTextures[shapeKey].pop(); this.usedTextures[shapeKey].push(newTexture); return newTexture; } let newTexture; if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) { newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) { newTexture = this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) { newTexture = this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) { newTexture = this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) { newTexture = this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]); } this.usedTextures[shapeKey].push(newTexture); this.numUsedTextures++; this._numBytesAllocated += texBytes; this.log(); return newTexture; } releaseTexture(texture, shape, logicalTexType, isPacked) { if (this.freeTextures == null) { // Already disposed. return; } const physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked); const shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked); if (!(shapeKey in this.freeTextures)) { this.freeTextures[shapeKey] = []; } const texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked); const deleteTexThreshold = tf.env() .getNumber('WEBGL_DELETE_TEXTURE_THRESHOLD'); if (deleteTexThreshold !== -1 && this._numBytesAllocated > deleteTexThreshold) { this.gpgpu.deleteMatrixTexture(texture.texture); this._numBytesAllocated -= texBytes; } else { this.freeTextures[shapeKey].push(texture); this.numFreeTextures++; this._numBytesFree += texBytes; } this.numUsedTextures--; const texList = this.usedTextures[shapeKey]; const texIndex = texList && texList.indexOf(texture); if (texIndex == null || texIndex < 0) { throw new Error('Cannot release a texture that was never provided by this ' + 'texture manager'); } texList[texIndex] = texList[texList.length - 1]; texList.pop(); this.log(); } log() { if (!this.logEnabled) { return; } const total = this.numFreeTextures + this.numUsedTextures; console.log('Free/Used', `${this.numFreeTextures} / ${this.numUsedTextures}`, `(${total})`); const freeRatio = this._numBytesFree / this._numBytesAllocated; console.log(`Bytes allocated: ${this._numBytesAllocated}`); console.log(`Bytes unused: ${this._numBytesFree} (${Math.round(100 * freeRatio)}%)`); } get numBytesAllocated() { return this._numBytesAllocated; } get numBytesFree() { return this._numBytesFree; } getNumUsedTextures() { return this.numUsedTextures; } getNumFreeTextures() { return this.numFreeTextures; } dispose() { if (this.freeTextures == null) { // Already disposed. return; } for (const texShape in this.freeTextures) { this.freeTextures[texShape].forEach(tex => { this.gpgpu.deleteMatrixTexture(tex.texture); }); } for (const texShape in this.usedTextures) { this.usedTextures[texShape].forEach(tex => { this.gpgpu.deleteMatrixTexture(tex.texture); }); } // TODO: Assign non-null value (empty object) to textures after disposed. this.freeTextures = null; this.usedTextures = null; this.numUsedTextures = 0; this.numFreeTextures = 0; this._numBytesAllocated = 0; this._numBytesFree = 0; } } function numBytesForInternalFormat(gl, internalFormat) { // tslint:disable-next-line:no-any const glany = gl; if (internalFormat === glany.R32F) { return 4; } else if (internalFormat === glany.R16F) { return 2; } else if (internalFormat === glany.RGBA32F) { return 16; } else if (internalFormat === gl.RGBA) { return 16; } else if (internalFormat === glany.RGBA16F) { return 8; } else if (internalFormat === glany.RGBA8) { return 4; } throw new Error(`Unknown internal format ${internalFormat}`); } function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) { // It is not possible to infer packed status from the texture type because // depending on the textureConfig, different texture types may resolve to the // same internal format (e.g. in WebGL1, the internal format for // UNPACKED_FLOAT16 textures is gl.RGBA). Therefore we pass in `isPacked` // explicitly. const internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig); let numElements; if (isPacked) { const [packedWidth, packedHeight] = getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]); numElements = packedWidth * packedHeight; } else { const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]); numElements = width * height; } const bytesPerElement = numBytesForInternalFormat(gl, internalFormat); return numElements * bytesPerElement; } function internalFormatForPhysicalTexType(physicalTexType, textureConfig) { switch (physicalTexType) { case PhysicalTextureType.PACKED_2X2_FLOAT32: return getInternalFormatForPackedMatrixTexture(textureConfig); case PhysicalTextureType.PACKED_2X2_FLOAT16: return getInternalFormatForFloat16PackedMatrixTexture(textureConfig); case PhysicalTextureType.UNPACKED_FLOAT32: return getInternalFormatForFloat32MatrixTexture(textureConfig); case PhysicalTextureType.UNPACKED_FLOAT16: return getInternalFormatForFloat16MatrixTexture(textureConfig); case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE: return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig); default: throw new Error(`Unknown physical texture type ${physicalTexType}`); } } function getPhysicalTextureForRendering(isPacked) { if (tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) { if (isPacked) { return PhysicalTextureType.PACKED_2X2_FLOAT32; } return PhysicalTextureType.UNPACKED_FLOAT32; } if (isPacked) { return PhysicalTextureType.PACKED_2X2_FLOAT16; } return PhysicalTextureType.UNPACKED_FLOAT16; } function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) { if (logicalTexType === TextureUsage.UPLOAD) { return PhysicalTextureType.PACKED_2X2_FLOAT32; } else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) { return getPhysicalTextureForRendering(isPacked); } else if (logicalTexType === TextureUsage.DOWNLOAD || logicalTexType === TextureUsage.PIXELS) { return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE; } throw new Error(`Unknown logical texture type ${logicalTexType}`); } function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) { return `${shapeRowsCol[0]}_${shapeRowsCol[1]}_${physicalTexType}_${isPacked}`; } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class UnaryOpProgram { constructor(aShape, opSnippet) { this.variableNames = ['A']; this.outputShape = aShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = ` float unaryOperation(float x) { ${opSnippet} } void main() { float x = getAAtOutCoords(); float y = unaryOperation(x); setOutput(y); } `; } } const CHECK_NAN_SNIPPET$1 = `if (isnan(x)) return x;`; const LINEAR$1 = `return x;`; const ABS$1 = `return abs(x);`; const ELU$2 = `return (x >= 0.0) ? x : (exp(x) - 1.0);`; const RELU$2 = CHECK_NAN_SNIPPET$1 + ` return (x < 0.0) ? 0.0 : x; `; const RELU6$2 = CHECK_NAN_SNIPPET$1 + ` return (x < 0.0) ? 0.0 : min(6.0, x); `; const CLONE = 'return x;'; const SIGMOID$2 = `return 1.0 / (1.0 + exp(-1.0 * x));`; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const LINEAR = `return x;`; const ELU$1 = ` vec4 result; result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0); result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0); result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0); result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0); return result; `; const RELU$1 = ` vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0))); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const RELU6$1 = ` vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0))); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const SIGMOID$1 = `return 1.0 / (1.0 + exp(-1.0 * x));`; class UnaryOpPackedProgram { constructor(aShape, opSnippet) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outputShape = aShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = ` vec4 unaryOperation(vec4 x) { ${opSnippet} } void main() { vec4 x = getAAtOutCoords(); vec4 y = unaryOperation(x); setOutput(y); } `; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class UnpackProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = false; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const rank = outputShape.length; const channels = getChannels('rc', rank); const dtype = getCoordsDataType(rank); const sourceCoords = getSourceCoords$2(rank, channels); const innerDims = channels.slice(-2); const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`; this.userCode = ` void main() { ${dtype} rc = getOutputCoords(); vec4 packedInput = getA(${sourceCoords}); setOutput(getChannel(packedInput, ${coords})); } `; } } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const whereImpl = tf.kernel_impls.whereImpl; const EPSILON_FLOAT32 = 1e-7; const EPSILON_FLOAT16 = 1e-4; const binaryCaches = {}; function getBinaryCache(webGLVersion) { if (webGLVersion in binaryCaches) { return binaryCaches[webGLVersion]; } binaryCaches[webGLVersion] = {}; return binaryCaches[webGLVersion]; } // Empirically determined constant used to determine size threshold for handing // off execution to the CPU. const CPU_HANDOFF_SIZE_THRESHOLD = tf.env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD'); // Empirically determined constant used to decide the number of MB on GPU // before we warn about high memory use. The MB are this constant * screen area // * dpi / 1024 / 1024. const BEFORE_PAGING_CONSTANT = 600; function numMBBeforeWarning() { if (tf.env().global.screen == null) { return 1024; // 1 GB. } return (tf.env().global.screen.height * tf.env().global.screen.width * window.devicePixelRatio) * BEFORE_PAGING_CONSTANT / 1024 / 1024; } class MathBackendWebGL extends tf.KernelBackend { nextDataId() { return MathBackendWebGL.nextDataId++; } constructor(gpuResource) { super(); // Maps data ids that have a pending read operation, to list of subscribers. this.pendingRead = new WeakMap(); // List of data ids that are scheduled for disposal, but are waiting on a // pending read operation. this.pendingDisposal = new WeakSet(); // Used to count the number of 'shallow' sliced tensors that point to the // same data id. this.dataRefCount = new WeakMap(); this.numBytesInGPU = 0; // Accumulated time spent (including blocking) in uploading data to webgl. this.uploadWaitMs = 0; // Accumulated time spent (including blocking in downloading data from webgl. this.downloadWaitMs = 0; // record the last manual GL Flush time. this.lastGlFlushTime = 0; this.warnedAboutMemory = false; this.pendingDeletes = 0; this.disposed = false; if (!tf.env().getBool('HAS_WEBGL')) { throw new Error('WebGL is not supported on this device'); } let newGPGPU; if (gpuResource != null) { if (gpuResource instanceof GPGPUContext) { newGPGPU = gpuResource; } else { const gl = getWebGLContext(tf.env().getNumber('WEBGL_VERSION'), gpuResource); newGPGPU = new GPGPUContext(gl); } this.binaryCache = {}; this.gpgpuCreatedLocally = false; } else { const gl = getWebGLContext(tf.env().getNumber('WEBGL_VERSION')); newGPGPU = new GPGPUContext(gl); this.binaryCache = getBinaryCache(tf.env().getNumber('WEBGL_VERSION')); this.gpgpuCreatedLocally = true; } this.gpgpu = newGPGPU; this.canvas = this.gpgpu.gl.canvas; this.textureManager = new TextureManager(this.gpgpu); this.numMBBeforeWarning = numMBBeforeWarning(); this.texData = new tf.DataStorage(this, tf.engine()); } numDataIds() { return this.texData.numDataIds() - this.pendingDeletes; } // Writes a new entry to the data store with a WebGL texture, and registers it // to the texture manager. writeTexture(texture, shape, dtype, texHeight, texWidth, channels) { // Temporarily create an tensor info to make the texture compatible with // the runWebGLProgram's input. const input = this.makeTensorInfo(shape, dtype); const inData = this.texData.get(input.dataId); // Even though the input texture could be unpacked or dense packed, it is // always considered as unpacked for EncodeMatrixProgram. inData.isPacked = false; // Bind texture to the input tensor. inData.texture = { texture, texShape: [texHeight, texWidth] }; inData.texShape = [texHeight, texWidth]; const shapeAs3D = getShapeAs3D(shape); const program = new EncodeMatrixProgram(shapeAs3D, false /* isByteArray */, channels); const output = this.runWebGLProgram(program, [input], dtype, [[texHeight, texWidth]]); output.shape = shape; // Unbind the texture from the input tensor to avoid the texture being // released. inData.texture = null; this.disposeIntermediateTensorInfo(input); return output.dataId; } write(values, shape, dtype) { if (tf.env().getBool('WEBGL_CHECK_NUMERICAL_PROBLEMS') || tf.env().getBool('DEBUG')) { this.checkNumericalProblems(values); } if (dtype === 'complex64' && values != null) { throw new Error(`Cannot write to a complex64 dtype. ` + `Please use tf.complex(real, imag).`); } const dataId = { id: this.nextDataId() }; this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount: 1 }); return dataId; } /** Return refCount of a `TensorData`. */ refCount(dataId) { if (this.texData.has(dataId)) { const tensorData = this.texData.get(dataId); return tensorData.refCount; } return 0; } /** Increase refCount of a `TextureData`. */ incRef(dataId) { const texData = this.texData.get(dataId); texData.refCount++; } /** Decrease refCount of a `TextureData`. */ decRef(dataId) { if (this.texData.has(dataId)) { const texData = this.texData.get(dataId); texData.refCount--; } } move(dataId, values, shape, dtype, refCount) { if (tf.env().getBool('DEBUG')) { this.checkNumericalProblems(values); } if (dtype === 'complex64') { throw new Error(`Cannot write to a complex64 dtype. ` + `Please use tf.complex(real, imag).`); } this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount }); } disposeIntermediateTensorInfo(tensorInfo) { this.disposeData(tensorInfo.dataId); } readSync(dataId) { const texData = this.texData.get(dataId); const { values, dtype, complexTensorInfos, slice, shape, isPacked } = texData; // The presence of `slice` indicates this tensor is a shallow slice of a // different tensor, and is using that original tensor's texture. Run // `clone` in order to copy that texture and read from it. if (slice != null) { let program; if (isPacked) { program = new UnaryOpPackedProgram(shape, CLONE); } else { program = new UnaryOpProgram(shape, CLONE); } const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype); const data = this.readSync(res.dataId); this.disposeIntermediateTensorInfo(res); return data; } if (values != null) { return this.convertAndCacheOnCPU(dataId); } if (dtype === 'string') { return values; } const shouldTimeProgram = this.activeTimers != null; let start; if (shouldTimeProgram) { start = tf.util.now(); } let result; if (dtype === 'complex64') { const realValues = this.readSync(complexTensorInfos.real.dataId); const imagValues = this.readSync(complexTensorInfos.imag.dataId); result = tf.backend_util.mergeRealAndImagArrays(realValues, imagValues); } else { result = this.getValuesFromTexture(dataId); } if (shouldTimeProgram) { this.downloadWaitMs += tf.util.now() - start; } return this.convertAndCacheOnCPU(dataId, result); } async read(dataId) { if (this.pendingRead.has(dataId)) { const subscribers = this.pendingRead.get(dataId); return new Promise(resolve => subscribers.push(resolve)); } const texData = this.texData.get(dataId); const { values, shape, slice, dtype, complexTensorInfos, isPacked } = texData; // The presence of `slice` indicates this tensor is a shallow slice of a // different tensor, and is using that original tensor's texture. Run // `clone` in order to copy that texture and read from it. if (slice != null) { let program; if (isPacked) { program = new UnaryOpPackedProgram(shape, CLONE); } else { program = new UnaryOpProgram(shape, CLONE); } const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype); const data = this.read(res.dataId); this.disposeIntermediateTensorInfo(res); return data; } if (values != null) { return this.convertAndCacheOnCPU(dataId); } if (tf.env().getBool('DEBUG')) { // getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') caused a blocking GPU call. // For performance reason, only check it for debugging. In production, // it doesn't handle this use case anyway, so behavior is not changed. if (!tf.env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') && tf.env().getNumber('WEBGL_VERSION') === 2) { throw new Error(`tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and ` + `WEBGL_VERSION=2 not yet supported.`); } } let buffer = null; let tmpDownloadTarget; if (dtype !== 'complex64' && tf.env().get('WEBGL_BUFFER_SUPPORTED')) { // Possibly copy the texture into a buffer before inserting a fence. tmpDownloadTarget = this.decode(dataId); const tmpData = this.texData.get(tmpDownloadTarget.dataId); buffer = this.gpgpu.createBufferFromTexture(tmpData.texture.texture, ...getDenseTexShape(shape)); } this.pendingRead.set(dataId, []); if (dtype !== 'complex64') { // Create a fence and wait for it to resolve. await this.gpgpu.createAndWaitForFence(); } // Download the values from the GPU. let vals; if (dtype === 'complex64') { const ps = await Promise.all([ this.read(complexTensorInfos.real.dataId), this.read(complexTensorInfos.imag.dataId) ]); const realValues = ps[0]; const imagValues = ps[1]; vals = tf.backend_util.mergeRealAndImagArrays(realValues, imagValues); } else if (buffer == null) { vals = this.getValuesFromTexture(dataId); } else { const size = tf.util.sizeFromShape(shape); vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size); } if (tmpDownloadTarget != null) { this.disposeIntermediateTensorInfo(tmpDownloadTarget); } if (buffer != null) { const gl = this.gpgpu.gl; callAndCheck(gl, () => gl.deleteBuffer(buffer)); } const dTypeVals = this.convertAndCacheOnCPU(dataId, vals); const subscribers = this.pendingRead.get(dataId); this.pendingRead.delete(dataId); // Notify all pending reads. subscribers.forEach(resolve => resolve(dTypeVals)); if (this.pendingDisposal.has(dataId)) { this.pendingDisposal.delete(dataId); if (this.disposeData(dataId)) { tf.engine().removeDataId(dataId, this); } this.pendingDeletes--; } return dTypeVals; } /** * Read tensor to a new texture that is densely packed for ease of use. * @param dataId The source tensor. * @param options * customTexShape: Optional. If set, will use the user defined texture * shape to create the texture. */ readToGPU(dataId, options = {}) { const texData = this.texData.get(dataId); const { values, shape, slice, dtype, isPacked, texture } = texData; if (dtype === 'complex64') { throw new Error('Does not support reading texture for complex64 dtype.'); } // The presence of `slice` indicates this tensor is a shallow slice of a // different tensor, and is using that original tensor's texture. Run // `clone` in order to copy that texture and read from it. if (slice != null) { let program; if (isPacked) { program = new UnaryOpPackedProgram(shape, CLONE); } else { program = new UnaryOpProgram(shape, CLONE); } const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype); const gpuResouorce = this.readToGPU(res, options); this.disposeIntermediateTensorInfo(res); return gpuResouorce; } if (texture == null) { if (values != null) { throw new Error('Data is not on GPU but on CPU.'); } else { throw new Error('There is no data on GPU or CPU.'); } } // Decode the texture so that it is stored densely (using four channels). const tmpTarget = this.decode(dataId, options.customTexShape); // Make engine track this tensor, so that we can dispose it later. const tensorRef = tf.engine().makeTensorFromTensorInfo(tmpTarget); const tmpData = this.texData.get(tmpTarget.dataId); return Object.assign({ tensorRef }, tmpData.texture); } bufferSync(t) { const data = this.readSync(t.dataId); if (t.dtype === 'string') { try { // Decode the bytes into string. const strings = data.map(d => tf.util.decodeString(d)); return tf.buffer(t.shape, t.dtype, strings); } catch (_a) { throw new Error('Failed to decode encoded string bytes into utf-8'); } } return tf.buffer(t.shape, t.dtype, data); } checkNumericalProblems(values) { if (values == null) { return; } for (let i = 0; i < values.length; i++) { const num = values[i]; if (!canBeRepresented(num)) { if (tf.env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) { throw Error(`The value ${num} cannot be represented with your ` + `current settings. Consider enabling float32 rendering: ` + `'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`); } throw Error(`The value ${num} cannot be represented on this device.`); } } } getValuesFromTexture(dataId) { const { shape, dtype, isPacked } = this.texData.get(dataId); const size = tf.util.sizeFromShape(shape); if (tf.env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) { const tmpTarget = this.decode(dataId); const tmpData = this.texData.get(tmpTarget.dataId); const vals = this.gpgpu .downloadMatrixFromPackedTexture(tmpData.texture.texture, ...getDenseTexShape(shape)) .subarray(0, size); this.disposeIntermediateTensorInfo(tmpTarget); return vals; } const shouldUsePackedProgram = tf.env().getBool('WEBGL_PACK') && isPacked === true; const outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape; const program = shouldUsePackedProgram ? new EncodeFloatPackedProgram(outputShape) : new EncodeFloatProgram(outputShape); const output = this.runWebGLProgram(program, [{ shape: outputShape, dtype, dataId }], 'float32'); const tmpData = this.texData.get(output.dataId); const vals = this.gpgpu .downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture.texture, tmpData.texShape[0], tmpData.texShape[1]) .subarray(0, size); this.disposeIntermediateTensorInfo(output); return vals; } timerAvailable() { return tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0; } time(f) { const oldActiveTimers = this.activeTimers; const newActiveTimers = []; let outerMostTime = false; if (this.programTimersStack == null) { this.programTimersStack = newActiveTimers; outerMostTime = true; } else { this.activeTimers.push(newActiveTimers); } this.activeTimers = newActiveTimers; f(); // needing to split these up because util.flatten only accepts certain types const flattenedActiveTimerQueries = tf.util.flatten(this.activeTimers.map((d) => d.query)) .filter(d => d != null); const flattenedActiveTimerNames = tf.util.flatten(this.activeTimers.map((d) => d.name)) .filter(d => d != null); this.activeTimers = oldActiveTimers; if (outerMostTime) { this.programTimersStack = null; } const res = { uploadWaitMs: this.uploadWaitMs, downloadWaitMs: this.downloadWaitMs, kernelMs: null, wallMs: null // will be filled by the engine }; return (async () => { if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { const kernelMs = await Promise.all(flattenedActiveTimerQueries); res['kernelMs'] = tf.util.sum(kernelMs); res['getExtraProfileInfo'] = () => kernelMs .map((d, i) => ({ name: flattenedActiveTimerNames[i], ms: d })) .map(d => `${d.name}: ${d.ms}`) .join(', '); } else { res['kernelMs'] = { error: 'WebGL query timers are not supported in this environment.' }; } this.uploadWaitMs = 0; this.downloadWaitMs = 0; return res; })(); } memory() { return { unreliable: false, numBytesInGPU: this.numBytesInGPU, numBytesInGPUAllocated: this.textureManager.numBytesAllocated, numBytesInGPUFree: this.textureManager.numBytesFree }; } startTimer() { if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { return this.gpgpu.beginQuery(); } return { startMs: tf.util.now(), endMs: null }; } endTimer(query) { if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { this.gpgpu.endQuery(); return query; } query.endMs = tf.util.now(); return query; } async getQueryTime(query) { if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { return this.gpgpu.waitForQueryAndGetTime(query); } const timerQuery = query; return timerQuery.endMs - timerQuery.startMs; } /** * Decrease the RefCount on the dataId and dispose the memory if the dataId * has 0 refCount. If there are pending read on the data, the disposal would * added to the pending delete queue. Return true if the dataId is removed * from backend or the backend does not contain the dataId, false if the * dataId is not removed. Memory may or may not be released even when dataId * is removed, which also depends on dataRefCount, see `releaseGPU`. * @param dataId * @oaram force Optional, remove the data regardless of refCount */ disposeData(dataId, force = false) { if (this.pendingDisposal.has(dataId)) { return false; } // No-op if already disposed. if (!this.texData.has(dataId)) { return true; } // if force flag is set, change refCount to 0, this would ensure disposal // when added to the pendingDisposal queue. Memory may or may not be // released, which also depends on dataRefCount, see `releaseGPU`. if (force) { this.texData.get(dataId).refCount = 0; } else { this.texData.get(dataId).refCount--; } if (!force && this.texData.get(dataId).refCount > 0) { return false; } if (this.pendingRead.has(dataId)) { this.pendingDisposal.add(dataId); this.pendingDeletes++; return false; } this.releaseGPUData(dataId); const { complexTensorInfos } = this.texData.get(dataId); if (complexTensorInfos != null) { this.disposeData(complexTensorInfos.real.dataId, force); this.disposeData(complexTensorInfos.imag.dataId, force); } this.texData.delete(dataId); return true; } releaseGPUData(dataId) { const { texture, dtype, texShape, usage, isPacked, slice } = this.texData.get(dataId); const key = slice && slice.origDataId || dataId; const refCount = this.dataRefCount.get(key); if (refCount > 1) { this.dataRefCount.set(key, refCount - 1); } else { this.dataRefCount.delete(key); if (texture != null) { this.numBytesInGPU -= this.computeBytes(texShape, dtype); this.textureManager.releaseTexture(texture, texShape, usage, isPacked); } } const texData = this.texData.get(dataId); texData.texture = null; texData.texShape = null; texData.isPacked = false; texData.slice = null; } getTexture(dataId) { this.uploadToGPU(dataId); return this.texData.get(dataId).texture.texture; } /** * Returns internal information for the specific data bucket. Used in unit * tests. */ getDataInfo(dataId) { return this.texData.get(dataId); } /* Tests whether all the inputs to an op are small and on the CPU. This heuristic determines when it would be faster to execute a kernel on the CPU. WebGL kernels opt into running this check and forwarding when appropriate. TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more sustainable strategy for optimizing backend execution of ops. */ shouldExecuteOnCPU(inputs, sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD) { return tf.env().getBool('WEBGL_CPU_FORWARD') && inputs.every(input => this.texData.get(input.dataId).texture == null && tf.util.sizeFromShape(input.shape) < sizeThreshold); } getGPGPUContext() { return this.gpgpu; } where(condition) { tf.backend_util.warn('tf.where() in webgl locks the UI thread. ' + 'Call tf.whereAsync() instead'); const condVals = condition.dataSync(); return whereImpl(condition.shape, condVals); } packedUnaryOp(x, op, dtype) { const program = new UnaryOpPackedProgram(x.shape, op); const outInfo = this.compileAndRun(program, [x], dtype); return tf.engine().makeTensorFromTensorInfo(outInfo); } // TODO(msoulanille) remove this once the backend has been modularized // a copy is needed here to break a circular dependency. // Also remove the op from unary_op. abs(x) { // TODO: handle cases when x is complex. if (this.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') { const outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values); return this.makeOutput(x.shape, x.dtype, outValues); } if (tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, ABS$1, x.dtype); } const program = new UnaryOpProgram(x.shape, ABS$1); const outInfo = this.compileAndRun(program, [x]); return tf.engine().makeTensorFromTensorInfo(outInfo); } makeTensorInfo(shape, dtype, values) { let dataId; if (dtype === 'string' && values != null && values.length > 0 && tf.util.isString(values[0])) { const encodedValues = values.map(d => tf.util.encodeString(d)); dataId = this.write(encodedValues, shape, dtype); } else { dataId = this.write(values, shape, dtype); } this.texData.get(dataId).usage = null; return { dataId, shape, dtype }; } makeOutput(shape, dtype, values) { return tf.engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this); } unpackTensor(input) { const program = new UnpackProgram(input.shape); return this.runWebGLProgram(program, [input], input.dtype); } packTensor(input) { const program = new PackProgram(input.shape); const preventEagerUnpackingOutput = true; return this.runWebGLProgram(program, [input], input.dtype, null /* customUniformValues */, preventEagerUnpackingOutput); } packedReshape(input, afterShape) { const input3DShape = [ getBatchDim(input.shape), ...getRowsCols(input.shape) ]; const input3D = { dtype: input.dtype, shape: input3DShape, dataId: input.dataId }; const afterShapeAs3D = [ getBatchDim(afterShape), ...getRowsCols(afterShape) ]; const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape); const preventEagerUnpackingOfOutput = true; const customValues = [input3DShape]; const output = this.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput); return { dataId: output.dataId, shape: afterShape, dtype: output.dtype }; } decode(dataId, customTexShape) { const texData = this.texData.get(dataId); const { isPacked, shape, dtype } = texData; if (customTexShape != null) { const size = tf.util.sizeFromShape(shape); const texSize = customTexShape[0] * customTexShape[1] * 4; tf.util.assert(size <= texSize, () => 'customTexShape is too small. ' + 'Row * Column * 4 should be equal or larger than the ' + 'size of the tensor data.'); } const shapeAs3D = getShapeAs3D(shape); let program; if (isPacked) { program = new DecodeMatrixPackedProgram(shapeAs3D); } else { program = new DecodeMatrixProgram(shapeAs3D); } const preventEagerUnpackingOfOutput = true; const customValues = [customTexShape != null ? customTexShape : getDenseTexShape(shapeAs3D)]; const out = this.runWebGLProgram(program, [{ shape: shapeAs3D, dtype, dataId }], dtype, customValues, preventEagerUnpackingOfOutput, customTexShape); return { dtype, shape, dataId: out.dataId }; } runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false, customTexShape) { const output = this.makeTensorInfo(program.outputShape, outputDtype); const outData = this.texData.get(output.dataId); if (program.packedOutput) { outData.isPacked = true; } if (program.outPackingScheme === PackingScheme.DENSE) { const texelShape = customTexShape != null ? customTexShape : getDenseTexShape(program.outputShape); // For a densely packed output, we explicitly set texShape // so it doesn't get assigned later according to our typical packing // scheme wherein a single texel can only contain values from adjacent // rows/cols. outData.texShape = texelShape.map(d => d * 2); } if (program.outTexUsage != null) { outData.usage = program.outTexUsage; } if (tf.util.sizeFromShape(output.shape) === 0) { // Short-circuit the computation since the result is empty (has 0 in its // shape). outData.values = tf.util.getTypedArrayFromDType(output.dtype, 0); return output; } const dataToDispose = []; const inputsData = inputs.map(input => { if (input.dtype === 'complex64') { throw new Error(`GPGPUProgram does not support complex64 input. For complex64 ` + `dtypes, please separate the program into real and imaginary ` + `parts.`); } let texData = this.texData.get(input.dataId); if (texData.texture == null) { if (!program.packedInputs && tf.util.sizeFromShape(input.shape) <= tf.env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) { // Upload small tensors that live on the CPU as uniforms, not as // textures. Do this only when the environment supports 32bit floats // due to problems when comparing 16bit floats with 32bit floats. // TODO(https://github.com/tensorflow/tfjs/issues/821): Make it // possible for packed shaders to sample from uniforms. return { shape: input.shape, texData: null, isUniform: true, uniformValues: texData.values }; } // This ensures that if a packed program's inputs have not yet been // uploaded to the GPU, they get uploaded as packed right off the bat. if (program.packedInputs) { texData.isPacked = true; texData.shape = input.shape; } } this.uploadToGPU(input.dataId); if (!!texData.isPacked !== !!program.packedInputs) { input = texData.isPacked ? this.unpackTensor(input) : this.packTensor(input); dataToDispose.push(input); texData = this.texData.get(input.dataId); } else if (texData.isPacked && !isReshapeFree(texData.shape, input.shape)) { // This is a special case where a texture exists for a tensor // but the shapes are incompatible (due to packing constraints) because // the tensor did not have a chance to go through the packed reshape // shader. This only happens when we reshape the *same* tensor to form // *distinct* inputs to an op, e.g. dotting a vector with itself. This // case will disappear once packed uploading is the default. const savedInput = input; const targetShape = input.shape; input.shape = texData.shape; input = this.packedReshape(input, targetShape); dataToDispose.push(input); texData = this.texData.get(input.dataId); savedInput.shape = targetShape; } return { shape: input.shape, texData, isUniform: false }; }); this.uploadToGPU(output.dataId); const outputData = { shape: output.shape, texData: outData, isUniform: false }; const key = makeShaderKey(program, inputsData, outputData); const binary = this.getAndSaveBinary(key, () => { return compileProgram(this.gpgpu, program, inputsData, outputData); }); const shouldTimeProgram = this.activeTimers != null; let query; if (shouldTimeProgram) { query = this.startTimer(); } if (!tf.env().get('ENGINE_COMPILE_ONLY')) { runProgram(this.gpgpu, binary, inputsData, outputData, customUniformValues); } dataToDispose.forEach(info => this.disposeIntermediateTensorInfo(info)); if (shouldTimeProgram) { query = this.endTimer(query); this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) }); } const glFlushThreshold = tf.env().getNumber('WEBGL_FLUSH_THRESHOLD'); // Manually GL flush requested if (glFlushThreshold > 0) { const time = tf.util.now(); if ((time - this.lastGlFlushTime) > glFlushThreshold) { this.gpgpu.gl.flush(); this.lastGlFlushTime = time; } } if (!tf.env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked && preventEagerUnpackingOfOutput === false) { const unpacked = this.unpackTensor(output); this.disposeIntermediateTensorInfo(output); return unpacked; } return output; } compileAndRun(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false) { outputDtype = outputDtype || inputs[0].dtype; const outInfo = this.runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput); return outInfo; } getAndSaveBinary(key, getBinary) { if (!(key in this.binaryCache)) { this.binaryCache[key] = getBinary(); } return this.binaryCache[key]; } getTextureManager() { return this.textureManager; } dispose() { if (this.disposed) { return; } // Avoid disposing the compiled webgl programs during unit testing because // it slows down test execution. if (!tf.env().getBool('IS_TEST')) { const allKeys = Object.keys(this.binaryCache); allKeys.forEach(key => { this.gpgpu.deleteProgram(this.binaryCache[key].webGLProgram); delete this.binaryCache[key]; }); } this.textureManager.dispose(); if (this.canvas != null && (typeof (HTMLCanvasElement) !== 'undefined' && this.canvas instanceof HTMLCanvasElement)) { this.canvas.remove(); } else { this.canvas = null; } if (this.gpgpuCreatedLocally) { this.gpgpu.program = null; this.gpgpu.dispose(); } this.disposed = true; } floatPrecision() { if (this.floatPrecisionValue == null) { this.floatPrecisionValue = tf.tidy(() => { if (!tf.env().get('WEBGL_RENDER_FLOAT32_ENABLED')) { // Momentarily switching DEBUG flag to false so we don't throw an // error trying to upload a small value. const debugFlag = tf.env().getBool('DEBUG'); tf.env().set('DEBUG', false); const underflowCheckValue = this.abs(tf.scalar(1e-8)).dataSync()[0]; tf.env().set('DEBUG', debugFlag); if (underflowCheckValue > 0) { return 32; } } return 16; }); } return this.floatPrecisionValue; } /** Returns the smallest representable number. */ epsilon() { return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; } uploadToGPU(dataId) { const texData = this.texData.get(dataId); const { shape, dtype, values, texture, usage, isPacked } = texData; if (texture != null) { // Array is already on GPU. No-op. return; } const shouldTimeProgram = this.activeTimers != null; let start; if (shouldTimeProgram) { start = tf.util.now(); } let texShape = texData.texShape; if (texShape == null) { // This texShape may not be the final texture shape. For packed or dense // textures, the texShape will be changed when textures are created. texShape = getTextureShapeFromLogicalShape(shape, isPacked); texData.texShape = texShape; } if (values != null) { const shapeAs3D = getShapeAs3D(shape); let program; let width = texShape[1], height = texShape[0]; const isByteArray = values instanceof Uint8Array || values instanceof Uint8ClampedArray; // texture for float array is PhysicalTextureType.PACKED_2X2_FLOAT32, we // need to make sure the upload uses the same packed size if (isPacked || !isByteArray) { [width, height] = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]); } if (isPacked) { program = new EncodeMatrixPackedProgram(shapeAs3D, isByteArray); } else { program = new EncodeMatrixProgram(shapeAs3D, isByteArray); } // TexShape for float array needs to be the original shape, which byte // array needs to be packed size. This allow the data upload shape to be // matched with texture creation logic. const tempDenseInputTexShape = isByteArray ? [height, width] : texShape; const tempDenseInputHandle = this.makeTensorInfo(tempDenseInputTexShape, dtype); const tempDenseInputTexData = this.texData.get(tempDenseInputHandle.dataId); if (isByteArray) { tempDenseInputTexData.usage = TextureUsage.PIXELS; } else { tempDenseInputTexData.usage = TextureUsage.UPLOAD; } tempDenseInputTexData.texShape = tempDenseInputTexShape; this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values); const customValues = [[height, width]]; // We want the output to remain packed regardless of the value of // WEBGL_PACK. const preventEagerUnpacking = true; const encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, customValues, preventEagerUnpacking); // Have the original texture assume the identity of the encoded output. const outputTexData = this.texData.get(encodedOutputTarget.dataId); texData.texShape = outputTexData.texShape; texData.isPacked = outputTexData.isPacked; texData.usage = outputTexData.usage; if (!tf.env().get('ENGINE_COMPILE_ONLY')) { texData.texture = outputTexData.texture; // Once uploaded, don't store the values on cpu. texData.values = null; this.texData.delete(encodedOutputTarget.dataId); } else { this.disposeData(encodedOutputTarget.dataId); } this.disposeIntermediateTensorInfo(tempDenseInputHandle); if (shouldTimeProgram) { this.uploadWaitMs += tf.util.now() - start; } } else { const newTexture = this.acquireTexture(texShape, usage, dtype, isPacked); texData.texture = newTexture; } } convertAndCacheOnCPU(dataId, float32Values) { const texData = this.texData.get(dataId); const { dtype } = texData; if (float32Values != null) { texData.values = float32ToTypedArray(float32Values, dtype); } return texData.values; } acquireTexture(texShape, texType, dtype, isPacked) { this.numBytesInGPU += this.computeBytes(texShape, dtype); if (!this.warnedAboutMemory && this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) { const mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2); this.warnedAboutMemory = true; console.warn(`High memory usage in GPU: ${mb} MB, ` + `most likely due to a memory leak`); } return this.textureManager.acquireTexture(texShape, texType, isPacked); } computeBytes(shape, dtype) { return shape[0] * shape[1] * tf.util.bytesPerElement(dtype); } checkCompileCompletion() { for (const [, binary] of Object.entries(this.binaryCache)) { this.checkCompletion_(binary); } } async checkCompileCompletionAsync() { const ps = []; if (this.gpgpu.parallelCompilationExtension) { for (const [, binary] of Object.entries(this.binaryCache)) { ps.push(this.checkCompletionAsync_(binary)); } return Promise.all(ps); } else { for (const [, binary] of Object.entries(this.binaryCache)) { const p = new Promise((resolve) => { try { this.checkCompletion_(binary); resolve(true); } catch (error) { throw error; } }); ps.push(p); } return Promise.all(ps); } } async checkCompletionAsync_(binary) { if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.parallelCompilationExtension.COMPLETION_STATUS_KHR)) { return this.checkCompletion_(binary); } else { await tf.nextFrame(); return this.checkCompletionAsync_(binary); } } checkCompletion_(binary) { if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.gl.LINK_STATUS) === false) { console.log(this.gpgpu.gl.getProgramInfoLog(binary.webGLProgram)); if (this.gpgpu.gl.getShaderParameter(binary.fragmentShader, this.gpgpu.gl.COMPILE_STATUS) === false) { logShaderSourceAndInfoLog(binary.source, this.gpgpu.gl.getShaderInfoLog(binary.fragmentShader)); throw new Error('Failed to compile fragment shader.'); } throw new Error('Failed to link vertex and fragment shaders.'); } return true; } getUniformLocations() { for (const binary of Object.values(this.binaryCache)) { // TODO: Iterating through all binaries to build VAOs is supposed to be in // a seperate function, like 'setVaos'. However, to avoid breaking changes // for the users using parallel compile feature now, buildVao is silently // added here. this.gpgpu.buildVao(binary.webGLProgram); const { variablesLocations, customUniformLocations, infLoc, nanLoc, outShapeLocation, outShapeStridesLocation, outTexShapeLocation } = getUniformLocations(this.gpgpu, binary.program, binary.webGLProgram); binary.variablesLocations = variablesLocations; binary.customUniformLocations = customUniformLocations; binary.infLoc = infLoc; binary.nanLoc = nanLoc; binary.outShapeLocation = outShapeLocation; binary.outShapeStridesLocation = outShapeStridesLocation; binary.outTexShapeLocation = outTexShapeLocation; } } /** * Create a TF.js tensor out of an existing WebGL texture. A new texture will * be created. */ createTensorFromGPUData(values, shape, dtype) { values.channels = values.channels || 'RGBA'; const { texture, height, width, channels } = values; const backend = tf.engine().backend; // Have to throw an error, otherwise WebGL just warns and returns wrong // values. if (!backend.gpgpu.gl.isTexture(texture)) { throw new Error(`The texture is invalid. Also, please make sure the texture and ` + `the TFJS WebGL backend are using the same canvas. If you want to ` + `use your own custom canvas, you have to create and use the custom ` + `TFJS WebGL backend created from the canvas through ` + `'new tf.MathBackendWebGL(customCanvas)'.`); } const dataId = backend.writeTexture(texture, shape, dtype, height, width, channels); return tf.engine().makeTensorFromDataId(dataId, shape, dtype, backend); } } MathBackendWebGL.nextDataId = 0; function float32ToTypedArray(a, dtype) { if (dtype === 'float32' || dtype === 'complex64') { return a; } else if (dtype === 'int32' || dtype === 'bool') { const result = (dtype === 'int32') ? new Int32Array(a.length) : new Uint8Array(a.length); for (let i = 0; i < result.length; ++i) { result[i] = Math.round(a[i]); } return result; } else { throw new Error(`Unknown dtype ${dtype}`); } } /** @license See the LICENSE file. */ // This code is auto-generated, do not modify this file! const version = '4.15.0'; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Enforce use of half precision textures if available on the platform. * * @doc {heading: 'Environment', namespace: 'webgl'} */ function forceHalfFloat() { tf.env().set('WEBGL_FORCE_F16_TEXTURES', true); } /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ if (tf.device_util.isBrowser()) { tf.registerBackend('webgl', () => new MathBackendWebGL(), 2 /* priority */); } const webgl = { forceHalfFloat }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const CHECK_NAN_SNIPPET = ` if (isnan(a)) return a; if (isnan(b)) return b; `; class BinaryOpProgram { constructor(op, aShape, bShape) { this.variableNames = ['A', 'B']; this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape); this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = ` float binaryOperation(float a, float b) { ${op} } void main() { float a = getAAtOutCoords(); float b = getBAtOutCoords(); setOutput(binaryOperation(a, b)); } `; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const CHECK_NAN_SNIPPET_PACKED = ` result.r = isNaN.r ? NAN : result.r; result.g = isNaN.g ? NAN : result.g; result.b = isNaN.b ? NAN : result.b; result.a = isNaN.a ? NAN : result.a; `; class BinaryOpPackedProgram { constructor(op, aShape, bShape, checkOutOfBounds = false) { this.variableNames = ['A', 'B']; this.supportsBroadcasting = true; this.packedInputs = true; this.packedOutput = true; this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape); const rank = this.outputShape.length; this.enableShapeUniforms = useShapeUniforms(rank); let checkOutOfBoundsString = ''; if (checkOutOfBounds) { if (rank === 0 || tf.util.sizeFromShape(this.outputShape) === 1) { checkOutOfBoundsString = ` result.y = 0.; result.z = 0.; result.w = 0.; `; } else { const dtype = getCoordsDataType(rank); checkOutOfBoundsString = ` ${dtype} coords = getOutputCoords(); `; if (rank === 1) { if (this.enableShapeUniforms) { checkOutOfBoundsString += ` result.y = (coords + 1) >= outShape ? 0. : result.y; result.z = 0.; result.w = 0.; `; } else { checkOutOfBoundsString += ` result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y; result.z = 0.; result.w = 0.; `; } } else { const channels = getChannels('coords', rank); if (this.enableShapeUniforms) { checkOutOfBoundsString += ` bool nextRowOutOfBounds = (${channels[rank - 2]} + 1) >= outShape[${rank} - 2]; bool nextColOutOfBounds = (${channels[rank - 1]} + 1) >= outShape[${rank} - 1]; result.y = nextColOutOfBounds ? 0. : result.y; result.z = nextRowOutOfBounds ? 0. : result.z; result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w; `; } else { checkOutOfBoundsString += ` bool nextRowOutOfBounds = (${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]}; bool nextColOutOfBounds = (${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]}; result.y = nextColOutOfBounds ? 0. : result.y; result.z = nextRowOutOfBounds ? 0. : result.z; result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w; `; } } } } this.userCode = ` vec4 binaryOperation(vec4 a, vec4 b) { ${op} } void main() { vec4 a = getAAtOutCoords(); vec4 b = getBAtOutCoords(); vec4 result = binaryOperation(a, b); ${checkOutOfBoundsString} setOutput(result); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function identity(args) { const { inputs, backend } = args; const { x } = inputs; backend.incRef(x.dataId); return { dataId: x.dataId, shape: x.shape, dtype: x.dtype }; } const identityConfig = { kernelName: tf.Identity, backendName: 'webgl', kernelFunc: identity }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * In WebGL data is stored in GPU textures which can't be efficiently copied, so * complex tensors share data with their real and imaginary components. Complex * tensors' reference to the components is tracked by refCount on the individual * component. The refCounts are increased by the identity call. * * When a complex tensor is disposed, it will reduce the refCount on the * components by calling disposeData on each. */ function complex(args) { const { inputs, backend } = args; const { real, imag } = inputs; const complexInfo = backend.makeTensorInfo(real.shape, 'complex64'); const complex = backend.texData.get(complexInfo.dataId); const realTensorInfo = identity({ inputs: { x: real }, backend }); const imagTensorInfo = identity({ inputs: { x: imag }, backend }); complex.complexTensorInfos = { real: realTensorInfo, imag: imagTensorInfo }; return complexInfo; } const complexConfig = { kernelName: tf.Complex, backendName: 'webgl', kernelFunc: complex }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const LEAKYRELU = `return (a < 0.) ? b * a : a;`; const LEAKYRELU_PACKED = ` vec4 aLessThanZero = vec4(lessThan(a, vec4(0.))); return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a); `; function leakyRelu(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { alpha } = attrs; const $alpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(alpha, 'float32')); const program = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(LEAKYRELU_PACKED, x.shape, $alpha.shape) : new BinaryOpProgram(LEAKYRELU, x.shape, $alpha.shape); const result = backend.runWebGLProgram(program, [x, $alpha], 'float32'); backend.disposeIntermediateTensorInfo($alpha); return result; } const leakyReluConfig = { kernelName: tf.LeakyRelu, backendName: 'webgl', kernelFunc: leakyRelu }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const PRELU = `return (a < 0.) ? b * a : a;`; const PRELU_PACKED = ` vec4 aLessThanZero = vec4(lessThan(a, vec4(0.))); return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a); `; function prelu(args) { const { inputs, backend } = args; const { x, alpha } = inputs; const program = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(PRELU_PACKED, x.shape, alpha.shape) : new BinaryOpProgram(PRELU, x.shape, alpha.shape); return backend.runWebGLProgram(program, [x, alpha], 'float32'); } const preluConfig = { kernelName: tf.Prelu, backendName: 'webgl', kernelFunc: prelu }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const CHECK_NAN_SNIPPET_UNARY = `if (isnan(x)) return x;`; /** * Template that creates a `KernelFunc` for unary ops. * @param opSnippet Op snippet to create `UnaryOpProgram`. * @param packedOpSnippet Op snippet to create `UnaryOpPackedProgram`. * @param dtype Optional. If set, the result has this dtype. Otherwise, the * result has the same dtype as the first input. This is mainly used in * comparison kernels, such as Equal, Less, Greater, etc. */ function unaryKernelFunc({ opSnippet, packedOpSnippet, cpuKernelImpl, dtype }) { return ({ inputs, backend }) => { const { x } = inputs; const webglBackend = backend; const $dtype = dtype || x.dtype; if (webglBackend.shouldExecuteOnCPU([x]) && cpuKernelImpl != null) { const xData = webglBackend.texData.get(x.dataId); const outValues = cpuKernelImpl(xData.values, $dtype); return webglBackend.makeTensorInfo(x.shape, $dtype, outValues); } const shouldUsePackedProgram = tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS') && packedOpSnippet != null; let program; if (shouldUsePackedProgram) { program = new UnaryOpPackedProgram(x.shape, packedOpSnippet); } else { program = new UnaryOpProgram(x.shape, opSnippet); } return webglBackend.runWebGLProgram(program, [x], $dtype); }; } /** * Template that creates a `KernelFunc` for binary ops. * @param opSnippet Op snippet to create `BinaryOpProgram`. * @param packedOpSnippet Op snippet to create `BinaryOpPackedProgram`. * @param checkOutOfBoundsForPackedProgram Whether to set checkOutOfBounds=true * when creating BinaryOpPackedProgram. * @param dtype Optional. If set, the result has this dtype. Otherwise, the * result has the same dtype as the first input. This is mainly used in * comparison kernels, such as Equal, Less, Greater, etc. */ function binaryKernelFunc({ opSnippet, packedOpSnippet, checkOutOfBounds = false, supportsComplex = false, cpuKernelImpl, dtype }) { return ({ inputs, backend }) => { const { a, b } = inputs; const webglBackend = backend; if (supportsComplex && a.dtype === 'complex64') { const aData = webglBackend.texData.get(a.dataId); const bData = webglBackend.texData.get(b.dataId); const [real, imag] = [ [aData.complexTensorInfos.real, bData.complexTensorInfos.real], [aData.complexTensorInfos.imag, bData.complexTensorInfos.imag] ].map(complexParts => { const [aPart, bPart] = complexParts; const aHandle = { dataId: aPart.dataId, dtype: aPart.dtype, shape: a.shape }; const bHandle = { dataId: bPart.dataId, dtype: bPart.dtype, shape: b.shape }; const program = new BinaryOpProgram(opSnippet, a.shape, b.shape); return webglBackend.runWebGLProgram(program, [aHandle, bHandle], tf.upcastType(aPart.dtype, bPart.dtype)); }); const complexOutput = complex({ inputs: { real, imag }, backend: webglBackend }); webglBackend.disposeIntermediateTensorInfo(real); webglBackend.disposeIntermediateTensorInfo(imag); // TODO(annxingyuan): Implement CPU forwarding for complex inputs. return complexOutput; } const $dtype = dtype || tf.upcastType(a.dtype, b.dtype); if ((a.dtype === 'string' || b.dtype === 'string' || webglBackend.shouldExecuteOnCPU([a, b])) && cpuKernelImpl != null) { const aVals = webglBackend.texData.get(a.dataId).values; const bVals = webglBackend.texData.get(b.dataId).values; const decodedAVals = a.dtype === 'string' ? // tslint:disable-next-line: no-any tf.backend_util.fromUint8ToStringArray(aVals) : aVals; const decodedBVals = a.dtype === 'string' ? // tslint:disable-next-line: no-any tf.backend_util.fromUint8ToStringArray(bVals) : bVals; const [outValues, outShape] = cpuKernelImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype); const out = webglBackend.makeTensorInfo(outShape, $dtype); const outData = webglBackend.texData.get(out.dataId); outData.values = outValues; return out; } const shouldUsePackedProgram = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') && packedOpSnippet != null; let program; if (shouldUsePackedProgram) { program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds); } else { program = new BinaryOpProgram(opSnippet, a.shape, b.shape); } return webglBackend.runWebGLProgram(program, [a, b], $dtype); }; } function mapActivationToShaderProgram(activation, packed = false) { if (activation === 'linear') { if (packed) { return LINEAR; } return LINEAR$1; } else if (activation === 'relu') { if (packed) { return RELU$1; } return RELU$2; } else if (activation === 'elu') { if (packed) { return ELU$1; } return ELU$2; } else if (activation === 'relu6') { if (packed) { return RELU6$1; } return RELU6$2; } else if (activation === 'prelu') { if (packed) { return PRELU_PACKED; } return PRELU; } else if (activation === 'leakyrelu') { if (packed) { return LEAKYRELU_PACKED; } return LEAKYRELU; } else if (activation === 'sigmoid') { if (packed) { return SIGMOID$1; } return SIGMOID$2; } throw new Error(`Activation ${activation} has not been implemented for the WebGL backend.`); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class MatMulPackedProgram { constructor(aShape, bShape, outputShape, transposeA = false, transposeB = false, addBias = false, activation = null, hasPreluActivation = false, hasLeakyreluActivation = false) { this.variableNames = ['matrixA', 'matrixB']; this.packedInputs = true; this.packedOutput = true; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const sharedDim = transposeA ? aShape[1] : aShape[2]; const sharedDimensionPacked = Math.ceil(sharedDim / 2); const aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2'; const bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z'; const aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww']; const bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw']; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyreluActivation) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = `vec4 activation(vec4 x) { ${activation} }`; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyreluActivation) { this.variableNames.push('leakyreluAlpha'); } let batchASnippet = 'rc.x'; let batchBSnippet = 'rc.x'; if (aShape[0] < bShape[0]) { batchASnippet = `imod(rc.x, ${aShape[0]})`; } else if (bShape[0] < aShape[0]) { batchBSnippet = `imod(rc.x, ${bShape[0]})`; } this.userCode = ` ${activationSnippet} // Don't use uniform for sharedDimensionPacked for performance. const float sharedDimension = ${sharedDimensionPacked}.0; vec4 dot2x2ARowBCol(ivec3 rc) { vec4 result = vec4(0); int batchA = ${batchASnippet}; int batchB = ${batchBSnippet}; for (int i = 0; i < ${sharedDimensionPacked}; i++) { vec4 a = getMatrixA(batchA, ${aSample}); vec4 b = getMatrixB(batchB, ${bSample}); // These swizzled products need to be separately added. // See: https://github.com/tensorflow/tfjs/issues/1735 result += (${aSwizzle[0]} * ${bSwizzle[0]}); result += (${aSwizzle[1]} * ${bSwizzle[1]}); } return result; } void main() { ivec3 rc = getOutputCoords(); vec4 result = dot2x2ARowBCol(rc); ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // (Ar + Ai)(Br + Bi) = // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr // Yr = ArBr - AB // Yi = ArBi + AiBr const COMPLEX_MULTIPLY = { REAL: 'return areal * breal - aimag * bimag;', IMAG: 'return areal * bimag + aimag * breal;' }; class BinaryOpComplexProgram { constructor(op, aShape, bShape) { this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag']; this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape); this.userCode = ` float binaryOpComplex( float areal, float aimag, float breal, float bimag) { ${op} } void main() { float areal = getARealAtOutCoords(); float aimag = getAImagAtOutCoords(); float breal = getBRealAtOutCoords(); float bimag = getBImagAtOutCoords(); setOutput(binaryOpComplex(areal, aimag, breal, bimag)); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const MUL = 'return a * b;'; function multiply(args) { const { inputs, backend } = args; const { a, b } = inputs; const dtype = tf.backend_util.upcastType(a.dtype, b.dtype); if (a.dtype === 'complex64') { const aData = backend.texData.get(a.dataId); const bData = backend.texData.get(b.dataId); const realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape); const imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape); const inputs = [ { dataId: aData.complexTensorInfos.real.dataId, dtype: aData.complexTensorInfos.real.dtype, shape: a.shape }, { dataId: aData.complexTensorInfos.imag.dataId, dtype: aData.complexTensorInfos.imag.dtype, shape: a.shape }, { dataId: bData.complexTensorInfos.real.dataId, dtype: bData.complexTensorInfos.real.dtype, shape: b.shape }, { dataId: bData.complexTensorInfos.imag.dataId, dtype: bData.complexTensorInfos.imag.dtype, shape: b.shape } ]; const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32'); const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32'); const complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(imagPart); // TODO(annxingyuan): CPU forwarding for complex inputs. return complexOutput; } if (backend.shouldExecuteOnCPU([a, b])) { const aData = backend.texData.get(a.dataId); const bData = backend.texData.get(b.dataId); const [outValues, outShape] = multiplyImplCPU(a.shape, b.shape, aData.values, bData.values, dtype); const out = backend.makeTensorInfo(outShape, dtype); const outData = backend.texData.get(out.dataId); outData.values = outValues; return out; } let program; if (tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { program = new BinaryOpPackedProgram(MUL, a.shape, b.shape); } else { program = new BinaryOpProgram(MUL, a.shape, b.shape); } return backend.runWebGLProgram(program, [a, b], dtype); } const multiplyConfig = { kernelName: tf.Multiply, backendName: 'webgl', kernelFunc: multiply }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function packedReshape(input, afterShape, backend) { const input3DShape = [getBatchDim(input.shape), ...getRowsCols(input.shape)]; const input3D = { dtype: input.dtype, shape: input3DShape, dataId: input.dataId }; const afterShapeAs3D = [getBatchDim(afterShape), ...getRowsCols(afterShape)]; const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape); const preventEagerUnpackingOfOutput = true; const customValues = [input3DShape]; const output = backend.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput); return { dataId: output.dataId, shape: afterShape, dtype: output.dtype }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function reshape(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { shape } = attrs; const webglBackend = backend; const xSize = tf.util.sizeFromShape(x.shape); const $shape = tf.util.inferFromImplicitShape(shape, xSize); const $xSize = tf.util.sizeFromShape($shape); tf.util.assert(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` + `shape (${x.shape}) has ${xSize} elements. The new shape and old ` + `shape must have the same number of elements.`); const xTexData = webglBackend.texData.get(x.dataId); if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) && !(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) { return packedReshape(x, $shape, webglBackend); } webglBackend.incRef(x.dataId); return { dataId: x.dataId, shape: $shape, dtype: x.dtype }; } const reshapeConfig = { kernelName: tf.Reshape, backendName: 'webgl', kernelFunc: reshape }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class MeanProgram { constructor(reduceInfo, divisor) { this.variableNames = ['x']; const { windowSize, batchSize, inSize, outSize } = reduceInfo; this.outputShape = [batchSize, outSize]; const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; const windowSizeVec4Remainder = windowSize % 4; let updateSnippet = `sumValue += dot(values, ones);`; if (divisor != null) { const denominator = 1 / divisor; updateSnippet = `sumValue += dot(values * ${tf.util.isInt(denominator) ? denominator.toPrecision(2) : denominator}, ones);`; } let checkOutOfBounds = ''; if (inSize % windowSize > 0) { checkOutOfBounds = ` if (inIdx < 0 || inIdx >= ${inSize}) { return 0.0; } `; } this.userCode = ` const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0); float getValue(int batch, int inIdx) { ${checkOutOfBounds} return getX(batch, inIdx); } void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int outIdx = coords[1]; int inOffset = outIdx * ${windowSize}; float sumValue = 0.0; for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) { int inIdx = inOffset + i; vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), getValue(batch, inIdx + 3) ); ${updateSnippet} } int inIdx = inOffset + ${windowSizeNearestVec4}; if (${windowSizeVec4Remainder === 1}) { vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0); ${updateSnippet} } else if (${windowSizeVec4Remainder === 2}) { vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), 0.0, 0.0); ${updateSnippet} } else if (${windowSizeVec4Remainder === 3}) { vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), 0.0); ${updateSnippet} } setOutput(sumValue); } `; } } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ReduceProgram { constructor(reduceInfo, reduceType) { this.variableNames = ['x']; const { windowSize, batchSize, inSize, outSize } = reduceInfo; this.outputShape = [batchSize, outSize]; let initializationValue = '0.0'; let compareOp = ``; if (reduceType === 'prod') { initializationValue = '1.0'; } else if (reduceType === 'min') { // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. initializationValue = '1.0 / 1e-20'; compareOp = `min`; } else if (reduceType === 'max') { // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. initializationValue = '-1.0 / 1e-20'; compareOp = `max`; } let returnValue = `${reduceType}(${reduceType}(${reduceType}(` + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (reduceType === 'sum') { returnValue = `sumValue`; } else if (reduceType === 'prod') { returnValue = `prodValue`; } else if (reduceType === 'all') { returnValue = `allValue`; } else if (reduceType === 'any') { returnValue = `anyValue`; } const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; const windowSizeVec4Remainder = windowSize % 4; let updateSnippet = ` if (${reduceType === 'sum'}) { sumValue += dot(values, ones); } else if (${reduceType === 'prod'}) { vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]); prodValue *= tmp[0] * tmp[1]; } else { minMaxValue = ${compareOp}(values, minMaxValue); if (${reduceType === 'min'} || ${reduceType === 'max'}) { minMaxValue = ${compareOp}(values, minMaxValue); bvec4 isNaN = isnan(values); if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) { minMaxValue = vec4(NAN); } } } `; let vecType = `vec4`; if (reduceType === 'all') { initializationValue = '1.0'; updateSnippet = ` bool reducedAllValue = all(values); float floatedReducedAllValue = float(reducedAllValue); allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0); `; vecType = `bvec4`; } else if (reduceType === 'any') { initializationValue = '0.0'; updateSnippet = ` bool reducedAnyValue = any(values); float floatedReducedAnyValue = float(reducedAnyValue); anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0); `; vecType = `bvec4`; } let checkOutOfBounds = ''; if (inSize % windowSize > 0) { checkOutOfBounds = ` if (inIdx < 0 || inIdx >= ${inSize}) { return initializationValue; } `; } this.userCode = ` const float initializationValue = ${initializationValue}; const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0); float getValue(int batch, int inIdx) { ${checkOutOfBounds} return getX(batch, inIdx); } void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int outIdx = coords[1]; int inOffset = outIdx * ${windowSize}; vec4 minMaxValue = vec4(${initializationValue}); float prodValue = 1.0; float sumValue = 0.0; float allValue = 1.0; float anyValue = 0.0; for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) { int inIdx = inOffset + i; ${vecType} values = ${vecType}( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), getValue(batch, inIdx + 3) ); ${updateSnippet} } int inIdx = inOffset + ${windowSizeNearestVec4}; if (${windowSizeVec4Remainder === 1}) { ${vecType} values = ${vecType}( getValue(batch, inIdx), initializationValue, initializationValue, initializationValue ); ${updateSnippet} } else if (${windowSizeVec4Remainder === 2}) { ${vecType} values = ${vecType}( getValue(batch, inIdx), getValue(batch, inIdx + 1), initializationValue, initializationValue ); ${updateSnippet} } else if (${windowSizeVec4Remainder === 3}) { ${vecType} values = ${vecType}( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), initializationValue ); ${updateSnippet} } setOutput(${returnValue}); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Returns an array of configuration objects that describe each stage of the // reduction. function getReductionStages(inShape) { const stages = []; while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) { const outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1]; const windowSize = tf.backend_util.computeOptimalWindowSize(outSize); stages.push({ inSize: outSize, windowSize, outSize: Math.ceil(outSize / windowSize) }); } return stages; } function reduce(x, dtype, reductionType, backend) { const reductionStages = getReductionStages(x.shape); let result = x; for (let i = 0; i < reductionStages.length; i++) { const { inSize, windowSize, outSize } = reductionStages[i]; let program; let previousResult; if (reductionType === 'mean') { program = i === 0 ? new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, inSize) : new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }); } else { program = new ReduceProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, reductionType); } previousResult = result; result = backend.runWebGLProgram(program, [result], dtype); if (previousResult.dataId !== x.dataId) { backend.disposeIntermediateTensorInfo(previousResult); } } return result; } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class TransposeProgram { constructor(aShape, newDim) { this.variableNames = ['A']; const outputShape = new Array(aShape.length); for (let i = 0; i < outputShape.length; i++) { outputShape[i] = aShape[newDim[i]]; } this.outputShape = outputShape; this.rank = outputShape.length; const dtype = getCoordsDataType(this.rank); const switched = getSwitchedCoords(newDim); this.userCode = ` void main() { ${dtype} resRC = getOutputCoords(); setOutput(getA(${switched})); } `; } } function getSwitchedCoords(newDim) { const rank = newDim.length; if (rank > 6) { throw Error(`Transpose for rank ${rank} is not yet supported`); } const originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v']; const switchedCoords = new Array(rank); for (let i = 0; i < newDim.length; i++) { switchedCoords[newDim[i]] = originalOrder[i]; } return switchedCoords.join(); } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class TransposePackedProgram { constructor(aShape, newDim) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; const outputShape = new Array(aShape.length); for (let i = 0; i < outputShape.length; i++) { outputShape[i] = aShape[newDim[i]]; } this.outputShape = outputShape; this.rank = outputShape.length; if (this.rank > 6) { throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`); } const dtype = getCoordsDataType(this.rank); const outputOrder = getVecChannels('rc', this.rank); const switchedOrder = new Array(this.rank); for (let i = 0; i < newDim.length; i++) { switchedOrder[newDim[i]] = outputOrder[i]; } const innerDims = `vec2(${switchedOrder.slice(-2).join()})`; const nextColumn = `++${outputOrder[this.rank - 1]} < ${outputShape[this.rank - 1]}`; const getc = `getChannel(getA(${switchedOrder.join()}), ${innerDims})`; this.userCode = ` void main() { ${dtype} rc = getOutputCoords(); vec4 result = vec4(0.); result[0] = ${getc}; if(${nextColumn}) { result[1] = ${getc}; } --${outputOrder[this.rank - 1]}; if(++${outputOrder[this.rank - 2]} < ${outputShape[this.rank - 2]}) { result[2] = ${getc}; if(${nextColumn}) { result[3] = ${getc}; } } setOutput(result); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function transposeImpl(x, perm, backend) { const program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new TransposePackedProgram(x.shape, perm) : new TransposeProgram(x.shape, perm); return backend.runWebGLProgram(program, [x], x.dtype); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sumImpl(x, axis, keepDims, backend) { const reductionIndices = axis; const xRank = x.shape.length; const origAxes = tf.util.parseAxisParam(reductionIndices, x.shape); let axes = origAxes; const permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); const sumInputIsTransposed = permutedAxes != null; let sumInput = x; if (sumInputIsTransposed) { sumInput = transposeImpl(x, permutedAxes, backend); axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); } tf.backend_util.assertAxesAreInnerMostDims('sum', axes, xRank); const [sumOutShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(sumInput.shape, axes); let outShape = sumOutShape; if (keepDims) { // rather than reshape at the end, set the target shape here. outShape = tf.backend_util.expandShapeToKeepDim(sumOutShape, origAxes); } const inSize = tf.util.sizeFromShape(reduceShape); const xSize = tf.util.sizeFromShape(x.shape); const batchSize = xSize / inSize; const reshapedInput = reshape({ inputs: { x: sumInput }, attrs: { shape: [batchSize, inSize] }, backend }); const outType = tf.sumOutType(x.dtype); const reduced = reduce(reshapedInput, outType, 'sum', backend); const out = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend }); backend.disposeIntermediateTensorInfo(reshapedInput); backend.disposeIntermediateTensorInfo(reduced); if (sumInputIsTransposed) { backend.disposeIntermediateTensorInfo(sumInput); } return out; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sum(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, keepDims } = attrs; return sumImpl(x, axis, keepDims, backend); } const sumConfig = { kernelName: tf.Sum, backendName: 'webgl', kernelFunc: sum }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function transpose(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { perm } = attrs; const webglBackend = backend; const xRank = x.shape.length; const newShape = new Array(xRank); for (let i = 0; i < newShape.length; i++) { newShape[i] = x.shape[perm[i]]; } let out; if (webglBackend.shouldExecuteOnCPU([x])) { const xTexData = webglBackend.texData.get(x.dataId); const values = xTexData.values; const outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape); out = webglBackend.makeTensorInfo(newShape, x.dtype); const outData = webglBackend.texData.get(out.dataId); outData.values = outValues; } else { out = transposeImpl(x, perm, webglBackend); } return out; } const transposeConfig = { kernelName: tf.Transpose, backendName: 'webgl', kernelFunc: transpose }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Empirically determined minimal shared dimension in matmul before we forward // to a.mul(b).sum() in order to take advantage of GPU parallelism. See // https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks. const MATMUL_SHARED_DIM_THRESHOLD = 1000; function batchMatMulImpl({ a, b, transposeA, transposeB, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) { const aRank = a.shape.length; const bRank = b.shape.length; const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1]; const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2]; const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2]; const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1]; const outerDimsA = a.shape.slice(0, -2); const outerDimsB = b.shape.slice(0, -2); const batchDimA = tf.util.sizeFromShape(outerDimsA); const batchDimB = tf.util.sizeFromShape(outerDimsB); const outShapeOuterDims = tf.broadcast_util.assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2)); const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]); tf.util.assert(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` + `${innerShapeB}) of Tensors with shapes ${a.shape} and ` + `${b.shape} and transposeA=${transposeA}` + ` and transposeB=${transposeB} must match.`); const a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] : [batchDimA, outerShapeA, innerShapeA]; const b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] : [batchDimB, innerShapeB, outerShapeB]; // The rest of the implementation is designed to operate on rank-3 tensors const a3d = reshape({ inputs: { x: a }, backend, attrs: { shape: a3dShape } }); const b3d = reshape({ inputs: { x: b }, backend, attrs: { shape: b3dShape } }); const intermediates = [a3d, b3d]; const batchDim = Math.max(batchDimA, batchDimB); const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2]; const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; const fusedActivation = activation != null ? mapActivationToShaderProgram(activation, true) : null; const containsFusedOps = hasBias || hasPreluActivationWeights || hasLeakyreluAlpha || fusedActivation != null; let out; // Since the matrices are vectors, it is faster to call mul().sum() // because sum() is O(sqrt(N)) due to divide-and-conquer. if ((outerShapeA === 1 || outerShapeB === 1) && sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) { let aVec = a3d; let bVec = b3d; if (transposeA) { aVec = transpose({ inputs: { x: a3d }, backend, attrs: { perm: [0, 2, 1] } }); intermediates.push(aVec); } if (transposeB) { bVec = transpose({ inputs: { x: b3d }, backend, attrs: { perm: [0, 2, 1] } }); intermediates.push(bVec); } const shouldReshapeA = outerShapeB !== 1; const shouldReshapeB = outerShapeB === 1; let aVec3d = aVec; if (shouldReshapeA) { aVec3d = reshape({ inputs: { x: aVec }, backend, attrs: { shape: [batchDim, sharedDim, 1] } }); intermediates.push(aVec3d); } const axis = outerShapeB === 1 ? 2 : 1; let bVec3d = bVec; if (shouldReshapeB) { bVec3d = reshape({ inputs: { x: bVec }, backend, attrs: { shape: [batchDim, 1, sharedDim] } }); intermediates.push(bVec3d); } const product = multiply({ inputs: { a: aVec3d, b: bVec3d }, backend }); out = sum({ inputs: { x: product }, backend, attrs: { axis, keepDims: true } }); intermediates.push(product); } else { const dtype = tf.upcastType(a.dtype, b.dtype); const program = new MatMulPackedProgram(a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const inputs = [a3d, b3d]; if (bias != null) { inputs.push(bias); } if (hasPreluActivationWeights) { inputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } out = backend.runWebGLProgram(program, inputs, dtype); } const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: outShape } }); intermediates.push(out); for (const i of intermediates) { backend.disposeIntermediateTensorInfo(i); } return outReshaped; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function _fusedMatMul(args) { const { inputs, backend, attrs } = args; const { a, b, bias, preluActivationWeights } = inputs; const { transposeA, transposeB, activation, leakyreluAlpha } = attrs; return batchMatMulImpl({ a, b, transposeA, transposeB, backend, bias, preluActivationWeights, leakyreluAlpha, activation }); } const _fusedMatMulConfig = { kernelName: tf._FusedMatMul, backendName: 'webgl', kernelFunc: _fusedMatMul, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ABS = `return abs(x);`; function abs(args) { const { inputs, backend } = args; const { x } = inputs; // TODO: handle cases when x is complex. Once the cpu implementation // can handle complex values, refactor to use unaryKernelFunc. if (backend.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') { const xData = backend.texData.get(x.dataId); const outValues = simpleAbsImplCPU(xData.values); return backend.makeTensorInfo(x.shape, x.dtype, outValues); } let program; if (tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { program = new UnaryOpPackedProgram(x.shape, ABS); } else { program = new UnaryOpProgram(x.shape, ABS); } return backend.runWebGLProgram(program, [x], x.dtype); } const absConfig = { kernelName: tf.Abs, backendName: 'webgl', kernelFunc: abs }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ACOS = CHECK_NAN_SNIPPET$1 + ` if (abs(x) > 1.) { return NAN; } return acos(x); `; const acos = unaryKernelFunc({ opSnippet: ACOS }); const acosConfig = { kernelName: tf.Acos, backendName: 'webgl', kernelFunc: acos, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ACOSH = CHECK_NAN_SNIPPET$1 + ` if (x < 1.0) return NAN; return log(x + sqrt(x * x - 1.0));`; const acosh = unaryKernelFunc({ opSnippet: ACOSH }); const acoshConfig = { kernelName: tf.Acosh, backendName: 'webgl', kernelFunc: acosh, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ADD = 'return a + b;'; const addKernelFunc = binaryKernelFunc({ opSnippet: ADD, packedOpSnippet: ADD, supportsComplex: true, cpuKernelImpl: addImplCPU }); const addConfig = { kernelName: tf.Add, backendName: 'webgl', kernelFunc: addKernelFunc }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class AddNProgram { constructor(outputShape, shapes) { this.outputShape = []; this.outputShape = outputShape; this.variableNames = shapes.map((_, i) => `T${i}`); const snippets = []; // Get target elements from every input tensor. this.variableNames.forEach(variable => { snippets.push(`float v${variable} = get${variable}AtOutCoords();`); }); // Calculate the sum of all elements. const operation = this.variableNames .map(variable => { return `v${variable}`; }) .join(' + '); this.userCode = ` void main() { ${snippets.join('\n ')} float result = ${operation}; setOutput(result); } `; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class AddNPackedProgram { constructor(outputShape, shapes) { this.outputShape = []; this.packedInputs = true; this.packedOutput = true; this.outputShape = outputShape; this.variableNames = shapes.map((_, i) => `T${i}`); const snippets = []; // Get target elements from every input tensor. this.variableNames.forEach(variable => { snippets.push(`vec4 v${variable} = get${variable}AtOutCoords();`); }); // Calculate the sum of all elements. const operation = this.variableNames .map(variable => { return `v${variable}`; }) .join(' + '); this.userCode = ` void main() { ${snippets.join('\n ')} vec4 result = ${operation}; setOutput(result); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function addN(args) { const { inputs, backend } = args; const tensors = inputs; if (tensors.length === 1) { return identity({ inputs: { x: tensors[0] }, backend }); } // Limit the number of uploaded textures for optimization. if (tensors.length > tf.env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) { const midIndex = Math.floor(tensors.length / 2); const leftSide = addN({ inputs: tensors.slice(0, midIndex), backend }); const rightSide = addN({ inputs: tensors.slice(midIndex), backend }); return addN({ inputs: [leftSide, rightSide], backend }); } const dtype = tensors.map(t => t.dtype).reduce((d1, d2) => tf.upcastType(d1, d2)); const shapes = tensors.map(t => t.shape); // We can make sure shapes are identical in op level. const usePackedOp = tf.env().getBool('WEBGL_PACK'); const program = usePackedOp ? new AddNPackedProgram(tensors[0].shape, shapes) : new AddNProgram(tensors[0].shape, shapes); return backend.runWebGLProgram(program, tensors, dtype); } const addNConfig = { kernelName: tf.AddN, backendName: 'webgl', kernelFunc: addN }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function all(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, keepDims } = attrs; const xRank = x.shape.length; const origAxes = tf.util.parseAxisParam(axis, x.shape); let axes = origAxes; const permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); let permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); } tf.backend_util.assertAxesAreInnerMostDims('all', axes, xRank); const [outShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes); const inSize = tf.util.sizeFromShape(reduceShape); const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } }); const reduced = reduce(a2D, a2D.dtype, 'all', backend); let res; if (keepDims) { const newShape = tf.backend_util.expandShapeToKeepDim(outShape, origAxes); res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: newShape } }); } else { res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } }); } backend.disposeIntermediateTensorInfo(a2D); backend.disposeIntermediateTensorInfo(reduced); if (permutedAxes != null) { backend.disposeIntermediateTensorInfo(permutedX); } return res; } const allConfig = { kernelName: tf.All, backendName: 'webgl', kernelFunc: all }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function any(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, keepDims } = attrs; const xRank = x.shape.length; const origAxes = tf.util.parseAxisParam(axis, x.shape); let axes = origAxes; const permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); let permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); } tf.backend_util.assertAxesAreInnerMostDims('any', axes, xRank); const [outShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes); const inSize = tf.util.sizeFromShape(reduceShape); const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } }); const reduced = reduce(a2D, a2D.dtype, 'any', backend); let res; if (keepDims) { const newShape = tf.backend_util.expandShapeToKeepDim(outShape, origAxes); res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: newShape } }); } else { res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } }); } backend.disposeIntermediateTensorInfo(a2D); backend.disposeIntermediateTensorInfo(reduced); if (permutedAxes != null) { backend.disposeIntermediateTensorInfo(permutedX); } return res; } const anyConfig = { kernelName: tf.Any, backendName: 'webgl', kernelFunc: any }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ArgMinMaxProgram { constructor(reduceInfo, op, firstPass) { this.variableNames = ['A']; const { windowSize, batchSize, outSize } = reduceInfo; if (!firstPass) { this.variableNames.push('bestIndicesA'); } this.outputShape = [batchSize, outSize]; const compOp = (op === 'max') ? '>' : '<'; const indexSnippet = firstPass ? 'inOffset + i;' : 'round(getBestIndicesA(batch, inOffset + i));'; this.userCode = ` void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int outIdx = coords[1]; int inOffset = outIdx * ${windowSize}; int bestIndex = inOffset; float bestValue = getA(batch, bestIndex); for (int i = 0; i < ${windowSize}; i++) { int inIdx = ${indexSnippet}; float candidate = getA(batch, inIdx); if (candidate ${compOp} bestValue) { bestValue = candidate; bestIndex = inIdx; } } setOutput(float(bestIndex)); } `; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ArgMinMaxPackedProgram { constructor(shape, windowSize, op, firstPass) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; tf.util.assert(shape.length > 2, () => `Packed arg${op.charAt(0).toUpperCase() + op.slice(1)} supports only inputs with rank above 2.`); const inSize = shape[shape.length - 1]; const outSize = Math.ceil(inSize / windowSize); this.outputShape = shape.slice(0, -1); if (outSize > 1) { this.outputShape.push(outSize); } if (!firstPass) { this.variableNames.push('bestIndicesA'); } const outShape = this.outputShape; const rank = outShape.length; const dtype = getCoordsDataType(rank); const coords = getChannels('coords', rank); let sourceLocSetup; let sourceRank; if (outSize === 1) { sourceRank = rank + 1; const sourceLocDType = getCoordsDataType(sourceRank); sourceLocSetup = ` ${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0); ++${coords[rank - 1]}; ${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0); ++${coords[rank - 2]}; ${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0); --${coords[rank - 1]}; ${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0); --${coords[rank - 2]};`; } else { sourceRank = rank; sourceLocSetup = ` ${dtype} sourceLocR = coords; ++${coords[rank - 1]}; ${dtype} sourceLocG = coords; ++${coords[rank - 2]}; ${dtype} sourceLocA = coords; --${coords[rank - 1]}; ${dtype} sourceLocB = coords; --${coords[rank - 2]};`; } const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank); const inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3. const intChannels = channels.map(x => 'int ' + x); const srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r'); const srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g'); const srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b'); const srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a'); const compOp = (op === 'max') ? 'greaterThan' : 'lessThan'; const fetchCandidateIdx = firstPass ? '' : ` inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}), getBestIndicesAChannel(${srcGCoords.join()}), getBestIndicesAChannel(${srcBCoords.join()}), getBestIndicesAChannel(${srcACoords.join()})));`; const fetchValue = `vec4( getAChannel(${srcRCoords.join()}), hasNextCol ? getAChannel(${srcGCoords.join()}) : 0., hasNextRow ? getAChannel(${srcBCoords.join()}) : 0., hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`; const getBestIndicesAChannelSnippet = firstPass ? '' : ` float getBestIndicesAChannel(${intChannels.join()}) { return getChannel(getBestIndicesA(${channels.join()}), vec2(${channels.slice(-2).join()})); }`; this.userCode = ` float getAChannel(${intChannels.join()}) { return getChannel(getA(${channels.join()}), vec2(${channels.slice(-2).join()})); } ${getBestIndicesAChannelSnippet} void main() { ${dtype} coords = getOutputCoords(); bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1}; bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1}; ${sourceLocSetup} ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel}, sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize}; ivec4 inIdx = srcIdx; vec4 bestIndex = vec4(inIdx); vec4 bestValue = ${fetchValue}; for (int i = 0; i < ${windowSize}; i++) { inIdx = srcIdx; ${fetchCandidateIdx} vec4 candidate = ${fetchValue}; bvec4 nan = isnan(candidate); bvec4 replace = bvec4( vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan))); bestValue = vec4(replace.x ? candidate.x : bestValue.x, replace.y ? candidate.y : bestValue.y, replace.z ? candidate.z : bestValue.z, replace.w ? candidate.w : bestValue.w); bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace)); srcIdx++; } setOutput(bestIndex); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function argReduce(backend, x, reduceType, bestIndicesA = null) { let batchSize = x.shape[0]; let inSize = x.shape[1]; if (bestIndicesA != null) { batchSize = bestIndicesA.shape[0]; inSize = bestIndicesA.shape[1]; } const windowSize = tf.backend_util.computeOptimalWindowSize(inSize); const reduceInfo = { windowSize, inSize, batchSize, outSize: Math.ceil(inSize / windowSize) }; const program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null); const inputs = [x]; if (bestIndicesA != null) { inputs.push(bestIndicesA); } const output = backend.runWebGLProgram(program, inputs, 'int32'); // No need to run another GPGPU program. if (output.shape[1] === 1) { return output; } const result = argReduce(backend, x, reduceType, output); backend.disposeIntermediateTensorInfo(output); return result; } function argReducePacked(backend, x, reduceType, bestIndicesA = null) { const inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape; const inSize = inShape[inShape.length - 1]; const windowSize = tf.backend_util.computeOptimalWindowSize(inSize); const program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null); const inputs = bestIndicesA == null ? [x] : [x, bestIndicesA]; const output = backend.runWebGLProgram(program, inputs, 'int32'); if (output.shape.length === x.shape.length) { const result = argReducePacked(backend, x, reduceType, output); backend.disposeIntermediateTensorInfo(output); return result; } return output; } function argMinMaxReduce(backend, x, axis, reduceType) { const axes = [axis]; tf.backend_util.assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.shape.length); if (!tf.env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) { const intermediateTensorInfos = []; // Eagerly unpack x input since it is passed in to all the shaders which // require unpacked inputs. const xtexData = backend.texData.get(x.dataId); const xIsPacked = xtexData !== null && xtexData.isPacked; let xUnPacked = x; if (xIsPacked) { xUnPacked = backend.unpackTensor(x); intermediateTensorInfos.push(xUnPacked); } const [outShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(xUnPacked.shape, axes); const inSize = tf.util.sizeFromShape(reduceShape); const a2D = reshape({ inputs: { x: xUnPacked }, backend, attrs: { shape: [-1, inSize] } }); intermediateTensorInfos.push(a2D); const reduced = argReduce(backend, a2D, reduceType); intermediateTensorInfos.push(reduced); const reshaped = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } }); intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t)); return reshaped; } return argReducePacked(backend, x, reduceType); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function argMax(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis } = attrs; let axes = tf.util.parseAxisParam(axis, x.shape); const permutedAxes = tf.backend_util.getAxesPermutation(axes, x.shape.length); let $x = x; const intermediateTensorInfos = []; if (permutedAxes != null) { $x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); intermediateTensorInfos.push($x); axes = tf.backend_util.getInnerMostAxes(axes.length, $x.shape.length); } tf.backend_util.assertAxesAreInnerMostDims('argMax', [axes[0]], $x.shape.length); const out = argMinMaxReduce(backend, $x, axes[0], 'max'); intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t)); return out; } const argMaxConfig = { kernelName: tf.ArgMax, backendName: 'webgl', kernelFunc: argMax }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function argMin(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis } = attrs; let axes = tf.util.parseAxisParam(axis, x.shape); const permutedAxes = tf.backend_util.getAxesPermutation(axes, x.shape.length); let $x = x; const intermediateTensorInfos = []; if (permutedAxes != null) { $x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); intermediateTensorInfos.push($x); axes = tf.backend_util.getInnerMostAxes(axes.length, $x.shape.length); } tf.backend_util.assertAxesAreInnerMostDims('argMin', [axes[0]], $x.shape.length); const out = argMinMaxReduce(backend, $x, axes[0], 'min'); intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t)); return out; } const argMinConfig = { kernelName: tf.ArgMin, backendName: 'webgl', kernelFunc: argMin }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ASIN = CHECK_NAN_SNIPPET$1 + ` if (abs(x) > 1.) { return NAN; } return asin(x); `; const asin = unaryKernelFunc({ opSnippet: ASIN }); const asinConfig = { kernelName: tf.Asin, backendName: 'webgl', kernelFunc: asin, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ASINH = CHECK_NAN_SNIPPET$1 + `return log(x + sqrt(x * x + 1.0));`; const asinh = unaryKernelFunc({ opSnippet: ASINH }); const asinhConfig = { kernelName: tf.Asinh, backendName: 'webgl', kernelFunc: asinh, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ATAN = CHECK_NAN_SNIPPET$1 + ` return atan(x); `; const atan = unaryKernelFunc({ opSnippet: ATAN }); const atanConfig = { kernelName: tf.Atan, backendName: 'webgl', kernelFunc: atan, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ATAN2 = CHECK_NAN_SNIPPET + ` return atan(a, b); `; const ATAN2_PACKED = ` vec4 result = atan(a, b); bvec4 isNaNA = isnan(a); bvec4 isNaNB = isnan(b); bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w); ` + CHECK_NAN_SNIPPET_PACKED + ` return result; `; const atan2 = binaryKernelFunc({ opSnippet: ATAN2, packedOpSnippet: ATAN2_PACKED }); const atan2Config = { kernelName: tf.Atan2, backendName: 'webgl', kernelFunc: atan2, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ATANH = CHECK_NAN_SNIPPET$1 + ` if ((x < -1.0) || (x > 1.0)) return NAN; return (log(1.0 + x) - log(1.0 - x)) / 2.0;`; const atanh = unaryKernelFunc({ opSnippet: ATANH }); const atanhConfig = { kernelName: tf.Atanh, backendName: 'webgl', kernelFunc: atanh, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class Pool2DProgram { constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) { this.variableNames = ['x']; if (poolType === 'avg' && computePositions) { throw new Error('Cannot compute positions for average pool.'); } const filterWidth = convInfo.filterWidth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; this.outputShape = convInfo.outShape; const isAvgPool = poolType === 'avg'; const batchFlattenPositionStr = `((batch * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`; const flattenPositionStr = `(xR * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`; let initializationValue = '0.0'; if (!isAvgPool) { // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. initializationValue = '-1.0 / 1e-20'; } if (computePositions) { const compareOp = '>='; this.userCode = ` const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d = coords[3]; ivec2 xRCCorner = coords.yz * strides - pads; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; // max/min x(?, ?, d) to get y(yR, yC, d). // ? = to be determined float minMaxValue = 0.0; float minMaxValueFound = 0.0; int minMaxPosition = 0; float avgValue = 0.0; for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { int xR = xRCorner + wR; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${effectiveFilterWidth}; wC += ${dilationWidth}) { int xC = xCCorner + wC; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } float value = getX(batch, xR, xC, d); // If a min / max value has already been found, use it. If not, // use the current value. float currMinMaxValue = mix( value, minMaxValue, minMaxValueFound); if (value ${compareOp} currMinMaxValue) { minMaxValue = value; minMaxValueFound = 1.0; minMaxPosition = ${flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr : flattenPositionStr) : `wR * ${effectiveFilterWidth} + wC`}; } } } setOutput(float(minMaxPosition)); } `; return; } const compareOp = 'max'; let returnValue = `${poolType}(${poolType}(${poolType}(` + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (poolType === 'avg') { returnValue = `avgValue / max(count, 1.0)`; } const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; const filterWidthVec4Remainder = filterWidth % 4; const updateSnippet = ` if (${isAvgPool}) { avgValue += dot(values, ones); } else { minMaxValue = ${compareOp}(values, minMaxValue); } `; this.userCode = ` const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); const ivec2 pads = ivec2(${padTop}, ${padLeft}); const float initializationValue = ${initializationValue}; const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0); float count = 0.0; float getValue(int batch, int xR, int xC, int d) { if (xC < 0 || xC >= ${convInfo.inWidth}) { return initializationValue; } count += 1.0; return getX(batch, xR, xC, d); } void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d = coords[3]; ivec2 xRCCorner = coords.yz * strides - pads; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; // max/min x(?, ?, d) to get y(yR, yC, d). // ? = to be determined vec4 minMaxValue = vec4(${initializationValue}); float avgValue = 0.0; count = 0.0; for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { int xR = xRCorner + wR; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) { int xC = xCCorner + wC * ${dilationWidth}; vec4 values = vec4( getValue(batch, xR, xC, d), getValue(batch, xR, xC + ${dilationWidth}, d), getValue(batch, xR, xC + 2 * ${dilationWidth}, d), getValue(batch, xR, xC + 3 * ${dilationWidth}, d) ); ${updateSnippet} } int xC = xCCorner + ${filterWidthNearestVec4}; if (${filterWidthVec4Remainder === 1}) { vec4 values = vec4( getValue(batch, xR, xC, d), initializationValue, initializationValue, initializationValue ); ${updateSnippet} } else if (${filterWidthVec4Remainder === 2}) { vec4 values = vec4( getValue(batch, xR, xC, d), getValue(batch, xR, xC + ${dilationWidth}, d), initializationValue, initializationValue ); ${updateSnippet} } else if (${filterWidthVec4Remainder === 3}) { vec4 values = vec4( getValue(batch, xR, xC, d), getValue(batch, xR, xC + ${dilationWidth}, d), getValue(batch, xR, xC + 2 * ${dilationWidth}, d), initializationValue ); ${updateSnippet} } } setOutput(${returnValue}); } `; } } class Pool3DProgram { constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) { this.variableNames = ['x']; if (poolType === 'avg' && computePositions) { throw new Error('Cannot compute positions for average pool.'); } const filterWidth = convInfo.filterWidth; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationDepth = convInfo.dilationDepth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const effectiveFilterDepth = convInfo.effectiveFilterDepth; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padFront = convInfo.padInfo.front; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; this.outputShape = convInfo.outShape; const isAvgPool = poolType === 'avg'; let initializationValue = '0.0'; if (!isAvgPool) { // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. initializationValue = '-1.0 / 1e-20'; } if (computePositions) { const compareOp = '>='; this.userCode = ` const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth}); const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int ch = coords.u; ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads; int xDCorner = xCorner.x; int xRCorner = xCorner.y; int xCCorner = xCorner.z; // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch). // ? = to be determined float minMaxValue = 0.0; float minMaxValueFound = 0.0; int minMaxPosition = 0; for (int wD = 0; wD < ${effectiveFilterDepth}; wD += ${dilationDepth}) { int xD = xDCorner + wD; if (xD < 0 || xD >= ${convInfo.inDepth}) { continue; } for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { int xR = xRCorner + wR; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${effectiveFilterWidth}; wC += ${dilationWidth}) { int xC = xCCorner + wC; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } float value = getX(batch, xD, xR, xC, ch); // If a min / max value has already been found, use it. If not, // use the current value. float currMinMaxValue = mix( value, minMaxValue, minMaxValueFound); if (value ${compareOp} currMinMaxValue) { minMaxValue = value; minMaxValueFound = 1.0; minMaxPosition = ${flattenPositions ? (includeBatchInIndex ? `(((batch * ${convInfo.inDepth} + xD) * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch` : `((xD * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`) : `wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} + wR * ${effectiveFilterWidth} + wC`}; } } } } setOutput(float(minMaxPosition)); } `; return; } const compareOp = 'max'; let returnValue = `${poolType}(${poolType}(${poolType}(` + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (poolType === 'avg') { // Use `max(count, 1.0)` instead of `count` in case count === 0.0. // If count === 0.0, `avgValue` is always 0.0 and we change `count`'s // value to avoid dividing zero. returnValue = `avgValue / max(count, 1.0)`; } const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; const filterWidthVec4Remainder = filterWidth % 4; const updateSnippet = ` if (${isAvgPool}) { avgValue += dot(values, ones); } else { minMaxValue = ${compareOp}(values, minMaxValue); } `; this.userCode = ` const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth}); const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); const float initializationValue = ${initializationValue}; const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0); float count = 0.0; float getValue(int batch, int xD, int xR, int xC, int ch) { if (xC < 0 || xC >= ${convInfo.inWidth}) { return initializationValue; } count += 1.0; return getX(batch, xD, xR, xC, ch); } void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int ch = coords.u; ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads; int xDCorner = xCorner.x; int xRCorner = xCorner.y; int xCCorner = xCorner.z; // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch). // ? = to be determined vec4 minMaxValue = vec4(${initializationValue}); float avgValue = 0.0; count = 0.0; for (int wD = 0; wD < ${effectiveFilterDepth}; wD += ${dilationDepth}) { int xD = xDCorner + wD; if (xD < 0 || xD >= ${convInfo.inDepth}) { continue; } for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { int xR = xRCorner + wR; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) { int xC = xCCorner + wC * ${dilationWidth}; vec4 values = vec4( getValue(batch, xD, xR, xC, ch), getValue(batch, xD, xR, xC + ${dilationWidth}, ch), getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch), getValue(batch, xD, xR, xC + 3 * ${dilationWidth}, ch) ); ${updateSnippet} } int xC = xCCorner + ${filterWidthNearestVec4}; if (${filterWidthVec4Remainder === 1}) { vec4 values = vec4( getValue(batch, xD, xR, xC, ch), initializationValue, initializationValue, initializationValue ); ${updateSnippet} } else if (${filterWidthVec4Remainder === 2}) { vec4 values = vec4( getValue(batch, xD, xR, xC, ch), getValue(batch, xD, xR, xC + ${dilationWidth}, ch), initializationValue, initializationValue ); ${updateSnippet} } else if (${filterWidthVec4Remainder === 3}) { vec4 values = vec4( getValue(batch, xD, xR, xC, ch), getValue(batch, xD, xR, xC + ${dilationWidth}, ch), getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch), initializationValue ); ${updateSnippet} } } } setOutput(${returnValue}); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function avgPool(args) { const { inputs, backend, attrs } = args; const { x } = inputs; assertNotComplex(x, 'avgPool'); const { filterSize, strides, pad, dimRoundingMode } = attrs; const dilations = 1; tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); const convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && tf.util.arraysEqual(convInfo.inShape, convInfo.outShape)) { return identity({ inputs: { x }, backend }); } const avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false); return backend.runWebGLProgram(avgPoolProgram, [x], 'float32'); } const avgPoolConfig = { kernelName: tf.AvgPool, backendName: 'webgl', kernelFunc: avgPool }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function avgPool3D(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs; const dilations = [1, 1, 1]; const convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat); const avgPoolProgram = new Pool3DProgram(convInfo, 'avg', false); return backend.runWebGLProgram(avgPoolProgram, [x], 'float32'); } const avgPool3DConfig = { kernelName: tf.AvgPool3D, backendName: 'webgl', kernelFunc: avgPool3D }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class AvgPool2DBackpropProgram { constructor(convInfo) { this.variableNames = ['dy']; this.outputShape = convInfo.inShape; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; const avgMultiplier = 1 / (filterHeight * filterWidth); this.userCode = ` const ivec2 pads = ivec2(${padTop}, ${padLeft}); const float avgMultiplier = float(${avgMultiplier}); void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; ivec2 dyRCCorner = coords.yz - pads; int dyRCorner = dyRCCorner.x; int dyCCorner = dyRCCorner.y; // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); for (int wC = 0; wC < ${effectiveFilterWidth}; wC+= ${dilationWidth}) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); float dyValue = getDy(b, idyR, idyC, d); dotProd += dyValue * avgMultiplier; } } setOutput(dotProd); } `; } } class AvgPool3DBackpropProgram { constructor(convInfo) { this.variableNames = ['dy']; this.outputShape = convInfo.inShape; const filterDepth = convInfo.filterDepth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationDepth = convInfo.dilationDepth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const effectiveFilterDepth = convInfo.effectiveFilterDepth; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth); this.userCode = ` const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); const float avgMultiplier = float(${avgMultiplier}); void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int ch = coords.u; ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads; int dyDCorner = dyCorner.x; int dyRCorner = dyCorner.y; int dyCCorner = dyCorner.z; // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get // dx(xD, xR, xC, ch). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; for (int wD = 0; wD < ${effectiveFilterDepth}; wD += ${dilationDepth}) { float dyD = float(dyDCorner + wD) / ${strideDepth}.0; if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) { continue; } int idyD = int(dyD); for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); for (int wC = 0; wC < ${effectiveFilterWidth}; wC += ${dilationWidth}) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); float dyValue = getDy(batch, idyD, idyR, idyC, ch); dotProd += dyValue * avgMultiplier; } } } setOutput(dotProd); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function avgPool3DGrad(args) { const { inputs, backend, attrs } = args; const { dy, input } = inputs; const x = input; const { filterSize, strides, pad, dimRoundingMode } = attrs; const dilations = [1, 1, 1]; const convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); const avgPoolBackpropProgram = new AvgPool3DBackpropProgram(convInfo); return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype); } const avgPool3DGradConfig = { kernelName: tf.AvgPool3DGrad, backendName: 'webgl', kernelFunc: avgPool3DGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function avgPoolGrad(args) { const { inputs, backend, attrs } = args; const { dy, input } = inputs; const x = input; assertNotComplex([dy, input], 'avgPoolGrad'); const { filterSize, strides, pad } = attrs; const convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad); const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo); return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype); } const avgPoolGradConfig = { kernelName: tf.AvgPoolGrad, backendName: 'webgl', kernelFunc: avgPoolGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function batchMatMul(args) { const { inputs, backend, attrs } = args; const { a, b } = inputs; const { transposeA, transposeB } = attrs; return batchMatMulImpl({ a, b, transposeA, transposeB, backend }); } const batchMatMulConfig = { kernelName: tf.BatchMatMul, backendName: 'webgl', kernelFunc: batchMatMul, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class BatchNormProgram { constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) { this.outputShape = []; this.variableNames = ['x', 'mean', 'variance']; tf.backend_util.assertAndGetBroadcastShape(xShape, meanShape); tf.backend_util.assertAndGetBroadcastShape(xShape, varianceShape); let offsetSnippet = '0.0'; if (offsetShape != null) { tf.backend_util.assertAndGetBroadcastShape(xShape, offsetShape); this.variableNames.push('offset'); offsetSnippet = 'getOffsetAtOutCoords()'; } let scaleSnippet = '1.0'; if (scaleShape != null) { tf.backend_util.assertAndGetBroadcastShape(xShape, scaleShape); this.variableNames.push('scale'); scaleSnippet = 'getScaleAtOutCoords()'; } this.outputShape = xShape; this.userCode = ` void main() { float x = getXAtOutCoords(); float mean = getMeanAtOutCoords(); float variance = getVarianceAtOutCoords(); float offset = ${offsetSnippet}; float scale = ${scaleSnippet}; float inv = scale * inversesqrt(variance + float(${varianceEpsilon})); setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1))); } `; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class BatchNormPackedProgram { constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) { this.packedInputs = true; this.packedOutput = true; this.variableNames = ['x', 'mean', 'variance']; tf.backend_util.assertAndGetBroadcastShape(xShape, meanShape); tf.backend_util.assertAndGetBroadcastShape(xShape, varianceShape); let offsetSnippet = 'vec4(0.0)'; if (offsetShape != null) { tf.backend_util.assertAndGetBroadcastShape(xShape, offsetShape); this.variableNames.push('offset'); offsetSnippet = 'getOffsetAtOutCoords()'; } let scaleSnippet = 'vec4(1.0)'; if (scaleShape != null) { tf.backend_util.assertAndGetBroadcastShape(xShape, scaleShape); this.variableNames.push('scale'); scaleSnippet = 'getScaleAtOutCoords()'; } this.outputShape = xShape; this.userCode = ` void main() { vec4 offset = ${offsetSnippet}; vec4 scale = ${scaleSnippet}; vec4 x = getXAtOutCoords(); vec4 mean = getMeanAtOutCoords(); vec4 variance = getVarianceAtOutCoords(); vec4 inv = scale * inversesqrt(variance + vec4(${varianceEpsilon})); setOutput((x - mean) * inv + offset); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const batchNorm = ({ inputs, backend, attrs }) => { const { x, mean, variance, offset, scale } = inputs; tf.util.assert(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' + 'equal ranks.'); tf.util.assert(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' + 'equal ranks.'); tf.util.assert(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' + 'equal ranks.'); let { varianceEpsilon } = attrs; if (varianceEpsilon == null) { varianceEpsilon = 0.001; } const finalInputs = [x, mean, variance]; let offsetShape = null; if (offset != null) { offsetShape = offset.shape; finalInputs.push(offset); } let scaleShape = null; if (scale != null) { scaleShape = scale.shape; finalInputs.push(scale); } const program = tf.env().getBool('WEBGL_PACK_NORMALIZATION') ? new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) : new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon); const output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype); return output; }; const batchNormConfig = { kernelName: tf.FusedBatchNorm, backendName: 'webgl', kernelFunc: batchNorm, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class SliceProgram { constructor(destSize) { this.variableNames = ['source']; this.outputShape = destSize; this.rank = destSize.length; const dtype = getCoordsDataType(this.rank); this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }]; const sourceCoords = getCoords$1(this.rank); let body; const coordSum = destSize.map((_, i) => { return `sourceLoc.${coords[i]} = start[${i}] + coords.${coords[i]};`; }); body = ` ${dtype} sourceLoc; ${dtype} coords = getOutputCoords(); ${coordSum.join('\n')} `; this.userCode = ` void main() { ${body} setOutput(getSource(${sourceCoords})); } `; } } const coords = ['x', 'y', 'z', 'w', 'u', 'v']; function getCoords$1(rank) { if (rank === 1) { return 'sourceLoc'; } else if (rank <= 6) { return coords.slice(0, rank).map(x => 'sourceLoc.' + x).join(','); } else { throw Error(`Slicing for rank ${rank} is not yet supported`); } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class SlicePackedProgram { constructor(destSize) { this.variableNames = ['source']; this.packedInputs = true; this.packedOutput = true; this.outputShape = destSize; this.rank = destSize.length; this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }]; const dtype = getCoordsDataType(this.rank); const coords = getChannels('coords', this.rank); const sourceLoc = getChannels('sourceLoc', this.rank); const innerDims = this.rank === 1 ? 'sourceLoc' : `vec2(${sourceLoc.slice(-2).join()})`; const getChannel = `getChannel(getSource(${sourceLoc.join()}), ${innerDims})`; const upperRow = ` result.x = ${getChannel}; if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) { ++${sourceLoc[this.rank - 1]}; result.y = ${getChannel}; --${sourceLoc[this.rank - 1]}; } `; const lowerRow = this.rank === 1 ? '' : ` --${coords[this.rank - 1]}; if (++${coords[this.rank - 2]} < ${destSize[this.rank - 2]}) { ++${sourceLoc[this.rank - 2]}; result.z = ${getChannel}; if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) { ++${sourceLoc[this.rank - 1]}; result.w = ${getChannel}; } } `; const sourceLocSetup = this.rank <= 4 ? `sourceLoc = coords + ${dtype}(${destSize.map((_, i) => `start[${i}]`).join()});` : destSize.map((_, i) => `${sourceLoc[i]} = ${coords[i]} + start[${i}];`) .join('\n'); this.userCode = ` void main() { ${dtype} coords = getOutputCoords(); ${dtype} sourceLoc; ${sourceLocSetup} vec4 result = vec4(0.); ${upperRow} ${lowerRow} setOutput(result); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function shallowSlice(x, begin, size, backend) { const xTexData = backend.texData.get(x.dataId); const t = backend.makeTensorInfo(size, x.dtype); const newTexData = backend.texData.get(t.dataId); // Copy texture data from the original tensor. Object.assign(newTexData, xTexData); newTexData.refCount = 1; newTexData.shape = size; newTexData.dtype = x.dtype; let flatOffset = tf.slice_util.computeFlatOffset(begin, tf.util.computeStrides(x.shape)); if (xTexData.slice) { // We are slicing an already sliced tensor, so we have to accumulate // the offset. flatOffset += xTexData.slice.flatOffset; } newTexData.slice = { flatOffset, // Point to the original dataId, which is used to do ref counting. origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId }; // Increase the ref count for that data bucket. const refCount = backend.dataRefCount.get(newTexData.slice.origDataId) || 1; backend.dataRefCount.set(newTexData.slice.origDataId, refCount + 1); return t; } function slice(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { begin, size } = attrs; const [$begin, $size] = tf.slice_util.parseSliceParams(x, begin, size); tf.slice_util.assertParamsValid(x, $begin, $size); if (tf.util.sizeFromShape($size) === 0) { return backend.makeTensorInfo($size, x.dtype, []); } // Run on cpu if dtype is string. For string, the backend represents it // as Uint8Array[], where each Uint8Array is a character. Given that the // computation is only on the outer array, uploading the whole data onto // gpu is wasteful. Also, currently webgl doesn't have a design to // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we // just run the kernel on cpu if dtype is string. if (backend.shouldExecuteOnCPU([x]) || x.dtype === 'string') { const xTexData = backend.texData.get(x.dataId); const outValues = sliceImplCPU(xTexData.values, $begin, $size, x.shape, x.dtype); return backend.makeTensorInfo($size, x.dtype, outValues); } const { isPacked } = backend.texData.get(x.dataId); const isContinous = tf.slice_util.isSliceContinous(x.shape, $begin, $size); if (isPacked || !isContinous) { const program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new SlicePackedProgram($size) : new SliceProgram($size); const customValues = [$begin]; return backend.runWebGLProgram(program, [x], x.dtype, customValues); } backend.uploadToGPU(x.dataId); return shallowSlice(x, $begin, $size, backend); } const sliceConfig = { kernelName: tf.Slice, backendName: 'webgl', kernelFunc: slice }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const batchToSpaceND = (args) => { const { inputs, backend, attrs } = args; const { x } = inputs; const { blockShape, crops } = attrs; tf.util.assert(x.shape.length <= 4, () => 'batchToSpaceND for rank > 4 with a WebGL backend not ' + 'implemented yet'); const prod = blockShape.reduce((a, b) => a * b); const reshaped = tf.backend_util.getReshaped(x.shape, blockShape, prod); const permuted = tf.backend_util.getPermuted(reshaped.length, blockShape.length); const reshapedPermuted = tf.backend_util.getReshapedPermuted(x.shape, blockShape, prod); const sliceBeginCoords = tf.backend_util.getSliceBeginCoords(crops, blockShape.length); const sliceSize = tf.backend_util.getSliceSize(reshapedPermuted, crops, blockShape.length); const toDispose = []; const reshapedIntermediate = reshape({ inputs: { x }, backend, attrs: { shape: reshaped } }); const transposedIntermediate = transpose({ inputs: { x: reshapedIntermediate }, backend, attrs: { perm: permuted } }); const reshapedIntermediate2 = reshape({ inputs: { x: transposedIntermediate }, backend, attrs: { shape: reshapedPermuted } }); const sliced = slice({ inputs: { x: reshapedIntermediate2 }, backend, attrs: { begin: sliceBeginCoords, size: sliceSize } }); toDispose.push(reshapedIntermediate); toDispose.push(transposedIntermediate); toDispose.push(reshapedIntermediate2); toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return sliced; }; const batchToSpaceNDConfig = { kernelName: tf.BatchToSpaceND, backendName: 'webgl', kernelFunc: batchToSpaceND }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function bincount(args) { const { inputs, backend, attrs } = args; const { x, weights } = inputs; const { size } = attrs; const xVals = backend.readSync(x.dataId); const weightsVals = backend.readSync(weights.dataId); const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size); return backend.makeTensorInfo([size], weights.dtype, outVals); } const bincountConfig = { kernelName: tf.Bincount, backendName: 'webgl', kernelFunc: bincount }; /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const BITWISEAND = ` int r = int(a.r) & int(b.r); int g = int(a.g) & int(b.g); int rb = int(a.b) & int(b.b); int ra = int(a.a) & int(b.a); return vec4(r, g, rb, ra); `; const BITWISEAND_UNPACKED = ` return float(int(a.r) & int(b.r)); `; function bitwiseAnd(args) { const { inputs, backend } = args; const { a, b } = inputs; const shouldUsePackedProgram = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); const versionNumber = tf.env().getNumber('WEBGL_VERSION'); // The type of a and b are ensured to be `int32` in core, therefore no need to // consider other type situations. if ((backend.shouldExecuteOnCPU([a, b])) || versionNumber === 1) { const aVals = backend.texData.get(a.dataId).values; const bVals = backend.texData.get(b.dataId).values; const [outValues, outShape] = bitwiseAndImplCPU(a.shape, b.shape, aVals, bVals, a.dtype); const out = backend.makeTensorInfo(outShape, a.dtype); const outData = backend.texData.get(out.dataId); outData.values = outValues; return out; } let program; if (shouldUsePackedProgram) { program = new BinaryOpPackedProgram(BITWISEAND, a.shape, b.shape, false); } else { program = new BinaryOpProgram(BITWISEAND_UNPACKED, a.shape, b.shape); } return backend.runWebGLProgram(program, [a, b], a.dtype); } const bitwiseAndConfig = { kernelName: tf.BitwiseAnd, backendName: 'webgl', kernelFunc: bitwiseAnd }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function broadcastArgs(args) { const { inputs, backend } = args; const { s0, s1 } = inputs; const s0Vals = backend.readSync(s0.dataId); const s1Vals = backend.readSync(s1.dataId); const broadcastShape = tf.backend_util.assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals)); return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape)); } const broadcastArgsConfig = { kernelName: tf.BroadcastArgs, backendName: 'webgl', kernelFunc: broadcastArgs }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const NOT_EQUAL = `return float(a != b);`; const notEqual = binaryKernelFunc({ opSnippet: NOT_EQUAL, cpuKernelImpl: notEqualImplCPU, dtype: 'bool' }); const notEqualConfig = { kernelName: tf.NotEqual, backendName: 'webgl', kernelFunc: notEqual, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function real(args) { const { inputs, backend } = args; const { input } = inputs; const inputData = backend.texData.get(input.dataId); return identity({ inputs: { x: inputData.complexTensorInfos.real }, backend }); } const realConfig = { kernelName: tf.Real, backendName: 'webgl', kernelFunc: real }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const TO_INT = `return float(int(x));`; function int(input, backend) { const program = new UnaryOpProgram(input.shape, TO_INT); const output = backend.runWebGLProgram(program, [input], 'int32'); return { dataId: output.dataId, shape: output.shape, dtype: output.dtype }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function cast(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { dtype } = attrs; // Casting to complex64. if (dtype === 'complex64') { if (x.dtype === 'complex64') { return identity({ inputs: { x }, backend }); } // TODO(annxingyuan): Import kernel function once zeros is modularized. const zerosTensor = tf__namespace.zeros(x.shape); const floatX = cast({ inputs: { x }, backend, attrs: { dtype: 'float32' } }); const result = complex({ inputs: { real: floatX, imag: zerosTensor }, backend }); zerosTensor.dispose(); backend.disposeIntermediateTensorInfo(floatX); return result; } // Casting from complex64 if (x.dtype === 'complex64') { const realPart = real({ inputs: { input: x }, backend }); const result = cast({ inputs: { x: realPart }, backend, attrs: { dtype } }); backend.disposeIntermediateTensorInfo(realPart); return result; } if (!tf.util.hasEncodingLoss(x.dtype, dtype)) { // We don't change the underlying data, since we cast to higher // precision. const result = identity({ inputs: { x }, backend }); return { dataId: result.dataId, shape: result.shape, dtype }; } if (backend.shouldExecuteOnCPU([x])) { const values = backend.texData.get(x.dataId).values; const [resultShape, resultType, resultData] = castImplCPU(values, x.shape, x.dtype, dtype); return backend.makeTensorInfo(resultShape, resultType, resultData); } if (dtype === 'int32') { return int(x, backend); } if (dtype === 'bool') { const zerosTensorInfo = backend.makeTensorInfo([], 'bool', tf.util.getTypedArrayFromDType('bool', 1)); const binaryInputs = { a: x, b: zerosTensorInfo }; const result = notEqual({ inputs: binaryInputs, backend }); backend.disposeIntermediateTensorInfo(zerosTensorInfo); return result; } throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`); } const castConfig = { kernelName: tf.Cast, backendName: 'webgl', kernelFunc: cast }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const CEIL = `return ceil(x);`; const ceil = unaryKernelFunc({ opSnippet: CEIL, packedOpSnippet: CEIL, cpuKernelImpl: ceilImplCPU }); const ceilConfig = { kernelName: tf.Ceil, backendName: 'webgl', kernelFunc: ceil }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ClipProgram { constructor(aShape) { this.variableNames = ['A']; this.customUniforms = [ { name: 'minVal', type: 'float' }, { name: 'maxVal', type: 'float' } ]; this.outputShape = aShape; this.userCode = ` void main() { float value = getAAtOutCoords(); if (isnan(value)) { setOutput(value); return; } setOutput(clamp(value, minVal, maxVal)); } `; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ClipPackedProgram { constructor(aShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'minVal', type: 'float' }, { name: 'maxVal', type: 'float' } ]; this.outputShape = aShape; this.userCode = ` void main() { vec4 value = getAAtOutCoords(); if (any(isnan(value))) { setOutput(value); return; } setOutput(clamp(value, vec4(minVal), vec4(maxVal))); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function clipByValue(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { clipValueMin, clipValueMax } = attrs; let program; if (tf.env().getBool('WEBGL_PACK_CLIP')) { program = new ClipPackedProgram(x.shape); } else { program = new ClipProgram(x.shape); } const customValues = [[clipValueMin], [clipValueMax]]; return backend.runWebGLProgram(program, [x], x.dtype, customValues); } const clipByValueConfig = { kernelName: tf.ClipByValue, backendName: 'webgl', kernelFunc: clipByValue }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ComplexAbsProgram { constructor(shape) { this.variableNames = ['real', 'imag']; this.outputShape = shape; this.userCode = ` void main() { float re = abs(getRealAtOutCoords()); float im = abs(getImagAtOutCoords()); float mx = max(re, im); // sadly the length function in glsl is not underflow-safe // (at least not on Intel GPUs). So the safe solution is // to ensure underflow-safety in all cases. setOutput( mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx)) ); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Returns a TensorInfo with the complex shape and the dataId of the // underlying part. We need to do this because a reshaped complex tensor is // not reflected in its parts. function makeComplexComponentTensorInfo(complexTensor, complexPart) { return { dataId: complexPart.dataId, dtype: complexPart.dtype, shape: complexTensor.shape }; } function complexAbs(args) { const { inputs, backend } = args; const { x } = inputs; const xData = backend.texData.get(x.dataId); const program = new ComplexAbsProgram(x.shape); const programInputs = [ makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real), makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag), ]; return backend.runWebGLProgram(program, programInputs, programInputs[0].dtype); } const complexAbsConfig = { kernelName: tf.ComplexAbs, backendName: 'webgl', kernelFunc: complexAbs }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ConcatProgram { // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat(). constructor(shapes) { this.outputShape = []; this.outputShape = tf.backend_util.computeOutShape(shapes, 1 /* axis */); this.variableNames = shapes.map((_, i) => `T${i}`); const offsets = new Array(shapes.length - 1); offsets[0] = shapes[0][1]; for (let i = 1; i < offsets.length; i++) { offsets[i] = offsets[i - 1] + shapes[i][1]; } const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`]; for (let i = 1; i < offsets.length; i++) { const shift = offsets[i - 1]; snippets.push(`else if (yC < ${offsets[i]}) ` + `setOutput(getT${i}(yR, yC-${shift}));`); } const lastIndex = offsets.length; const lastShift = offsets[offsets.length - 1]; snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`); this.userCode = ` void main() { ivec2 coords = getOutputCoords(); int yR = coords.x; int yC = coords.y; ${snippets.join('\n ')} } `; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ConcatPackedProgram { constructor(shapes, axis) { this.packedInputs = true; this.packedOutput = true; this.outputShape = []; this.outputShape = tf.backend_util.computeOutShape(shapes, axis); const shape = this.outputShape; const rank = shape.length; const dtype = getCoordsDataType(rank); const coords = getChannels('coords', rank); const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank); this.variableNames = shapes.map((_, i) => `T${i}`); const offsets = new Array(shapes.length - 1); offsets[0] = shapes[0][axis]; for (let i = 1; i < offsets.length; i++) { offsets[i] = offsets[i - 1] + shapes[i][axis]; } const channel = channels[axis]; const lastChannels = channels.slice(-2); const allChannels = channels.join(); let getValueSnippet = `if (${channel} < ${offsets[0]}) { return getChannel( getT0(${allChannels}), vec2(${lastChannels.join()})); }`; for (let i = 1; i < offsets.length; i++) { const shift = offsets[i - 1]; // Note: the >= comparison below may seem unnecessary given the check // above but is needed to workaround branch execution issues on some // devices. It makes all the conditions exclusive without relying on // execution order. getValueSnippet += ` if (${channel} < ${offsets[i]} && ${channel} >= ${offsets[i - 1]}) { return getChannel( getT${i}(${shiftedChannels(channels, channel, shift)}), vec2(${shiftedChannels(lastChannels, channel, shift)})); }`; } const lastIndex = offsets.length; const shift = offsets[offsets.length - 1]; getValueSnippet += ` return getChannel( getT${lastIndex}(${shiftedChannels(channels, channel, shift)}), vec2(${shiftedChannels(lastChannels, channel, shift)}));`; this.userCode = ` float getValue(${channels.map(x => 'int ' + x)}) { ${getValueSnippet} } void main() { ${dtype} coords = getOutputCoords(); vec4 result = vec4(getValue(${coords}), 0., 0., 0.); ${coords[rank - 1]} = ${coords[rank - 1]} + 1; if (${coords[rank - 1]} < ${shape[rank - 1]}) { result.g = getValue(${coords}); } ${coords[rank - 2]} = ${coords[rank - 2]} + 1; if (${coords[rank - 2]} < ${shape[rank - 2]}) { result.a = getValue(${coords}); } ${coords[rank - 1]} = ${coords[rank - 1]} - 1; if (${coords[rank - 2]} < ${shape[rank - 2]} && ${coords[rank - 1]} < ${shape[rank - 1]}) { result.b = getValue(${coords}); } setOutput(result); } `; } } /** * Return an expression for coordinates into a vector where a given channel * will be offset by [shift]. * * @param channels the channels to consider * @param channel the channel we want shifted * @param shift the amount to subtract from the channel. * * @returns a string of the form 'x, y-[shift], z' where any one channel can * have the shift applied. */ function shiftedChannels(channels, channel, shift) { const channelIdx = channels.indexOf(channel); const res = channels.map((c, idx) => { if (idx === channelIdx) { return `${c} - ${shift}`; } else { return c; } }); return res.join(); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function imag(args) { const { inputs, backend } = args; const { input } = inputs; const inputData = backend.texData.get(input.dataId); return identity({ inputs: { x: inputData.complexTensorInfos.imag }, backend }); } const imagConfig = { kernelName: tf.Imag, backendName: 'webgl', kernelFunc: imag }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function concatImpl(inputs, axis, backend) { const dtype = inputs[0].dtype; if (dtype === 'complex64') { const reals = inputs.map((t) => real({ inputs: { input: t }, backend })); const imags = inputs.map((t) => imag({ inputs: { input: t }, backend })); const realConcated = concatImpl(reals, axis, backend); const imagConcated = concatImpl(imags, axis, backend); const result = complex({ inputs: { real: realConcated, imag: imagConcated }, backend }); reals.forEach(r => backend.disposeIntermediateTensorInfo(r)); imags.forEach(i => backend.disposeIntermediateTensorInfo(i)); backend.disposeIntermediateTensorInfo(realConcated); backend.disposeIntermediateTensorInfo(imagConcated); return result; } let runOnCpu = backend.shouldExecuteOnCPU(inputs); // Run on cpu if dtype is string. For string, the backend represents it // as Uint8Array[], where each Uint8Array is a character. Given that the // computation is only on the outer array, uploading the whole data onto // gpu is wasteful. Also, currently webgl doesn't have a design to // upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we // just run the kernel on cpu if dtype is string. if (dtype === 'string') { runOnCpu = true; } if (runOnCpu) { // Any concat of n-dimensional tensors across any axis can be reduced to // a concatenation of two-dimensional tensors across the axis 1 by first // partitioning the axes of the original tensors into those less than the // axis to be concatenated and the rest. Then reshape the tensors // into a two-dimensional tensor by collapsing these two sets of axes and // concatenate the resulting matrices across the axis 1, finally reshaping // the result to have the proper shape. const tensors2D = inputs.map(t => { const innerSize = tf.util.sizeFromShape(t.shape.slice(axis)); const shape = [-1, innerSize]; return reshape({ inputs: { x: t }, backend, attrs: { shape } }); }); const inputsValShapes = tensors2D.map(t => { return { vals: backend.readSync(t.dataId), shape: t.shape }; }); // Concats 2d tensors along axis=1. const outShape = tf.backend_util.computeOutShape(tensors2D.map(t => t.shape), 1 /* axis */); const simplyConcat = tensors2D[0].shape[0] === 1; const outVals = concatImplCPU(inputsValShapes, outShape, dtype, simplyConcat); const finalOutShape = tf.backend_util.computeOutShape(inputs.map(t => t.shape), axis); const outInfo = backend.makeTensorInfo(finalOutShape, dtype, outVals); tensors2D.forEach(t => backend.disposeIntermediateTensorInfo(t)); return outInfo; } // Keep only non-empty tensors (ignore tensors with 0 in their shape). const $inputs = inputs.filter(t => tf.util.sizeFromShape(t.shape) > 0); const shouldPack = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && $inputs[0].shape.length > 1; if ($inputs.length === 1) { // Clone tensor. const program = shouldPack ? new UnaryOpProgram(inputs[0].shape, CLONE) : new UnaryOpPackedProgram(inputs[0].shape, CLONE); return backend.runWebGLProgram(program, inputs, dtype); } const maxTexturesInShader = tf.env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER'); if ($inputs.length > maxTexturesInShader) { const reducedInputs = []; for (let i = 0; i < $inputs.length; i += maxTexturesInShader) { const subArray = $inputs.slice(i, i + maxTexturesInShader); reducedInputs.push(concatImpl(subArray, axis, backend)); } const result = concatImpl(reducedInputs, axis, backend); for (const i of reducedInputs) { backend.disposeIntermediateTensorInfo(i); } return result; } if (shouldPack) { const program = new ConcatPackedProgram($inputs.map(t => t.shape), axis); return backend.runWebGLProgram(program, $inputs, dtype); } const { tensors2D, outShape } = computeTensors2D($inputs, axis, backend); const program = new ConcatProgram(tensors2D.map(t => t.shape)); const result = backend.runWebGLProgram(program, tensors2D, dtype); tensors2D.forEach(r => backend.disposeIntermediateTensorInfo(r)); const reshapedResult = reshape({ inputs: { x: result }, attrs: { shape: outShape }, backend }); backend.disposeIntermediateTensorInfo(result); return reshapedResult; } function computeTensors2D(inputs, axis, backend) { // Any concat of n-dimensional tensors across any axis can be reduced to // a concatenation of two-dimensional tensors across the axis 1 by first // partitioning the axes of the original tensors into those less than the // axis to be concatenated and the rest. Then reshape the tensors // into a two-dimensional tensor by collapsing these two sets of axes and // concatenate the resulting matrices across the axis 1, finally reshaping // the result to have the proper shape. const outShape = tf.backend_util.computeOutShape(inputs.map(t => t.shape), axis); const tensors2D = inputs.map(x => reshape({ inputs: { x }, attrs: { shape: [-1, tf.util.sizeFromShape(x.shape.slice(axis))] }, backend })); return { tensors2D, outShape }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function concat(args) { const { inputs, backend, attrs } = args; const { axis } = attrs; const $axis = tf.util.parseAxisParam(axis, inputs[0].shape)[0]; const shapes = inputs.map(t => t.shape); tf.backend_util.assertParamsConsistent(shapes, $axis); const outShape = tf.backend_util.computeOutShape(inputs.map(t => t.shape), $axis); if (tf.util.sizeFromShape(outShape) === 0) { return backend.makeTensorInfo(outShape, inputs[0].dtype, []); } // Keep only non-empty tensors (ignore tensors with 0 in their shape). const $inputs = inputs.filter(t => tf.util.sizeFromShape(t.shape) > 0); if ($inputs.length === 1) { return identity({ inputs: { x: $inputs[0] }, backend }); } return concatImpl($inputs, $axis, backend); } const concatConfig = { kernelName: tf.Concat, backendName: 'webgl', kernelFunc: concat }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class Conv2DProgram { constructor(convInfo, addBias = false, activation = null, hasPreluActivationWeights = false, hasLeakyreluAlpha = false) { this.variableNames = ['x', 'W']; this.outputShape = convInfo.outShape; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4; const inputDepthVec4Remainder = convInfo.inChannels % 4; const isChannelsLast = convInfo.dataFormat === 'channelsLast'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivationWeights) { activationSnippet = `float activation(float a) { float b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyreluAlpha) { activationSnippet = `float activation(float a) { float b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = ` float activation(float x) { ${activation} } `; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivationWeights) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyreluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = ` ${activationSnippet} const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d2 = coords[${channelDim}]; ivec2 xRCCorner = ivec2(coords[${rowDim}], coords[${colDim}]) * strides - pads; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; for (int wR = 0; wR < ${filterHeight}; wR++) { int xR = xRCorner + wR * ${dilationHeight}; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${filterWidth}; wC++) { int xC = xCCorner + wC * ${dilationWidth}; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) { vec4 wValues = vec4( getW(wR, wC, d1, d2), getW(wR, wC, d1 + 1, d2), getW(wR, wC, d1 + 2, d2), getW(wR, wC, d1 + 3, d2) ); if (${isChannelsLast}) { vec4 xValues = vec4( getX(batch, xR, xC, d1), getX(batch, xR, xC, d1 + 1), getX(batch, xR, xC, d1 + 2), getX(batch, xR, xC, d1 + 3) ); dotProd += dot(xValues, wValues); } else { vec4 xValues = vec4( getX(batch, d1, xR, xC), getX(batch, d1 + 1, xR, xC), getX(batch, d1 + 2, xR, xC), getX(batch, d1 + 3, xR, xC) ); dotProd += dot(xValues, wValues); } } if (${inputDepthVec4Remainder === 1}) { if (${isChannelsLast}) { dotProd += getX(batch, xR, xC, ${inputDepthNearestVec4}) * getW(wR, wC, ${inputDepthNearestVec4}, d2); } else { dotProd += getX(batch, ${inputDepthNearestVec4}, xR, xC) * getW(wR, wC, ${inputDepthNearestVec4}, d2); } } else if (${inputDepthVec4Remainder === 2}) { vec2 wValues = vec2( getW(wR, wC, ${inputDepthNearestVec4}, d2), getW(wR, wC, ${inputDepthNearestVec4} + 1, d2) ); if (${isChannelsLast}) { vec2 xValues = vec2( getX(batch, xR, xC, ${inputDepthNearestVec4}), getX(batch, xR, xC, ${inputDepthNearestVec4} + 1) ); dotProd += dot(xValues, wValues); } else { vec2 xValues = vec2( getX(batch, ${inputDepthNearestVec4}, xR, xC), getX(batch, ${inputDepthNearestVec4} + 1, xR, xC) ); dotProd += dot(xValues, wValues); } } else if (${inputDepthVec4Remainder === 3}) { vec3 wValues = vec3( getW(wR, wC, ${inputDepthNearestVec4}, d2), getW(wR, wC, ${inputDepthNearestVec4} + 1, d2), getW(wR, wC, ${inputDepthNearestVec4} + 2, d2) ); if (${isChannelsLast}) { vec3 xValues = vec3( getX(batch, xR, xC, ${inputDepthNearestVec4}), getX(batch, xR, xC, ${inputDepthNearestVec4} + 1), getX(batch, xR, xC, ${inputDepthNearestVec4} + 2) ); dotProd += dot(xValues, wValues); } else { vec3 xValues = vec3( getX(batch, ${inputDepthNearestVec4}, xR, xC), getX(batch, ${inputDepthNearestVec4} + 1, xR, xC), getX(batch, ${inputDepthNearestVec4} + 2, xR, xC) ); dotProd += dot(xValues, wValues); } } } } float result = dotProd; ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } class Conv3DProgram { constructor(convInfo) { this.variableNames = ['x', 'W']; this.outputShape = convInfo.outShape; const padFront = convInfo.padInfo.front; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationDepth = convInfo.dilationDepth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const filterDepth = convInfo.filterDepth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4; const inputDepthVec4Remainder = convInfo.inChannels % 4; this.userCode = ` const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth}); const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int d2 = coords.u; ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads; int xFCorner = xFRCCorner.x; int xRCorner = xFRCCorner.y; int xCCorner = xFRCCorner.z; // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get // y(yF, yR, yC, d2). ? = to be determined. : = across all // values in that axis. float dotProd = 0.0; for (int wF = 0; wF < ${filterDepth}; wF++) { int xF = xFCorner + wF * ${dilationDepth}; if (xF < 0 || xF >= ${convInfo.inDepth}) { continue; } for (int wR = 0; wR < ${filterHeight}; wR++) { int xR = xRCorner + wR * ${dilationHeight}; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${filterWidth}; wC++) { int xC = xCCorner + wC * ${dilationWidth}; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) { vec4 xValues = vec4( getX(batch, xF, xR, xC, d1), getX(batch, xF, xR, xC, d1 + 1), getX(batch, xF, xR, xC, d1 + 2), getX(batch, xF, xR, xC, d1 + 3) ); vec4 wValues = vec4( getW(wF, wR, wC, d1, d2), getW(wF, wR, wC, d1 + 1, d2), getW(wF, wR, wC, d1 + 2, d2), getW(wF, wR, wC, d1 + 3, d2) ); dotProd += dot(xValues, wValues); } if (${inputDepthVec4Remainder === 1}) { dotProd += getX(batch, xF, xR, xC, ${inputDepthNearestVec4}) * getW(wF, wR, wC, ${inputDepthNearestVec4}, d2); } else if (${inputDepthVec4Remainder === 2}) { vec2 xValues = vec2( getX(batch, xF, xR, xC, ${inputDepthNearestVec4}), getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1) ); vec2 wValues = vec2( getW(wF, wR, wC, ${inputDepthNearestVec4}, d2), getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2) ); dotProd += dot(xValues, wValues); } else if (${inputDepthVec4Remainder === 3}) { vec3 xValues = vec3( getX(batch, xF, xR, xC, ${inputDepthNearestVec4}), getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1), getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 2) ); vec3 wValues = vec3( getW(wF, wR, wC, ${inputDepthNearestVec4}, d2), getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2), getW(wF, wR, wC, ${inputDepthNearestVec4} + 2, d2) ); dotProd += dot(xValues, wValues); } } } } setOutput(dotProd); } `; } } /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class Conv2DPackedProgram { constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) { this.variableNames = ['x', 'W']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'pads', type: 'ivec2' }, { name: 'strides', type: 'ivec2' }, { name: 'dilations', type: 'ivec2' }, { name: 'inDims', type: 'ivec2' }, ]; this.outputShape = convInfo.outShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const padLeft = convInfo.padInfo.left; const strideWidth = convInfo.strideWidth; const dilationWidth = convInfo.dilationWidth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const texelsAcross = filterWidth; let mainLoop = ` int xR; int xC; int xCOffset; vec4 wTexel; vec4 previous; vec4 final;`; for (let c = 0; c < filterWidth; c++) { mainLoop += ` vec4 xTexelC${c * 2}; int xTexelC${c * 2}Ready; vec4 xTexelC${c * 2 + 1}; int xTexelC${c * 2 + 1}Ready; vec4 xC${c};`; } /** * This vectorized implementation works by gathering the values needed for * each output channel's dot product into vec4's and then multiplying them * all together (this happens in the final double for-loop below). Most of * the main loop consists of constructing these vec4's with the minimum * number of texture2D calls, which means making use of all four returned * values from a texture2D call at once. */ mainLoop += ` for (int r = 0; r < ${filterHeight}; r++) { for (int d1 = 0; d1 < ${convInfo.inChannels}; d1 += 2) { `; for (let c = 0; c < filterWidth; c++) { mainLoop += ` xTexelC${c * 2} = vec4(0.0); xTexelC${c * 2}Ready = 0; xTexelC${c * 2 + 1} = vec4(0.0); xTexelC${c * 2 + 1}Ready = 0; xC${c} = vec4(0.0);`; } mainLoop += ` xR = xRCorner + r * dilations[0]; if (xR >=0 && xR < inDims[0]) { `; for (let texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) { const colIndex = texelC * 2; mainLoop += ` xC = xCCorner + ${colIndex * dilationWidth}; `; if (strideWidth === 1) { if (colIndex < filterWidth) { // If padding is odd, the outer texels have to be composed. if (padLeft % 2 === 1) { // TODO: Ensure vec4 previous does not result in redundant sample, // and avoid setting xTexelRC's that exceed the boundary in the // first place rather than resetting them to vec4(0)). // To compute xCOffset: // - If padding is odd, we must add 1 to ensure we ask for an // even-numbered row. // - We subtract 2 to access the previous texel. mainLoop += ` xCOffset = xC + 1; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xCOffset, d1); // Need to manually clear unused channels in case // we're reading from recycled texture. if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } `; // This texel has been read in previous iteration if the dilation // is 1. if (dilationWidth === 1 && colIndex > 0) { mainLoop += ` xC${colIndex} = vec4(xTexelC${colIndex - 2}.zw, xTexelC${colIndex}.xy); `; } else { mainLoop += ` xCOffset = xC + 1 - 2; if (xCOffset >= 0 && xCOffset < inDims[1]) { previous = getX(batch, xR, xCOffset, d1); // Need to manually clear unused channels in case // we're reading from recycled texture. if (xCOffset + 1 >= inDims[1]) { previous.zw = vec2(0.0); } xC${colIndex} = vec4(previous.zw, xTexelC${colIndex}.xy); } else { xC${colIndex} = vec4(0.0, 0.0, xTexelC${colIndex}.xy); } `; } } else { // Padding is even, so xRC corresponds to a single texel. mainLoop += ` if (xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xC, d1); if (xC + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } xC${colIndex} = xTexelC${colIndex}; `; } if (colIndex + 1 < filterWidth) { // If dilation is even, the second entry should match the first // (either both are composed or both are single samples). But if // dilation is odd, then the second entry should be the opposite // of the first (if the first is composed, the second is a single // sample, and vice versa.) const nextTexelOffset = padLeft % 2 === 0 ? tf.util.nearestLargerEven(dilationWidth) : dilationWidth; if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) || (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) { mainLoop += ` xCOffset = xC + imod(pads[1], 2) + ${nextTexelOffset}; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); // Need to manually clear unused channels in case // we're reading from recycled texture. if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } `; // If dilation > 1 then the xRC's will not be able to share any // values, so each xRC will require two unique calls to getX. if (dilationWidth > 1) { mainLoop += ` xCOffset -= 2; if (xCOffset >= 0 && xCOffset < inDims[1]) { previous = getX(batch, xR, xCOffset, d1); xC${colIndex + 1} = vec4(previous.zw, xTexelC${colIndex + 1}.xy); } else { xC${colIndex + 1} = vec4(0.0, 0.0, xTexelC${colIndex + 1}.xy); } `; } else { mainLoop += ` xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.xy); `; } } else { // If dilation is 1 and padding is odd, we have already read the // texel when constructing the previous x value. Here we can // simply skip the texture read. if (nextTexelOffset === 1) { mainLoop += ` xC${colIndex + 1} = xTexelC${colIndex}; `; } else { mainLoop += ` xCOffset = xC + ${nextTexelOffset}; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex + 1} = xTexelC${colIndex + 1}; `; } } } } } else { // stride === 2 if (colIndex < filterWidth) { // Depending on whether padLeft is even or odd, we want either the // xy or zw channels from X texels for xC${colIndex}. If padLeft is // even, xC${colIndex +1} is simply the zw channels of texels we've // already sampled. But if padLeft is odd, xC{$c + 1}.zw will // need to come from the xy channels of a new texel, hence the ` // vec4 // final` initialized below. if (padLeft % 2 === 1) { mainLoop += ` xCOffset = xC + 1 - strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xCOffset, d1); // Need to manually clear unused channels in case // we're reading from recycled texture. if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xC + 1, d1); // Need to manually clear unused channels in case // we're reading from recycled texture. if (xC + 2 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw); `; if (colIndex + 1 < filterWidth) { mainLoop += ` final = vec4(0.0); xCOffset = xC + 1 + strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1]) { final = getX(batch, xR, xCOffset, d1); } xC${colIndex + 1} = vec4(xTexelC${colIndex + 1}.xy, final.xy); `; } } else { mainLoop += ` if(xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xC, d1); if (xC + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } xCOffset = xC + strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex} = vec4( xTexelC${colIndex}.xy, xTexelC${colIndex + 1}.xy); `; if (colIndex + 1 < filterWidth) { mainLoop += ` xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw); `; } } } } // localize the dotProd accumulation within the loop, the theory is for // GPU with limited cache, accumulate sum across large amount of // veriables will cause lots of cache misses. (i.e. 5x5 filter will have // 50 variables) if (colIndex < filterWidth) { mainLoop += ` wTexel = getW(r, ${colIndex}, d1, d2); dotProd += xC${colIndex}.xxzz * vec4(wTexel.xy, wTexel.xy); if(d1 + 1 < ${convInfo.inChannels}) { dotProd += xC${colIndex}.yyww * vec4(wTexel.zw, wTexel.zw); } `; if (colIndex + 1 < filterWidth) { mainLoop += ` wTexel = getW(r, ${colIndex + 1}, d1, d2); dotProd += xC${colIndex + 1}.xxzz * vec4(wTexel.xy, wTexel.xy); if(d1 + 1 < ${convInfo.inChannels}) { dotProd += xC${colIndex + 1}.yyww * vec4(wTexel.zw, wTexel.zw); } `; } } } mainLoop += ` } `; mainLoop += ` } `; mainLoop += ` } `; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyReluAlpha) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = `vec4 activation(vec4 x) { ${activation} }`; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyReluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = ` ${activationSnippet} void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; ivec2 xRCCorner = coords.yz * strides - pads; int d2 = coords.w; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; //intialize dotProd with a small epsilon seems to reduce GPU accuracy loss. vec4 dotProd = vec4(0.000000000000001); ${mainLoop} vec4 result = dotProd - vec4(0.000000000000001); ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class Im2ColPackedProgram { constructor(outputShape, convInfo) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'inputShape', type: 'ivec4' }, { name: 'pad', type: 'ivec2' }, { name: 'stride', type: 'ivec2' }, { name: 'dilation', type: 'ivec2' }, { name: 'inChannels', type: 'int' }, { name: 'itemsPerBlockRow', type: 'int' }, { name: 'outWidth', type: 'int' }, ]; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const { dataFormat } = convInfo; const glsl = getGlslDifferences(); const isChannelsLast = dataFormat === 'channelsLast'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const boundsCheckingSnippet = this.enableShapeUniforms ? 'if(blockIndex < outShape[2] && pos < outShape[1]) {' : `if(blockIndex < ${outputShape[2]} && pos < ${outputShape[1]}) {`; let unrolled = ``; for (let row = 0; row <= 1; row++) { for (let col = 0; col <= 1; col++) { unrolled += ` blockIndex = rc.z + ${col}; pos = rc.y + ${row}; ${boundsCheckingSnippet} offsetY = int(blockIndex / outWidth) * stride[0] - pad[0]; d0 = offsetY + dilation[0] * (pos / itemsPerBlockRow); if(d0 < inputShape[${rowDim}] && d0 >= 0) { // Use custom imod instead mod. On Intel GPU, mod may generate // unexpected value. // https://github.com/tensorflow/tfjs/issues/5447 offsetX = imod(blockIndex, outWidth) * stride[1] - pad[1]; d1 = offsetX + dilation[1] * (imod(pos, itemsPerBlockRow) / inChannels); if(d1 < inputShape[${colDim}] && d1 >= 0) { ch = imod(pos, inChannels); if (${isChannelsLast}) { innerDims = vec2(d1, ch); result[${row * 2 + col}] = getChannel( getA(rc.x, d0, int(innerDims.x), int(innerDims.y)), innerDims); } else { innerDims = vec2(d0, d1); result[${row * 2 + col}] = getChannel( getA(rc.x, ch, int(innerDims.x), int(innerDims.y)), innerDims); } } } } `; } } this.userCode = ` void main() { ivec3 rc = getOutputCoords(); vec4 result = vec4(0); int blockIndex, pos, offsetY, d0, offsetX, d1, ch; vec2 innerDims; ${unrolled} ${glsl.output} = result; } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Both conv2dByMatMul and conv2dWithIm2Row fuse height and width into one // dimension to compute batchMatMul, so bias and activation weights are also // supposed to fuse the two dimensions into one. // // This function computes the target shape for fusing height and width // dimensions. Returning null means the shape is already compatible. // // Even though the bias is not supposed to be a 3-D or a 4-D (including // batch) tensor and PReLU activiation weights is not supposed to be a 4-D // tensor, we still need to support them, because we haven't disabled // them for NHWC format. // https://github.com/tensorflow/tfjs/blob/b53bd47e880367ae57493f0ea628abaf08db2d5d/tfjs-core/src/ops/fused/conv2d.ts#L181-L196 function getShapeForBatchMatMul(shape, isChannelsLast) { const length = shape.length; if (length >= 3) { return isChannelsLast ? [ ...shape.slice(0, -3) /* batch */, shape[length - 3] * shape[length - 2] /* height * width */, shape[length - 1] /* channel */ ] : [ ...shape.slice(0, -3) /* batch */, shape[length - 3] /* channel */, shape[length - 2] * shape[length - 1] /* height * width */ ]; } else if (!isChannelsLast && length === 1 && shape[0] > 1) { return [shape[0], 1]; } else { return null; } } // For 1x1 kernels that iterate through every point in the input, convolution // can be expressed as matrix multiplication (without need for memory // remapping). function conv2dByMatMul({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) { // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the // result from 2D to 4D. const xShape = x.shape; const xTexData = backend.texData.get(x.dataId); const sharedMatMulDim = convInfo.inChannels; const outerShapeX = xShape[0] * xShape[1] * xShape[2]; const outerShapeFilter = convInfo.outChannels; const isChannelsLast = convInfo.dataFormat === 'channelsLast'; const transposeA = false; const transposeB = false; let out; const intermediates = []; if (preluActivationWeights != null) { const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast); if (targetShape != null) { preluActivationWeights = reshape({ inputs: { x: preluActivationWeights }, backend, attrs: { shape: targetShape } }); intermediates.push(preluActivationWeights); } } if (bias != null) { const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast); if (targetShape != null) { bias = reshape({ inputs: { x: bias }, backend, attrs: { shape: targetShape } }); intermediates.push(bias); } } // TODO: Once reduction ops are packed, batchMatMul will always be packed // and we can remove this condition. const batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) && sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD; // The algorithm in the if condition assumes (1) the output will be packed, // (2) x is packed, (3) x isChannelsLast, (4) x's packed texture is already // on GPU, (5) col is odd, (6) the width, height and inChannels are the same // for xTexData.shape and xShape. const canOptimize = !batchMatMulWillBeUnpacked && xTexData.isPacked && isChannelsLast && xTexData.texture != null && xShape[2] % 2 !== 0 && tf.util.arraysEqual(xTexData.shape.slice(-3), xShape.slice(-3)); if (canOptimize) { // We avoid expensive packed 2x2 reshape by padding col count to next, // even number. When col is odd, the result of packed batchMatMul is // the same (has the same texture layout and and values in the texture) as // it is for next even col. We make the odd-cols tensor to look like // even-cols tensor before the operation and, after the batchMatMul, // fix the even-cols result to have odd number of cols. const targetShape = xShape[0] * xShape[1] * (xShape[2] + 1); const xReshaped = { dataId: x.dataId, shape: [1, targetShape, convInfo.inChannels], dtype: x.dtype }; // xTexData.shape gets referenced from GPGPUBinary.inShapeInfos. // Decrementing col count, after batchMatMul->...->compileProgram leads to // invalid col count within the reference in GPGPUBinary.inShapeInfos. // Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos // in compileProgram method, but that would affect compilation of all // programs - instead, provide a copy here, with even col count, before // calling batchMatMul->...->compileProgram and after that, the original // xTexData.shape is restored. const originalXTexDataShape = xTexData.shape; xTexData.shape = xTexData.shape.slice(); xTexData.shape[xTexData.shape.length - 2]++; tf.util.assert(isReshapeFree(xTexData.shape, xReshaped.shape), () => `packed reshape ${xTexData.shape} to ${xReshaped.shape} isn't free`); const filterReshaped = reshape({ inputs: { x: filter }, backend, attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] } }); intermediates.push(filterReshaped); const pointwiseConv = batchMatMulImpl({ a: xReshaped, b: filterReshaped, backend, transposeA, transposeB, bias, activation, preluActivationWeights, leakyreluAlpha }); const pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId); tf.util.assert(pointwiseConvTexData.isPacked, () => 'batchMatMul result is expected to be packed'); // Restore the input shape to original. xTexData.shape = originalXTexDataShape; // Set the output shape - there is no need for expensive reshape as data // layout is already correct. pointwiseConvTexData.shape = convInfo.outShape; out = identity({ inputs: { x: pointwiseConv }, backend }); out.shape = convInfo.outShape; intermediates.push(pointwiseConv); } else { const numCols = convInfo.outHeight * convInfo.outWidth; const xReshaped = reshape({ inputs: { x }, backend, attrs: { shape: isChannelsLast ? [convInfo.batchSize, numCols, convInfo.inChannels] : [convInfo.batchSize, convInfo.inChannels, numCols] } }); const filterReshaped = reshape({ inputs: { x: filter }, backend, attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] } }); const result = batchMatMulImpl({ a: isChannelsLast ? xReshaped : filterReshaped, b: isChannelsLast ? filterReshaped : xReshaped, transposeA: !isChannelsLast, transposeB, backend, bias, activation, preluActivationWeights, leakyreluAlpha }); out = reshape({ inputs: { x: result }, backend, attrs: { shape: convInfo.outShape } }); intermediates.push(xReshaped); intermediates.push(filterReshaped); intermediates.push(result); } for (const i of intermediates) { backend.disposeIntermediateTensorInfo(i); } return out; } // Implements the im2row algorithm as outlined in "High Performance // Convolutional Neural Networks for Document Processing" (Suvisoft, 2006) function conv2dWithIm2Row({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) { // Rearranges conv2d input so each block to be convolved over forms the // column of a new matrix with shape [filterWidth * filterHeight * // inChannels, outHeight * outWidth]. The filter is also rearranged so each // output channel forms a row of a new matrix with shape [outChannels, // filterWidth * filterHeight * inChannels]. The convolution is then // computed by multiplying these matrices and reshaping the result. const { filterWidth, filterHeight, inChannels, outWidth, outHeight, dataFormat } = convInfo; const isChannelsLast = dataFormat === 'channelsLast'; const sharedDim = filterWidth * filterHeight * inChannels; const numCols = outHeight * outWidth; const x2ColShape = [convInfo.batchSize, sharedDim, numCols]; const transposeA = true; const transposeB = false; const intermediates = []; if (preluActivationWeights != null) { const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast); if (targetShape != null) { preluActivationWeights = reshape({ inputs: { x: preluActivationWeights }, backend, attrs: { shape: targetShape } }); intermediates.push(preluActivationWeights); } } if (bias != null) { const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast); if (targetShape != null) { bias = reshape({ inputs: { x: bias }, backend, attrs: { shape: targetShape } }); intermediates.push(bias); } } const w2Row = reshape({ inputs: { x: filter }, backend, attrs: { shape: [1, sharedDim, tf.util.sizeFromShape(filter.shape) / sharedDim] } }); intermediates.push(w2Row); const im2ColProgram = new Im2ColPackedProgram(x2ColShape, convInfo); const customValues = [ x.shape, [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inChannels], [convInfo.filterWidth * convInfo.inChannels], [convInfo.outWidth] ]; const im2Col = backend.runWebGLProgram(im2ColProgram, [x], 'float32', customValues); const im2ColReshaped = reshape({ inputs: { x: im2Col }, backend, attrs: { shape: x2ColShape } }); intermediates.push(im2Col); intermediates.push(im2ColReshaped); const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null; const matmulProgram = new MatMulPackedProgram(isChannelsLast ? im2ColReshaped.shape : w2Row.shape, isChannelsLast ? w2Row.shape : im2ColReshaped.shape, isChannelsLast ? [convInfo.batchSize, numCols, convInfo.outChannels] : [convInfo.batchSize, convInfo.outChannels, numCols], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const inputs = isChannelsLast ? [im2ColReshaped, w2Row] : [w2Row, im2ColReshaped]; if (bias) { inputs.push(bias); } if (hasPreluActivationWeights) { inputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } const product = backend.runWebGLProgram(matmulProgram, inputs, 'float32'); const out = reshape({ inputs: { x: product }, backend, attrs: { shape: convInfo.outShape } }); intermediates.push(product); for (const i of intermediates) { backend.disposeIntermediateTensorInfo(i); } return out; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv2d(args) { const { inputs, backend, attrs } = args; const { x, filter } = inputs; const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs; const $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat); const convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat); let out; if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) { out = conv2dByMatMul({ x, filter, convInfo, backend }); } else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast' && tf.env().getBool('WEBGL_EXP_CONV')) { const program = new Conv2DPackedProgram(convInfo); const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; out = backend.runWebGLProgram(program, [x, filter], 'float32', customValues); } else if (tf.env().getBool('WEBGL_CONV_IM2COL')) { out = conv2dWithIm2Row({ x, filter, convInfo, backend }); } else { const program = new Conv2DProgram(convInfo); out = backend.runWebGLProgram(program, [x, filter], 'float32'); } const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } }); backend.disposeIntermediateTensorInfo(out); return outReshaped; } const conv2DConfig = { kernelName: tf.Conv2D, backendName: 'webgl', kernelFunc: conv2d, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class Conv2DDerFilterProgram { constructor(convInfo) { this.variableNames = ['x', 'dy']; this.outputShape = convInfo.filterShape; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; const isChannelsLast = convInfo.dataFormat === 'channelsLast'; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int wR = coords.x; int wC = coords.y; int d1 = coords.z; int d2 = coords.w; // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; for (int b = 0; b < ${convInfo.batchSize}; b++) { for (int yR = 0; yR < ${convInfo.outHeight}; yR++) { int xR = wR + yR * ${strideHeight} - ${padTop}; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int yC = 0; yC < ${convInfo.outWidth}; yC++) { int xC = wC + yC * ${strideWidth} - ${padLeft}; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } ${isChannelsLast ? `float dyValue = getDy(b, yR, yC, d2); float xValue = getX(b, xR, xC, d1); dotProd += (xValue * dyValue);` : `float dyValue = getDy(b, d2, yR, yC); float xValue = getX(b, d1, xR, xC); dotProd += (xValue * dyValue);`} } } } setOutput(dotProd); } `; } } class Conv2DDerInputProgram { constructor(convInfo) { this.variableNames = ['dy', 'W']; this.outputShape = convInfo.inShape; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const isChannelsLast = convInfo.dataFormat === 'channelsLast'; const padTop = filterHeight - 1 - convInfo.padInfo.top; const padLeft = filterWidth - 1 - convInfo.padInfo.left; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; this.userCode = ` const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d1 = coords[${channelDim}]; ivec2 dyCorner = ivec2(coords[${rowDim}], coords[${colDim}]) - pads; int dyRCorner = dyCorner.x; int dyCCorner = dyCorner.y; // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; for (int wR = 0; wR < ${filterHeight}; wR++) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); int wRPerm = ${filterHeight} - 1 - wR; for (int wC = 0; wC < ${filterWidth}; wC++) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); int wCPerm = ${filterWidth} - 1 - wC; for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) { if (${isChannelsLast}) { float xValue = getDy(batch, idyR, idyC, d2); float wValue = getW(wRPerm, wCPerm, d1, d2); dotProd += xValue * wValue; } else { float xValue = getDy(batch, d2, idyR, idyC); float wValue = getW(wRPerm, wCPerm, d1, d2); dotProd += xValue * wValue; } } } } setOutput(dotProd); } `; } } class Conv3DDerFilterProgram { constructor(convInfo) { this.variableNames = ['x', 'dy']; this.outputShape = convInfo.filterShape; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const padFront = convInfo.padInfo.front; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; this.userCode = ` void main() { ivec5 coords = getOutputCoords(); int wF = coords.x; int wR = coords.y; int wC = coords.z; int d1 = coords.w; int d2 = coords.u; float dotProd = 0.0; for (int b = 0; b < ${convInfo.batchSize}; b++) { for (int yF = 0; yF < ${convInfo.outDepth}; yF++) { int xF = wF + yF * ${strideDepth} - ${padFront}; if (xF < 0 || xF >= ${convInfo.inDepth}) { continue; } for (int yR = 0; yR < ${convInfo.outHeight}; yR++) { int xR = wR + yR * ${strideHeight} - ${padTop}; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int yC = 0; yC < ${convInfo.outWidth}; yC++) { int xC = wC + yC * ${strideWidth} - ${padLeft}; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } float dyValue = getDy(b, yF, yR, yC, d2); float xValue = getX(b, xF, xR, xC, d1); dotProd += (xValue * dyValue); } } } } setOutput(dotProd); } `; } } class Conv3DDerInputProgram { constructor(convInfo) { this.variableNames = ['dy', 'W']; this.outputShape = convInfo.inShape; const filterDepth = convInfo.filterDepth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const padFront = filterDepth - 1 - convInfo.padInfo.front; const padTop = filterHeight - 1 - convInfo.padInfo.top; const padLeft = filterWidth - 1 - convInfo.padInfo.left; this.userCode = ` const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int d1 = coords.u; ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads; int dyFCorner = dyCorner.x; int dyRCorner = dyCorner.y; int dyCCorner = dyCorner.z; float dotProd = 0.0; for (int wF = 0; wF < ${filterDepth}; wF++) { float dyF = float(dyFCorner + wF) / ${strideDepth}.0; if (dyF < 0.0 || dyF >= ${convInfo.outDepth}.0 || fract(dyF) > 0.0) { continue; } int idyF = int(dyF); int wFPerm = ${filterDepth} - 1 - wF; for (int wR = 0; wR < ${filterHeight}; wR++) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); int wRPerm = ${filterHeight} - 1 - wR; for (int wC = 0; wC < ${filterWidth}; wC++) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); int wCPerm = ${filterWidth} - 1 - wC; for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) { float xValue = getDy(batch, idyF, idyR, idyC, d2); float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2); dotProd += xValue * wValue; } } } } setOutput(dotProd); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv2DBackpropFilter(args) { const { inputs, backend, attrs } = args; const { x, dy } = inputs; const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs; const $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat); const convInfo = tf.backend_util.computeConv2DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad, dimRoundingMode, false /* depthwise */, $dataFormat); const program = new Conv2DDerFilterProgram(convInfo); return backend.runWebGLProgram(program, [x, dy], 'float32'); } const conv2DBackpropFilterConfig = { kernelName: tf.Conv2DBackpropFilter, backendName: 'webgl', kernelFunc: conv2DBackpropFilter, }; /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class Conv2DDerInputPackedProgram { constructor(convInfo) { this.variableNames = ['dy', 'W']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'strides', type: 'vec2' }, ]; this.outputShape = convInfo.inShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const padTop = filterHeight - 1 - convInfo.padInfo.top; const padLeft = filterWidth - 1 - convInfo.padInfo.left; this.userCode = ` const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d1 = coords[3]; ivec2 dyCorner = ivec2(coords[1], coords[2]) - pads; int dyRCorner = dyCorner.x; int dyCCorner = dyCorner.y; vec4 result = vec4(0.); for (int wR = 0; wR < ${filterHeight}; wR++) { float dyR = float(dyRCorner + wR) / strides[0]; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); int wRPerm = ${filterHeight} - 1 - wR; for (int wC = 0; wC < ${filterWidth}; wC++) { int wCPerm = ${filterWidth} - 1 - wC; float dyC = float(dyCCorner + wC) / strides[1]; bool idyCVal = (dyC >= 0.0) && (dyC < ${convInfo.outWidth}.0) && (fract(dyC) == 0.0); int idyC = int(dyC); float dyC2 = float(dyCCorner + wC + 1) / strides[1]; bool idyCVal2 = (dyC2 >= 0.0) && (dyC2 < ${convInfo.outWidth}.0) && (fract(dyC2) == 0.0); int idyC2 = int(dyC2); if (idyCVal && idyCVal2) { for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) { vec4 wValue = getW(wRPerm, wCPerm, d1, d2); vec4 dySample = getDy(batch, idyR, idyC, d2); vec4 dySample2 = (idyC / 2 == idyC2 / 2) ? dySample : getDy(batch, idyR, idyC2, d2); vec2 dyValue = mod(float(idyC), 2.) == 0. ? dySample.xy : dySample.zw; result.xy += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); dyValue = mod(float(idyC2), 2.) == 0. ? dySample2.xy : dySample2.zw; result.zw += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); } } else if (idyCVal) { for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) { vec4 wValue = getW(wRPerm, wCPerm, d1, d2); vec4 dySample = getDy(batch, idyR, idyC, d2); vec2 dyValue = mod(float(idyC), 2.) == 0. ? dySample.xy : dySample.zw; result.xy += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); } } else if (idyCVal2) { for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) { vec4 wValue = getW(wRPerm, wCPerm, d1, d2); vec4 dySample = getDy(batch, idyR, idyC2, d2); vec2 dyValue = mod(float(idyC2), 2.) == 0. ? dySample.xy : dySample.zw; result.zw += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); } } } } setOutput(result); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv2DBackpropInput(args) { const { inputs, backend, attrs } = args; const { dy, filter } = inputs; const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs; const $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat); const convInfo = tf.backend_util.computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat); if (tf.env().getBool('WEBGL_PACK_CONV2DTRANSPOSE') && $dataFormat === 'channelsLast') { const customValues = [ [convInfo.strideHeight, convInfo.strideWidth], ]; const program = new Conv2DDerInputPackedProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32', customValues); } else { const program = new Conv2DDerInputProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32'); } } const conv2DBackpropInputConfig = { kernelName: tf.Conv2DBackpropInput, backendName: 'webgl', kernelFunc: conv2DBackpropInput, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv3D(args) { const { inputs, backend, attrs } = args; const { x, filter } = inputs; const { strides, pad, dilations } = attrs; const convInfo = tf.backend_util.computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad); const program = new Conv3DProgram(convInfo); return backend.runWebGLProgram(program, [x, filter], 'float32'); } const conv3DConfig = { kernelName: tf.Conv3D, backendName: 'webgl', kernelFunc: conv3D, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv3DBackpropFilterV2(args) { const { inputs, backend, attrs } = args; const { x, dy } = inputs; const { strides, pad, filterShape } = attrs; const convInfo = tf.backend_util.computeConv3DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad); const program = new Conv3DDerFilterProgram(convInfo); return backend.runWebGLProgram(program, [x, dy], 'float32'); } const conv3DBackpropFilterV2Config = { kernelName: tf.Conv3DBackpropFilterV2, backendName: 'webgl', kernelFunc: conv3DBackpropFilterV2 }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function conv3DBackpropInput(args) { const { inputs, backend, attrs } = args; const { dy, filter } = inputs; const { pad, strides, inputShape } = attrs; const convInfo = tf.backend_util.computeConv3DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad); const program = new Conv3DDerInputProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32'); } const conv3DBackpropInputConfig = { kernelName: tf.Conv3DBackpropInputV2, backendName: 'webgl', kernelFunc: conv3DBackpropInput, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const COS = CHECK_NAN_SNIPPET_UNARY + ` return cos(x); `; const COS_PACKED = ` vec4 result = cos(x); bvec4 isNaN = isnan(x); ${CHECK_NAN_SNIPPET_PACKED} return result; `; const cos = unaryKernelFunc({ opSnippet: COS, packedOpSnippet: COS_PACKED }); const cosConfig = { kernelName: tf.Cos, backendName: 'webgl', kernelFunc: cos, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const COSH = ` float e2x = exp(-x); return (e2x + 1.0 / e2x) / 2.0; `; const cosh = unaryKernelFunc({ opSnippet: COSH }); const coshConfig = { kernelName: tf.Cosh, backendName: 'webgl', kernelFunc: cosh, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class CropAndResizeProgram { constructor(imageShape, boxShape, cropSize, method, extrapolationValue) { this.variableNames = ['Image', 'Boxes', 'BoxInd']; this.outputShape = []; const [batch, imageHeight, imageWidth, depth] = imageShape; const [numBoxes,] = boxShape; const [cropHeight, cropWidth] = cropSize; this.outputShape = [numBoxes, cropHeight, cropWidth, depth]; const methodId = method === 'bilinear' ? 1 : 0; const [inputHeightFloat, inputWidthFloat] = [`${imageHeight - 1}.0`, `${imageWidth - 1}.0`]; const [heightRatio, heightScale, inY] = cropHeight > 1 ? [ `${(imageHeight - 1) / (cropHeight - 1)}`, '(y2-y1) * height_ratio', `y1*${inputHeightFloat} + float(y)*(height_scale)`, ] : [ '0.0', '0.0', `0.5 * (y1+y2) * ${inputHeightFloat}`, ]; const [widthRatio, widthScale, inX] = cropWidth > 1 ? [ `${(imageWidth - 1) / (cropWidth - 1)}`, '(x2-x1) * width_ratio', `x1*${inputWidthFloat} + float(x)*(width_scale)`, ] : [ '0.0', '0.0', `0.5 * (x1+x2) * ${inputWidthFloat}`, ]; // Reference implementation // tslint:disable-next-line:max-line-length // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc this.userCode = ` const float height_ratio = float(${heightRatio}); const float width_ratio = float(${widthRatio}); void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int y = coords[1]; int x = coords[2]; int d = coords[3]; // get box vals float y1 = getBoxes(b,0); float x1 = getBoxes(b,1); float y2 = getBoxes(b,2); float x2 = getBoxes(b,3); // get image in batch index int bInd = round(getBoxInd(b)); if(bInd < 0 || bInd >= ${batch}) { return; } float height_scale = ${heightScale}; float width_scale = ${widthScale}; float in_y = ${inY}; if( in_y < 0.0 || in_y > ${inputHeightFloat} ) { setOutput(float(${extrapolationValue})); return; } float in_x = ${inX}; if( in_x < 0.0 || in_x > ${inputWidthFloat} ) { setOutput(float(${extrapolationValue})); return; } vec2 sourceFracIndexCR = vec2(in_x,in_y); if(${methodId} == 1) { // Compute the four integer indices. ivec2 sourceFloorCR = ivec2(sourceFracIndexCR); ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR)); float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d); float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d); float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d); float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d); vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR); float top = topLeft + (topRight - topLeft) * fracCR.x; float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x; float newValue = top + (bottom - top) * fracCR.y; setOutput(newValue); } else { // Compute the coordinators of nearest neighbor point. ivec2 sourceNearestCR = ivec2(floor( sourceFracIndexCR + vec2(0.5,0.5))); float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d); setOutput(newValue); } } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const cropAndResize = (args) => { const { inputs, backend, attrs } = args; const { image, boxes, boxInd } = inputs; const { cropSize, method, extrapolationValue } = attrs; const program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue); return backend.runWebGLProgram(program, [image, boxes, boxInd], 'float32'); }; const cropAndResizeConfig = { kernelName: tf.CropAndResize, backendName: 'webgl', kernelFunc: cropAndResize }; var CumOpType; (function (CumOpType) { CumOpType["Prod"] = "*"; CumOpType["Sum"] = "+"; })(CumOpType || (CumOpType = {})); class CumProgram { constructor(op, outputShape, exclusive, reverse) { this.op = op; this.outputShape = outputShape; this.variableNames = ['x']; this.customUniforms = [{ name: 'index', type: 'float' }]; const rank = this.outputShape.length; const initVal = this.op === CumOpType.Prod ? '1.0' : '0.0'; const val = exclusive ? initVal : `getX(${getCoords(rank, 'coords', this.op)})`; const length = this.outputShape[this.outputShape.length - 1]; let condition = ''; let idxString = ''; // When exclusive is set, the cum op becomes roll op that copies the // value from the previous index based on the direction specified by the // reverse flag. if (exclusive) { condition = reverse ? `end != ${length - 1}` : 'end != 0'; idxString = reverse ? 'end + 1' : 'end - 1'; } else { condition = reverse ? `end + pow2 < ${length}` : 'end >= pow2'; idxString = (reverse ? 'end + pow2' : 'end - pow2'); } this.userCode = ` void main() { ${getCoordsDataType(rank)} coords = getOutputCoords(); int end = ${getFinalCoord(rank, 'coords', this.op)}; float val = ${val}; int pow2 = int(pow(2.0, index)); if (${condition}) { int idx = ${idxString}; ${getFinalCoord(rank, 'coords', this.op)} = idx; val ${this.op}= getX(${getCoords(rank, 'coords', this.op)}); } setOutput(val); } `; } } function getCoords(rank, name, op) { if (rank === 1) { return `${name}`; } else if (rank === 2) { return `${name}.x, ${name}.y`; } else if (rank === 3) { return `${name}.x, ${name}.y, ${name}.z`; } else if (rank === 4) { return `${name}.x, ${name}.y, ${name}.z, ${name}.w`; } else { throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`); } } function getFinalCoord(rank, name, op) { if (rank === 1) { return `${name}`; } else if (rank === 2) { return `${name}.y`; } else if (rank === 3) { return `${name}.z`; } else if (rank === 4) { return `${name}.w`; } else { throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`); } } /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function cumImpl(op, x, backend, axis, exclusive, reverse) { const xRank = x.shape.length; const permutation = tf.backend_util.getAxesPermutation([axis], xRank); let permutedX = x; if (permutation != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } }); } const permutedAxis = tf.backend_util.getInnerMostAxes(1, xRank)[0]; if (permutedAxis !== xRank - 1) { throw new Error(`WebGL cumprod shader expects an inner-most axis=${x.shape.length - 1} ` + `but got axis=${axis}`); } const size = permutedX.shape[permutedAxis]; let result = identity({ inputs: { x: permutedX }, backend }); // Use cum parallel algorithm, inspired by: // https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda // Note: although the algorithm is called sum, it works for any associtative // operator with an identity. for (let i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) { const program = new CumProgram(op, permutedX.shape, false, reverse); const customValues = [[i]]; const prevResult = result; result = backend.runWebGLProgram(program, [result], result.dtype, customValues); backend.disposeIntermediateTensorInfo(prevResult); } // For exclusive cum, shift the end result in the direction of product or sum // and add 1 for product or 0 for sum to the front index. if (exclusive) { const program = new CumProgram(op, permutedX.shape, exclusive, reverse); const prevResult = result; result = backend.runWebGLProgram(program, [result], result.dtype); backend.disposeIntermediateTensorInfo(prevResult); } if (permutation != null) { const reversePermutation = tf.backend_util.getUndoAxesPermutation(permutation); const reverseTransposedResult = transpose({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } }); backend.disposeIntermediateTensorInfo(result); backend.disposeIntermediateTensorInfo(permutedX); return reverseTransposedResult; } return result; } /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function cumprod(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, exclusive, reverse } = attrs; return cumImpl(CumOpType.Prod, x, backend, axis, exclusive, reverse); } const cumprodConfig = { kernelName: tf.Cumprod, backendName: 'webgl', kernelFunc: cumprod }; /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function cumsum(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, exclusive, reverse } = attrs; return cumImpl(CumOpType.Sum, x, backend, axis, exclusive, reverse); } const cumsumConfig = { kernelName: tf.Cumsum, backendName: 'webgl', kernelFunc: cumsum }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function denseBincount(args) { const { inputs, backend, attrs } = args; const { x, weights } = inputs; const { size, binaryOutput } = attrs; if (x.shape.length === 1) { const xVals = backend.readSync(x.dataId); const weightsVals = backend.readSync(weights.dataId); const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size); return backend.makeTensorInfo([size], weights.dtype, outVals); } else if (x.shape.length === 2) { const xBuf = backend.bufferSync(x); const weightsBuf = backend.bufferSync(weights); const outBuf = bincountReduceImplCPU(xBuf, weightsBuf, size, binaryOutput); return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values); } throw new Error(`Error in denseBincount: input must be at most rank 2, but got rank` + `${x.shape.length}.`); } const denseBincountConfig = { kernelName: tf.DenseBincount, backendName: 'webgl', kernelFunc: denseBincount }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class DepthToSpaceProgram { constructor(outputShape, blockSize, dataFormat) { this.variableNames = ['x']; this.outputShape = []; this.outputShape = outputShape; this.blockSize = blockSize; this.dataFormat = dataFormat; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int h = ${this.getHeightCoordString()}; int w = ${this.getWidthCoordString()}; int d = ${this.getDepthCoordString()}; int in_h = h / ${blockSize}; int offset_h = imod(h, ${blockSize}); int in_w = w / ${blockSize}; int offset_w = imod(w, ${blockSize}); int offset_d = (offset_h * ${blockSize} + offset_w) * ${this.getOutputDepthSize()}; int in_d = d + offset_d; float result = ${this.getInputSamplingString()}; setOutput(result); } `; } getHeightCoordString() { if (this.dataFormat === 'NHWC') { return `coords[1]`; } else { return `coords[2]`; } } getWidthCoordString() { if (this.dataFormat === 'NHWC') { return `coords[2]`; } else { return `coords[3]`; } } getDepthCoordString() { if (this.dataFormat === 'NHWC') { return `coords[3]`; } else { return `coords[1]`; } } getOutputDepthSize() { if (this.dataFormat === 'NHWC') { return this.outputShape[3]; } else { return this.outputShape[1]; } } getInputSamplingString() { if (this.dataFormat === 'NHWC') { return `getX(b, in_h, in_w, in_d)`; } else { return `getX(b, in_d, in_h, in_w)`; } } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function depthToSpace(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { blockSize, dataFormat } = attrs; const batchSize = x.shape[0]; const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2]; const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3]; const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1]; const outputHeight = inputHeight * blockSize; const outputWidth = inputWidth * blockSize; const outputDepth = inputDepth / (blockSize * blockSize); const outputShape = (dataFormat === 'NHWC') ? [batchSize, outputHeight, outputWidth, outputDepth] : [batchSize, outputDepth, outputHeight, outputWidth]; const program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat); return backend.runWebGLProgram(program, [x], x.dtype); } const depthToSpaceConfig = { kernelName: tf.DepthToSpace, backendName: 'webgl', kernelFunc: depthToSpace }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class DepthwiseConv2DProgram { constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) { this.variableNames = ['x', 'W']; this.customUniforms = [ { name: 'pads', type: 'ivec2' }, { name: 'strides', type: 'ivec2' }, { name: 'dilations', type: 'ivec2' }, { name: 'inDims', type: 'ivec2' }, ]; this.outputShape = convInfo.outShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const channelMul = convInfo.outChannels / convInfo.inChannels; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = `float activation(float a) { float b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyReluAlpha) { activationSnippet = `float activation(float a) { float b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = ` float activation(float x) { ${activation} } `; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyReluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = ` ${activationSnippet} void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; ivec2 xRCCorner = coords.yz * strides - pads; int d2 = coords.w; int d1 = d2 / ${channelMul}; int q = d2 - d1 * ${channelMul}; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations. for (int wR = 0; wR < ${filterHeight}; wR++) { int xR = xRCorner + wR * dilations[0]; if (xR < 0 || xR >= inDims[0]) { continue; } for (int wC = 0; wC < ${filterWidth}; wC++) { int xC = xCCorner + wC * dilations[1]; if (xC < 0 || xC >= inDims[1]) { continue; } float xVal = getX(batch, xR, xC, d1); float wVal = getW(wR, wC, d1, q); dotProd += xVal * wVal; } } float result = dotProd; ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class DepthwiseConvPacked2DProgram { constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) { this.variableNames = ['x', 'W']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'pads', type: 'ivec2' }, { name: 'strides', type: 'ivec2' }, { name: 'dilations', type: 'ivec2' }, { name: 'inDims', type: 'ivec2' }, ]; this.outputShape = convInfo.outShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const channelMul = convInfo.outChannels / convInfo.inChannels; const padLeft = convInfo.padInfo.left; const strideWidth = convInfo.strideWidth; const dilationWidth = convInfo.dilationWidth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const texelsAcross = filterWidth; let mainLoop = ` int xR; int xC; int xCOffset; vec4 wTexel; vec4 previous; vec4 final;`; for (let c = 0; c < filterWidth; c++) { mainLoop += ` vec4 xTexelC${c * 2}; int xTexelC${c * 2}Ready; vec4 xTexelC${c * 2 + 1}; int xTexelC${c * 2 + 1}Ready; vec4 xC${c};`; } /** * This vectorized implementation works by gathering the values needed for * each output channel's dot product into vec4's and then multiplying them * all together (this happens in the final double for-loop below). Most of * the main loop consists of constructing these vec4's with the minimum * number of texture2D calls, which means making use of all four returned * values from a texture2D call at once. */ mainLoop += ` for (int r = 0; r < ${filterHeight}; r++) { `; for (let c = 0; c < filterWidth; c++) { mainLoop += ` xTexelC${c * 2} = vec4(0.0); xTexelC${c * 2}Ready = 0; xTexelC${c * 2 + 1} = vec4(0.0); xTexelC${c * 2 + 1}Ready = 0; xC${c} = vec4(0.0);`; } mainLoop += ` xR = xRCorner + r * dilations[0]; if (xR >=0 && xR < inDims[0]) { `; for (let texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) { const colIndex = texelC * 2; mainLoop += ` xC = xCCorner + ${colIndex * dilationWidth}; `; if (strideWidth === 1) { if (colIndex < filterWidth) { // If padding is odd, the outer texels have to be composed. if (padLeft % 2 === 1) { // TODO: Ensure vec4 previous does not result in redundant sample, // and avoid setting xTexelRC's that exceed the boundary in the // first place rather than resetting them to vec4(0)). // To compute xCOffset: // - If padding is odd, we must add 1 to ensure we ask for an // even-numbered row. // - We subtract 2 to access the previous texel. mainLoop += ` xCOffset = xC + 1; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xCOffset, d1); // Need to manually clear unused channels in case // we're reading from recycled texture. if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } `; // This texel has been read in previous iteration if the dilation // is 1. if (dilationWidth === 1 && colIndex > 0) { mainLoop += ` xC${colIndex} = vec4(xTexelC${colIndex - 2}.zw, xTexelC${colIndex}.xy); `; } else { mainLoop += ` xCOffset = xC + 1 - 2; if (xCOffset >= 0 && xCOffset < inDims[1]) { previous = getX(batch, xR, xCOffset, d1); // Need to manually clear unused channels in case // we're reading from recycled texture. if (xCOffset + 1 >= inDims[1]) { previous.zw = vec2(0.0); } xC${colIndex} = vec4(previous.zw, xTexelC${colIndex}.xy); } else { xC${colIndex} = vec4(0.0, 0.0, xTexelC${colIndex}.xy); } `; } } else { // Padding is even, so xRC corresponds to a single texel. mainLoop += ` if (xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xC, d1); if (xC + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } xC${colIndex} = xTexelC${colIndex}; `; } if (colIndex + 1 < filterWidth) { // If dilation is even, the second entry should match the first // (either both are composed or both are single samples). But if // dilation is odd, then the second entry should be the opposite // of the first (if the first is composed, the second is a single // sample, and vice versa.) const nextTexelOffset = padLeft % 2 === 0 ? tf.util.nearestLargerEven(dilationWidth) : dilationWidth; if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) || (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) { mainLoop += ` xCOffset = xC + imod(pads[1], 2) + ${nextTexelOffset}; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); // Need to manually clear unused channels in case // we're reading from recycled texture. if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } `; // If dilation > 1 then the xRC's will not be able to share any // values, so each xRC will require two unique calls to getX. if (dilationWidth > 1) { mainLoop += ` xCOffset -= 2; if (xCOffset >= 0 && xCOffset < inDims[1]) { previous = getX(batch, xR, xCOffset, d1); xC${colIndex + 1} = vec4(previous.zw, xTexelC${colIndex + 1}.xy); } else { xC${colIndex + 1} = vec4(0.0, 0.0, xTexelC${colIndex + 1}.xy); } `; } else { mainLoop += ` xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.xy); `; } } else { // If dilation is 1 and padding is odd, we have already read the // texel when constructing the previous x value. Here we can // simply skip the texture read. if (nextTexelOffset === 1) { mainLoop += ` xC${colIndex + 1} = xTexelC${colIndex}; `; } else { mainLoop += ` xCOffset = xC + ${nextTexelOffset}; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex + 1} = xTexelC${colIndex + 1}; `; } } } } } else { // stride === 2 if (colIndex < filterWidth) { // Depending on whether padLeft is even or odd, we want either the // xy or zw channels from X texels for xC${colIndex}. If padLeft is // even, xC${colIndex +1} is simply the zw channels of texels we've // already sampled. But if padLeft is odd, xC{$c + 1}.zw will // need to come from the xy channels of a new texel, hence the ` // vec4 // final` initialized below. if (padLeft % 2 === 1) { mainLoop += ` xCOffset = xC + 1 - strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xCOffset, d1); // Need to manually clear unused channels in case // we're reading from recycled texture. if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xC + 1, d1); // Need to manually clear unused channels in case // we're reading from recycled texture. if (xC + 2 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw); `; if (colIndex + 1 < filterWidth) { mainLoop += ` final = vec4(0.0); xCOffset = xC + 1 + strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1]) { final = getX(batch, xR, xCOffset, d1); } xC${colIndex + 1} = vec4(xTexelC${colIndex + 1}.xy, final.xy); `; } } else { mainLoop += ` if(xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xC, d1); if (xC + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } xCOffset = xC + strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex} = vec4( xTexelC${colIndex}.xy, xTexelC${colIndex + 1}.xy); `; if (colIndex + 1 < filterWidth) { mainLoop += ` xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw); `; } } } } // localize the dotProd accumulation within the loop, the theory is for // GPU with limited cache, accumulate sum across large amount of // veriables will cause lots of cache misses. (i.e. 5x5 filter will have // 50 variables) if (colIndex < filterWidth) { mainLoop += ` wTexel = getW(r, ${colIndex}, d1, q); dotProd += xC${colIndex} * vec4(wTexel.xz, wTexel.xz); `; if (colIndex + 1 < filterWidth) { mainLoop += ` wTexel = getW(r, ${colIndex + 1}, d1, q); dotProd += xC${colIndex + 1} * vec4(wTexel.xz, wTexel.xz); `; } } } mainLoop += ` } `; mainLoop += ` } `; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyReluAlpha) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = `vec4 activation(vec4 x) { ${activation} }`; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyReluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = ` ${activationSnippet} void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; ivec2 xRCCorner = coords.yz * strides - pads; int d2 = coords.w; int d1 = d2 / ${channelMul}; int q = d2 - d1 * ${channelMul}; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; //intialize dotProd with a small epsilon seems to reduce GPU accuracy loss. vec4 dotProd = vec4(0.000000000000001); ${mainLoop} vec4 result = dotProd - vec4(0.000000000000001); ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function depthwiseConv2dNative(args) { const { inputs, backend, attrs } = args; const { x, filter } = inputs; const { strides, pad, dilations, dimRoundingMode } = attrs; let $dilations = dilations; if ($dilations == null) { $dilations = [1, 1]; } tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' + `1. Got strides ${strides} and dilations '${$dilations}'`); const convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */); let program; if (tf.env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1) { program = new DepthwiseConvPacked2DProgram(convInfo); } else { program = new DepthwiseConv2DProgram(convInfo); } const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; return backend.runWebGLProgram(program, [x, filter], 'float32', customValues); } const depthwiseConv2dNativeConfig = { kernelName: tf.DepthwiseConv2dNative, backendName: 'webgl', kernelFunc: depthwiseConv2dNative, }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class DepthwiseConv2DDerFilterProgram { constructor(convInfo) { this.variableNames = ['x', 'dy']; this.outputShape = convInfo.filterShape; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; const channelMul = convInfo.outChannels / convInfo.inChannels; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int wR = coords.x; int wC = coords.y; int d1 = coords.z; int dm = coords.w; int d2 = d1 * ${channelMul} + dm; float dotProd = 0.0; // TO DO: Vec4 over the batch size for (int b = 0; b < ${convInfo.batchSize}; b++) { for (int yR = 0; yR < ${convInfo.outHeight}; yR++) { int xR = wR + yR * ${strideHeight} - ${padTop}; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int yC = 0; yC < ${convInfo.outWidth}; yC++) { int xC = wC + yC * ${strideWidth} - ${padLeft}; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } float dyValue = getDy(b, yR, yC, d2); float xValue = getX(b, xR, xC, d1); dotProd += (xValue * dyValue); } } } setOutput(dotProd); } `; } } class DepthwiseConv2DDerInputProgram { constructor(convInfo) { this.variableNames = ['dy', 'W']; this.outputShape = convInfo.inShape; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const padTop = filterHeight - 1 - convInfo.padInfo.top; const padLeft = filterWidth - 1 - convInfo.padInfo.left; const channelMul = convInfo.outChannels / convInfo.inChannels; this.userCode = ` const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d1 = coords[3]; ivec2 dyCorner = coords.yz - pads; int dyRCorner = dyCorner.x; int dyCCorner = dyCorner.y; float dotProd = 0.0; for (int wR = 0; wR < ${filterHeight}; wR++) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); int wRPerm = ${filterHeight} - 1 - wR; for (int wC = 0; wC < ${filterWidth}; wC++) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); int wCPerm = ${filterWidth} - 1 - wC; // TO DO: Vec4 over the channelMul for (int dm = 0; dm < ${channelMul}; dm++) { int d2 = d1 * ${channelMul} + dm; float xValue = getDy(batch, idyR, idyC, d2); float wValue = getW(wRPerm, wCPerm, d1, dm); dotProd += xValue * wValue; } } } setOutput(dotProd); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function depthwiseConv2dNativeBackpropFilter(args) { const { inputs, backend, attrs } = args; const { x, dy } = inputs; const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs; const convInfo = tf.backend_util.computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true /* depthwise */); const program = new DepthwiseConv2DDerFilterProgram(convInfo); return backend.runWebGLProgram(program, [x, dy], 'float32'); } const depthwiseConv2dNativeBackpropFilterConfig = { kernelName: tf.DepthwiseConv2dNativeBackpropFilter, backendName: 'webgl', kernelFunc: depthwiseConv2dNativeBackpropFilter }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function depthwiseConv2dNativeBackpropInput(args) { const { inputs, backend, attrs } = args; const { dy, filter } = inputs; const { strides, dilations, pad, dimRoundingMode, inputShape } = attrs; const convInfo = tf.backend_util.computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */); const program = new DepthwiseConv2DDerInputProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32'); } const depthwiseConv2dNativeBackpropInputConfig = { kernelName: tf.DepthwiseConv2dNativeBackpropInput, backendName: 'webgl', kernelFunc: depthwiseConv2dNativeBackpropInput }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class DiagProgram { constructor(size) { this.variableNames = ['X']; this.outputShape = [size, size]; this.userCode = ` void main() { ivec2 coords = getOutputCoords(); float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0; setOutput(val); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function diag(args) { const { inputs, backend } = args; const { x } = inputs; const outShape = [...x.shape, ...x.shape]; const xSize = tf.util.sizeFromShape(x.shape); const flat = reshape({ inputs: { x }, backend, attrs: { shape: [xSize] } }); const program = new DiagProgram(xSize); const res = backend.runWebGLProgram(program, [flat], flat.dtype); const out = reshape({ inputs: { x: res }, backend, attrs: { shape: outShape } }); backend.disposeIntermediateTensorInfo(flat); backend.disposeIntermediateTensorInfo(res); return out; } const diagConfig = { kernelName: tf.Diag, backendName: 'webgl', kernelFunc: diag }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class Dilation2DProgram { constructor(convInfo) { this.variableNames = ['x', 'W']; this.outputShape = convInfo.outShape; const { inHeight, inWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth } = convInfo; const { top: padTop, left: padLeft } = padInfo; this.userCode = ` const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); const ivec2 pads = ivec2(${padTop}, ${padLeft}); const float neg_infinity = -3.4e38; void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; int d1 = coords.w; ivec2 outTopLeftCorner = coords.yz * strides - pads; int hBeg = outTopLeftCorner.x; int wBeg = outTopLeftCorner.y; float curVal = neg_infinity; for (int h = 0; h < ${filterHeight}; h++) { int hIn = hBeg + h * ${dilationHeight}; if (hIn >= 0 && hIn < ${inHeight}) { for (int w = 0; w < ${filterWidth}; w++) { int wIn = wBeg + w * ${dilationWidth}; if (wIn >= 0 && wIn < ${inWidth}) { float xVal = getX(batch, hIn, wIn, d1); float wVal = getW(h, w, d1); float val = xVal + wVal; if (val > curVal) { curVal = val; } } } } } float result = curVal; setOutput(result); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function dilation2D(args) { const { inputs, backend, attrs } = args; const { x, filter } = inputs; const { strides, pad, dilations } = attrs; const convInfo = tf.backend_util.computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations); let out; const program = new Dilation2DProgram(convInfo); out = backend.runWebGLProgram(program, [x, filter], 'float32'); const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } }); backend.disposeIntermediateTensorInfo(out); return outReshaped; } const dilation2DConfig = { kernelName: tf.Dilation2D, backendName: 'webgl', kernelFunc: dilation2D, }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function einsum(args) { const { inputs, backend, attrs } = args; const { equation } = attrs; const tensors = inputs; const { allDims, summedDims, idDims } = tf.backend_util.decodeEinsumEquation(equation, tensors.length); tf.backend_util.checkEinsumDimSizes(allDims.length, idDims, tensors); const { path, steps } = tf.backend_util.getEinsumComputePath(summedDims, idDims); const nSteps = steps.length; let out = null; let numDimsRemaining = allDims.length; const tensorsToDispose = []; for (let i = 0; i < nSteps; ++i) { for (const idTerm of steps[i]) { const { permutationIndices: perm, expandDims: dimsToExpand } = tf.backend_util.getEinsumPermutation(numDimsRemaining, idDims[idTerm]); let x; if (tf.backend_util.isIdentityPermutation(perm)) { x = tensors[idTerm]; } else { x = transpose({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } }); tensorsToDispose.push(x); } const targetShape = x.shape.slice(); for (let k = 0; k < dimsToExpand.length; ++k) { targetShape.splice(dimsToExpand[k], 0, 1); } if (!tf.util.arraysEqual(x.shape, targetShape)) { x = reshape({ inputs: { x }, backend, attrs: { shape: targetShape } }); tensorsToDispose.push(x); } if (out === null) { out = x; } else { // tslint:disable-next-line: no-unnecessary-type-assertion out = multiply({ inputs: { a: x, b: out }, backend }); tensorsToDispose.push(out); } } if (i < nSteps - 1) { if (path[i] >= 0) { out = sum({ inputs: { x: out }, backend, attrs: { axis: path[i] - (allDims.length - numDimsRemaining), keepDims: false } }); tensorsToDispose.push(out); } numDimsRemaining--; } } // Clean up intermediate tensors. for (const tensorInfo of tensorsToDispose) { if (tensorInfo === out) { continue; } backend.disposeIntermediateTensorInfo(tensorInfo); } return out; } const einsumConfig = { kernelName: tf.Einsum, backendName: 'webgl', kernelFunc: einsum }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`; const ELU_PACKED = ` vec4 result; result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0); result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0); result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0); result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0); return result; `; const elu = unaryKernelFunc({ opSnippet: ELU, packedOpSnippet: ELU_PACKED }); const eluConfig = { kernelName: tf.Elu, backendName: 'webgl', kernelFunc: elu }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ELU_DER = `return (b >= 0.0) ? a : a * (b + 1.0);`; const ELU_DER_PACKED = ` vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.))); return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0)))); `; const eluGrad = (args) => { const { inputs, backend } = args; const { dy, y } = inputs; const program = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(ELU_DER_PACKED, dy.shape, y.shape) : new BinaryOpProgram(ELU_DER, dy.shape, y.shape); return backend.runWebGLProgram(program, [dy, y], dy.dtype); }; const eluGradConfig = { kernelName: tf.EluGrad, backendName: 'webgl', kernelFunc: eluGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const PACKED_EQUAL = ` return vec4(equal(a, b)); `; const EQUAL = `return float(a == b);`; const equal = binaryKernelFunc({ opSnippet: EQUAL, packedOpSnippet: PACKED_EQUAL, dtype: 'bool', cpuKernelImpl: equalImplCPU, }); const equalConfig = { kernelName: tf.Equal, backendName: 'webgl', kernelFunc: equal }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ERF = ` // Error function is calculated approximately with elementary function. // See "Handbook of Mathematical Functions with Formulas, // Graphs, and Mathematical Tables", Abramowitz and Stegun. float p = ${tf.backend_util.ERF_P}; float a1 = ${tf.backend_util.ERF_A1}; float a2 = ${tf.backend_util.ERF_A2}; float a3 = ${tf.backend_util.ERF_A3}; float a4 = ${tf.backend_util.ERF_A4}; float a5 = ${tf.backend_util.ERF_A5}; float sign = sign(x); x = abs(x); float t = 1.0 / (1.0 + p * x); return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x)); `; const erf = unaryKernelFunc({ opSnippet: ERF }); const erfConfig = { kernelName: tf.Erf, backendName: 'webgl', kernelFunc: erf, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const EXP = CHECK_NAN_SNIPPET_UNARY + ` return exp(x); `; const EXP_PACKED = ` vec4 result = exp(x); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const exp = unaryKernelFunc({ opSnippet: EXP, packedOpSnippet: EXP_PACKED, cpuKernelImpl: expImplCPU, dtype: 'float32', }); const expConfig = { kernelName: tf.Exp, backendName: 'webgl', kernelFunc: exp }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function expandDims(args) { const { inputs, attrs, backend } = args; const { dim } = attrs; const { input } = inputs; const inputRank = input.shape.length; const newShape = input.shape.slice(); let $dim = dim; if (dim < 0) { // Negative value is counted from the tail of rank. tf.util.assert(-(inputRank + 1) <= dim, () => `Axis must be in the interval [${-(inputRank + 1)}, ${inputRank}]`); $dim = inputRank + dim + 1; } newShape.splice($dim, 0, 1); return reshape({ inputs: { x: input }, backend, attrs: { shape: newShape } }); } const expandDimsConfig = { kernelName: tf.ExpandDims, backendName: 'webgl', kernelFunc: expandDims, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const EXPM1 = `return exp(x) - 1.0;`; const expm1 = unaryKernelFunc({ opSnippet: EXPM1, packedOpSnippet: EXPM1, cpuKernelImpl: expm1ImplCPU }); const expm1Config = { kernelName: tf.Expm1, backendName: 'webgl', kernelFunc: expm1 }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class FFTProgram { constructor(component, inputShape, inverse) { this.variableNames = ['real', 'imag']; const innerDim = inputShape[1]; this.outputShape = inputShape; const exponentMultiplierSnippet = inverse ? `2.0 * ${Math.PI}` : `-2.0 * ${Math.PI}`; const resultDenominator = inverse ? `${innerDim}.0` : '1.0'; let opString; if (component === 'real') { opString = 'return real * expR - imag * expI;'; } else if (component === 'imag') { opString = 'return real * expI + imag * expR;'; } else { throw new Error(`FFT component must be either "real" or "imag", got ${component}.`); } this.userCode = ` const float exponentMultiplier = ${exponentMultiplierSnippet}; float unaryOpComplex(float real, float expR, float imag, float expI) { ${opString} } float mulMatDFT(int batch, int index) { float indexRatio = float(index) / float(${innerDim}); float exponentMultiplierTimesIndexRatio = exponentMultiplier * indexRatio; float result = 0.0; for (int i = 0; i < ${innerDim}; i++) { // x = (-2|2 * PI / N) * index * i; float x = exponentMultiplierTimesIndexRatio * float(i); float expR = cos(x); float expI = sin(x); float real = getReal(batch, i); float imag = getImag(batch, i); result += unaryOpComplex(real, expR, imag, expI) / ${resultDenominator}; } return result; } void main() { ivec2 coords = getOutputCoords(); setOutput(mulMatDFT(coords[0], coords[1])); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function fftImpl(x, inverse, backend) { const xData = backend.texData.get(x.dataId); const inputSize = tf.util.sizeFromShape(x.shape); // Collapse all outer dimensions to a single batch dimension. const innerDimensionSize = x.shape[x.shape.length - 1]; const batch = inputSize / innerDimensionSize; const input2D = reshape({ inputs: { x }, backend, attrs: { shape: [batch, innerDimensionSize] } }); const xShape = input2D.shape; const realProgram = new FFTProgram('real', xShape, inverse); const imagProgram = new FFTProgram('imag', xShape, inverse); const inputs = [ { dataId: xData.complexTensorInfos.real.dataId, dtype: xData.complexTensorInfos.real.dtype, shape: xShape }, { dataId: xData.complexTensorInfos.imag.dataId, dtype: xData.complexTensorInfos.imag.dtype, shape: xShape } ]; const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32'); const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32'); const complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(imagPart); const complexOutputReshaped = reshape({ inputs: { x: complexOutput }, backend, attrs: { shape: x.shape } }); backend.disposeIntermediateTensorInfo(input2D); backend.disposeIntermediateTensorInfo(complexOutput); return complexOutputReshaped; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function fft(args) { const { inputs, backend } = args; const { input } = inputs; return fftImpl(input, false /* inverse */, backend); } const fftConfig = { kernelName: tf.FFT, backendName: 'webgl', kernelFunc: fft }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class FillProgram { constructor(shape, value) { this.outputShape = []; this.customUniforms = [{ name: 'value', type: 'float' }]; this.variableNames = ['x']; this.outputShape = shape; this.userCode = ` void main() { // Input can be obtained from uniform value. setOutput(value); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function fill(args) { const { backend, attrs } = args; const { shape, value } = attrs; let { dtype } = attrs; dtype = dtype || tf.util.inferDtype(value); if (dtype === 'string') { // String type should be handled in CPU memory. const values = tf.util.getArrayFromDType(dtype, tf.util.sizeFromShape(shape)); values.fill(value); return backend.makeTensorInfo(shape, dtype, values); } else { const program = new FillProgram(shape, value); const customValues = [[value]]; return backend.runWebGLProgram(program, [], dtype, customValues); } } const fillConfig = { kernelName: tf.Fill, backendName: 'webgl', kernelFunc: fill }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class FlipLeftRightProgram { constructor(imageShape) { this.variableNames = ['Image']; this.outputShape = []; const imageWidth = imageShape[2]; this.outputShape = imageShape; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int x = coords[2]; int coordX = ${imageWidth} - x - 1; float outputValue; if(coordX >= 0 && coordX < ${imageWidth}) { outputValue = getImage(coords[0], coords[1], coordX, coords[3]); } else { outputValue = getImage(coords[0], coords[1], coords[2], coords[3]); } setOutput(outputValue); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const flipLeftRightConfig = { kernelName: tf.FlipLeftRight, backendName: 'webgl', kernelFunc: ({ inputs, backend }) => { const { image } = inputs; const webglBackend = backend; const program = new FlipLeftRightProgram(image.shape); const output = webglBackend.runWebGLProgram(program, [image], image.dtype); return output; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const FLOOR = `return floor(x);`; const floor = unaryKernelFunc({ opSnippet: FLOOR, packedOpSnippet: FLOOR, cpuKernelImpl: floorImplCPU }); const floorConfig = { kernelName: tf.Floor, backendName: 'webgl', kernelFunc: floor, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // We use native integer division to deal with floating point imprecision. Since // we implement floor division and glsl implements truncated division, we // correct for this by subtracting 1 from result when the result is negative and // there is a remainder. const INT_DIV = ` float s = sign(a) * sign(b); int ia = round(a); int ib = round(b); if (ib != 0) { // Windows (D3D) wants guaranteed non-zero int division at compile-time. return float(idiv(ia, ib, s)); } else { return NAN; } `; const INT_DIV_PACKED = ` ivec4 ia = round(a); ivec4 ib = round(b); bvec4 cond = notEqual(ib, ivec4(0)); ivec4 result = ivec4(0); vec4 s = sign(a) * sign(b); // Windows (D3D) wants guaranteed non-zero int division at compile-time. if (cond[0]) { result[0] = idiv(ia[0], ib[0], s[0]); } if (cond[1]) { result[1] = idiv(ia[1], ib[1], s[1]); } if (cond[2]) { result[2] = idiv(ia[2], ib[2], s[2]); } if (cond[3]) { result[3] = idiv(ia[3], ib[3], s[3]); } return vec4(result); `; const floorDiv = binaryKernelFunc({ opSnippet: INT_DIV, packedOpSnippet: INT_DIV_PACKED, dtype: 'int32' }); const floorDivConfig = { kernelName: tf.FloorDiv, backendName: 'webgl', kernelFunc: floorDiv }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class FromPixelsProgram { constructor(outputShape) { this.variableNames = ['A']; const glsl = getGlslDifferences(); const [height, width,] = outputShape; this.outputShape = outputShape; this.userCode = ` void main() { ivec3 coords = getOutputCoords(); int texR = coords[0]; int texC = coords[1]; int depth = coords[2]; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0); vec4 values = ${glsl.texture2D}(A, uv); float value; if (depth == 0) { value = values.r; } else if (depth == 1) { value = values.g; } else if (depth == 2) { value = values.b; } else if (depth == 3) { value = values.a; } setOutput(floor(value * 255.0 + 0.5)); } `; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class FromPixelsPackedProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; const glsl = getGlslDifferences(); const [height, width,] = outputShape; this.outputShape = outputShape; this.userCode = ` void main() { ivec3 coords = getOutputCoords(); int texR = coords[0]; int texC = coords[1]; int depth = coords[2]; vec4 result = vec4(0.); for(int row=0; row<=1; row++) { for(int col=0; col<=1; col++) { texC = coords[1] + row; depth = coords[2] + col; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0); vec4 values = ${glsl.texture2D}(A, uv); float value; if (depth == 0) { value = values.r; } else if (depth == 1) { value = values.g; } else if (depth == 2) { value = values.b; } else if (depth == 3) { value = values.a; } result[row * 2 + col] = floor(value * 255.0 + 0.5); } } ${glsl.output} = result; } `; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const fromPixelsConfig = { kernelName: tf.FromPixels, backendName: 'webgl', kernelFunc: fromPixels, }; let fromPixels2DContext; let willReadFrequently = tf.env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU'); function fromPixels(args) { const { inputs, backend, attrs } = args; let { pixels } = inputs; const { numChannels } = attrs; const isVideo = typeof (HTMLVideoElement) !== 'undefined' && pixels instanceof HTMLVideoElement; const isImage = typeof (HTMLImageElement) !== 'undefined' && pixels instanceof HTMLImageElement; const [width, height] = isVideo ? [ pixels.videoWidth, pixels.videoHeight ] : [pixels.width, pixels.height]; const texShape = [height, width]; const outShape = [height, width, numChannels]; if (isImage || isVideo) { const newWillReadFrequently = tf.env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU'); if (fromPixels2DContext == null || newWillReadFrequently !== willReadFrequently) { willReadFrequently = newWillReadFrequently; fromPixels2DContext = document.createElement('canvas').getContext('2d', { willReadFrequently }); } fromPixels2DContext.canvas.width = width; fromPixels2DContext.canvas.height = height; fromPixels2DContext.drawImage(pixels, 0, 0, width, height); pixels = fromPixels2DContext.canvas; } const tempPixelHandle = backend.makeTensorInfo(texShape, 'int32'); // This is a byte texture with pixels. backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS; backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels); const program = tf.env().getBool('WEBGL_PACK') ? new FromPixelsPackedProgram(outShape) : new FromPixelsProgram(outShape); const res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32'); backend.disposeData(tempPixelHandle.dataId); return res; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function fusedConv2d(args) { const { inputs, backend, attrs } = args; const { x, filter, bias, preluActivationWeights } = inputs; const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs; const $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat); const convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat); let out; const intermediates = []; const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; const prepareInputs = () => { const inputs = [x, filter]; // If the input is a 1-D tensor, align it with the channels. // // For fusedConv2d, the inputs (x, W, bias, preluActivationWeights) are // supposed to be aligned with the dataFormat. The 4-D tensor inputs or // scalar inputs are originally aligned, but the 1-D tensor inputs are // supposed to be aligned with the channels (only bias and PReLU activation // weights could be a 1-D tensor). const alignInputWithDataFormat = (input, dataFormat) => { if (dataFormat === 'NCHW' && input.shape.length === 1 && input.shape[0] !== 1) { const alignedInput = reshape({ inputs: { x: input }, backend, attrs: { shape: [input.shape[0], 1, 1] } }); intermediates.push(alignedInput); return alignedInput; } return input; }; if (hasBias) { inputs.push(alignInputWithDataFormat(bias, dataFormat)); } if (hasPreluActivationWeights) { inputs.push(alignInputWithDataFormat(preluActivationWeights, dataFormat)); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } return inputs; }; if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) { out = conv2dByMatMul({ x, filter, convInfo, backend, bias, activation, preluActivationWeights, leakyreluAlpha }); } else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast' && tf.env().getBool('WEBGL_EXP_CONV')) { const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null; const program = new Conv2DPackedProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; const inputs = prepareInputs(); out = backend.runWebGLProgram(program, inputs, 'float32', customValues); } else if (tf.env().getBool('WEBGL_CONV_IM2COL')) { out = conv2dWithIm2Row({ x, filter, convInfo, backend, bias, activation, preluActivationWeights, leakyreluAlpha }); } else { const fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null; const program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const inputs = prepareInputs(); out = backend.runWebGLProgram(program, inputs, 'float32'); } const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } }); intermediates.push(out); intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t)); return outReshaped; } const fusedConv2DConfig = { kernelName: tf.FusedConv2D, backendName: 'webgl', kernelFunc: fusedConv2d, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function fusedDepthwiseConv2D(args) { const { inputs, backend, attrs } = args; const { x, filter, bias, preluActivationWeights } = inputs; const { strides, pad, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs; const intermediates = []; let $dilations = dilations; if ($dilations == null) { $dilations = [1, 1]; } tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' + `1. Got strides ${strides} and dilations '${$dilations}'`); const convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */); const shouldPackDepthwiseConv = tf.env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1; const fusedActivation = activation ? mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) : null; const programInputs = [x, filter]; const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; if (hasBias) { programInputs.push(bias); } if (hasPreluActivationWeights) { programInputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32')); programInputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } let program; if (shouldPackDepthwiseConv) { program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); } else { program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); } const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; const result = backend.runWebGLProgram(program, programInputs, 'float32', customValues); intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t)); return result; } const fusedDepthwiseConv2DConfig = { kernelName: tf.FusedDepthwiseConv2D, backendName: 'webgl', kernelFunc: fusedDepthwiseConv2D, }; class GatherNDProgram { constructor(sliceDim, strides, shape, paramsShape) { this.sliceDim = sliceDim; this.strides = strides; this.paramsShape = paramsShape; this.variableNames = ['x', 'indices']; this.outputShape = shape; const dtype = getCoordsDataType(shape.length); let mainLoop = ` int index;`; for (let j = 0; j < this.sliceDim; j++) { mainLoop += ` index = round(getIndices(coords[0], ${j})); out_of_bounds = out_of_bounds || index < 0; out_of_bounds = out_of_bounds || index >= ${this.paramsShape[j]}; flattenIndex += index * ${this.strides[j]};`; } this.userCode = ` void main() { ${dtype} coords = getOutputCoords(); int flattenIndex = 0; bool out_of_bounds = false; ${mainLoop} setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1])); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function gatherNd(args) { const { inputs, backend } = args; const { params, indices } = inputs; const indicesShape = indices.shape; const sliceRank = indicesShape[indicesShape.length - 1]; const paramsSize = tf.util.sizeFromShape(params.shape); const [resultShape, numSlices, sliceSize, strides] = tf.backend_util.prepareAndValidate(params, indices); const flattenIndices = reshape({ inputs: { x: indices }, backend, attrs: { shape: [numSlices, sliceRank] } }); const flattenX = reshape({ inputs: { x: params }, backend, attrs: { shape: [(tf.util.sizeFromShape(params.shape) / sliceSize), sliceSize] } }); if (backend.shouldExecuteOnCPU([params, indices]) || params.dtype === 'string') { const indicesData = backend.readSync(indices.dataId); const paramsBuf = backend.bufferSync(params); const outValue = gatherNdImplCPU(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize); return backend.makeTensorInfo(resultShape, params.dtype, outValue.values); } const program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize], params.shape); const res = backend.runWebGLProgram(program, [flattenX, flattenIndices], flattenX.dtype); const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: resultShape } }); backend.disposeIntermediateTensorInfo(flattenIndices); backend.disposeIntermediateTensorInfo(flattenX); backend.disposeIntermediateTensorInfo(res); return reshaped; } const gatherNdConfig = { kernelName: tf.GatherNd, backendName: 'webgl', kernelFunc: gatherNd }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class GatherProgram { constructor(aShape, outputShape) { this.variableNames = ['A', 'indices']; this.outputShape = outputShape; this.rank = outputShape.length; const dtype = getCoordsDataType(this.rank); const sourceCoords = getSourceCoords$1(aShape); this.userCode = ` void main() { ${dtype} resRC = getOutputCoords(); int index = int(getIndices(resRC.x, resRC.z)); float inBounds = (index >= 0) && (index < ${aShape[2]}) ? 1.0 : 0.0; setOutput(inBounds * getA(${sourceCoords})); } `; } } // The input and output are always flattened into rank 4 tensors. function getSourceCoords$1(aShape, axis) { const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; const sourceCoords = []; for (let i = 0; i < aShape.length; i++) { if (i === 2) { sourceCoords.push('index'); } else { sourceCoords.push(`${currentCoords[i]}`); } } return sourceCoords.join(); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function gatherV2(args) { const { inputs, backend, attrs } = args; const { x, indices } = inputs; const { axis, batchDims } = attrs; const parsedAxis = tf.util.parseAxisParam(axis, x.shape)[0]; if (tf.env().get('DEBUG')) { // In debug mode, throw error when any index is out of bound. // Otherwise, just fill out of bounds with zeroes. const indicesVals = backend.readSync(indices.dataId); const axisDim = x.shape[parsedAxis]; for (let i = 0; i < indicesVals.length; ++i) { const index = indicesVals[i]; tf.util.assert(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`); } } const shapeInfo = tf.backend_util.segment_util.collectGatherOpShapeInfo(x, indices, parsedAxis, batchDims); const indicesSize = tf.util.sizeFromShape(indices.shape); const toDispose = []; const flattenX = reshape({ inputs: { x }, backend, attrs: { shape: [ shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize, shapeInfo.sliceSize ] } }); const flattenIndex = reshape({ inputs: { x: indices }, backend, attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] } }); toDispose.push(flattenX); toDispose.push(flattenIndex); const flattenOutputShape = [ shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize, shapeInfo.sliceSize ]; if (backend.shouldExecuteOnCPU([x, indices]) || x.dtype === 'string') { const indicesBuf = backend.bufferSync(flattenIndex); const xBuf = backend.bufferSync(flattenX); const outBuf = gatherV2ImplCPU(xBuf, indicesBuf, flattenOutputShape); toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values); } const program = new GatherProgram(flattenX.shape, flattenOutputShape); const res = backend.runWebGLProgram(program, [flattenX, flattenIndex], flattenX.dtype); toDispose.push(res); const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: shapeInfo.outputShape } }); toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return reshaped; } const gatherV2Config = { kernelName: tf.GatherV2, backendName: 'webgl', kernelFunc: gatherV2 }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const GREATER = `return float(a > b);`; const GREATER_PACKED = ` return vec4(greaterThan(a, b)); `; const greater = binaryKernelFunc({ opSnippet: GREATER, packedOpSnippet: GREATER_PACKED, cpuKernelImpl: greaterImplCPU, dtype: 'bool' }); const greaterConfig = { kernelName: tf.Greater, backendName: 'webgl', kernelFunc: greater }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const GREATER_EQUAL = `return float(a >= b);`; const GREATER_EQUAL_PACKED = ` return vec4(greaterThanEqual(a, b)); `; const greaterEqual = binaryKernelFunc({ opSnippet: GREATER_EQUAL, packedOpSnippet: GREATER_EQUAL_PACKED, dtype: 'bool', cpuKernelImpl: greaterEqualImplCPU }); const greaterEqualConfig = { kernelName: tf.GreaterEqual, backendName: 'webgl', kernelFunc: greaterEqual }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function ifft(args) { const { inputs, backend } = args; const { input } = inputs; return fftImpl(input, true /* inverse */, backend); } const ifftConfig = { kernelName: tf.IFFT, backendName: 'webgl', kernelFunc: ifft }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const IS_FINITE = `return float(!isnan(x) && !isinf(x));`; const isFinite = unaryKernelFunc({ opSnippet: IS_FINITE, dtype: 'bool' }); const isFiniteConfig = { kernelName: tf.IsFinite, backendName: 'webgl', kernelFunc: isFinite, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const IS_INF = `return float(isinf(x));`; const isInf = unaryKernelFunc({ opSnippet: IS_INF, dtype: 'bool' }); const isInfConfig = { kernelName: tf.IsInf, backendName: 'webgl', kernelFunc: isInf, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const IS_NAN = `return float(isnan(x));`; const isNaN = unaryKernelFunc({ opSnippet: IS_NAN, dtype: 'bool' }); const isNaNConfig = { kernelName: tf.IsNan, backendName: 'webgl', kernelFunc: isNaN, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const LESS = `return float(a < b);`; const LESS_PACKED = ` return vec4(lessThan(a, b)); `; const less = binaryKernelFunc({ opSnippet: LESS, packedOpSnippet: LESS_PACKED, cpuKernelImpl: lessImplCPU, dtype: 'bool' }); const lessConfig = { kernelName: tf.Less, backendName: 'webgl', kernelFunc: less }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const LESS_EQUAL = `return float(a <= b);`; const LESS_EQUAL_PACKED = ` return vec4(lessThanEqual(a, b)); `; const lessEqual = binaryKernelFunc({ opSnippet: LESS_EQUAL, packedOpSnippet: LESS_EQUAL_PACKED, cpuKernelImpl: lessEqualImplCPU, dtype: 'bool' }); const lessEqualConfig = { kernelName: tf.LessEqual, backendName: 'webgl', kernelFunc: lessEqual }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function linSpace(args) { const { backend, attrs } = args; const { start, stop, num } = attrs; // TODO: Use CPU implementation due to the precision problem in Safari. const outVals = linSpaceImplCPU(start, stop, num); return backend.makeTensorInfo([outVals.length], 'float32', outVals); } const linSpaceConfig = { kernelName: tf.LinSpace, backendName: 'webgl', kernelFunc: linSpace }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Windows chrome return 0 if the input is negative value. We will specifically // return NaN if the input is 0 to solve compatiblity issue. const LOG = CHECK_NAN_SNIPPET_UNARY + ` return x < 0.0 ? 0./0. : log(x); `; const LOG_PACKED = ` vec4 result = log(x); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : (x.r < 0.0 ? 0./0. : result.r); result.g = isNaN.g ? x.g : (x.g < 0.0 ? 0./0. : result.g); result.b = isNaN.b ? x.b : (x.b < 0.0 ? 0./0. : result.b); result.a = isNaN.a ? x.a : (x.a < 0.0 ? 0./0. : result.a); return result; `; const log = unaryKernelFunc({ opSnippet: LOG, packedOpSnippet: LOG_PACKED, cpuKernelImpl: logImplCPU }); const logConfig = { kernelName: tf.Log, backendName: 'webgl', kernelFunc: log }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const LOG1P = CHECK_NAN_SNIPPET_UNARY + ` return log(1.0 + x); `; const log1p = unaryKernelFunc({ opSnippet: LOG1P }); const log1pConfig = { kernelName: tf.Log1p, backendName: 'webgl', kernelFunc: log1p, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const LOGICAL_AND = `return float(a >= 1.0 && b >= 1.0);`; const LOGICAL_AND_PACKED = ` return vec4( vec4(greaterThanEqual(a, vec4(1.0))) * vec4(greaterThanEqual(b, vec4(1.0)))); `; const logicalAnd = binaryKernelFunc({ opSnippet: LOGICAL_AND, packedOpSnippet: LOGICAL_AND_PACKED, dtype: 'bool' }); const logicalAndConfig = { kernelName: tf.LogicalAnd, backendName: 'webgl', kernelFunc: logicalAnd }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const LOGICAL_NOT = `return float(!(x >= 1.0));`; const logicalNot = unaryKernelFunc({ opSnippet: LOGICAL_NOT }); const logicalNotConfig = { kernelName: tf.LogicalNot, backendName: 'webgl', kernelFunc: logicalNot, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const LOGICAL_OR = `return float(a >= 1.0 || b >= 1.0);`; const LOGICAL_OR_PACKED = ` return min( vec4(greaterThanEqual(a, vec4(1.0))) + vec4(greaterThanEqual(b, vec4(1.0))), vec4(1.0)); `; const logicalOr = binaryKernelFunc({ opSnippet: LOGICAL_OR, packedOpSnippet: LOGICAL_OR_PACKED, dtype: 'bool' }); const logicalOrConfig = { kernelName: tf.LogicalOr, backendName: 'webgl', kernelFunc: logicalOr }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class LRNProgram { constructor(xShape, radius, bias, alpha, beta) { this.variableNames = ['x']; this.outputShape = []; const rad = radius; const maxD = xShape[3] - 1; this.outputShape = xShape; // optimize pow(bias + alpha * sum, -beta) // src: https://github.com/tensorflow/tensorflow/.. // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/.. // tensorflow/core/kernels/mkl_lrn_op.cc#L320 let powOperator; const basis = `float(${bias}) + float(${alpha}) * sum`; if (beta === 0.5) { powOperator = `inversesqrt(${basis})`; } else if (beta === 1.0) { powOperator = `1.0/(${basis})`; } else { powOperator = `exp(log(${basis}) * float(-${beta}));`; } this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int r = coords[1]; int c = coords[2]; int d = coords[3]; float x = getX(b, r, c, d); float sum = 0.0; for (int j = -${rad}; j <= ${rad}; j++) { int idx = d + j; if (idx >= 0 && idx <= ${maxD}) { float z = getX(b, r, c, idx); sum += z * z; } } float val = x * ${powOperator}; setOutput(val); } `; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class LRNPackedProgram { constructor(xShape, radius, bias, alpha, beta) { this.variableNames = ['x']; this.outputShape = []; this.packedInputs = true; this.packedOutput = true; const rad = radius; const maxD = xShape[3] - 1; this.outputShape = xShape; // optimize pow(bias + alpha * sum, -beta) // src: https://github.com/tensorflow/tensorflow/.. // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/.. // tensorflow/core/kernels/mkl_lrn_op.cc#L320 let powOperator; const basis = `float(${bias}) + float(${alpha}) * sum`; if (beta === 0.5) { powOperator = `inversesqrt(${basis})`; } else if (beta === 1.0) { powOperator = `1.0/(${basis})`; } else { powOperator = `exp(log(${basis}) * float(-${beta}));`; } this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords.x; int r = coords.y; int c = coords.z; int d = coords.w; bool hasNextCol = d < ${this.outputShape[3]}; bool hasNextRow = c < ${this.outputShape[2]}; vec4 sum = vec4(0.); vec4 xFragAtOutputCoords = getX(b, r, c, d); vec4 xAtOutputCoords = vec4( getChannel(xFragAtOutputCoords, vec2(c, d)), hasNextCol ? getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0, hasNextRow ? getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0, (hasNextRow && hasNextCol) ? getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0 ); int firstChannel = d - ${rad}; vec2 cache = vec2(0.); if(firstChannel >= 0){ vec4 firstChannelFrag = getX(b, r, c, firstChannel); cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel)); if(hasNextRow){ cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel)); } } ivec2 depth = ivec2(d, d + 1); for (int j = - ${rad}; j <= ${rad}; j++) { ivec2 idx = depth + j; bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0)); bvec2 belowUpperBound = lessThanEqual(idx, ivec2(${maxD})); bool depthInRange = aboveLowerBound.x && belowUpperBound.x; bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y; if(depthInRange || depthPlusOneInRange){ vec4 z = vec4(0.); vec4 xFragAtCurrentDepth; z.xz = cache.xy; if(depthPlusOneInRange && hasNextCol){ xFragAtCurrentDepth = idx.y != d ? getX(b, r, c, idx.y) : xFragAtOutputCoords; z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y)); if(hasNextRow){ z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y)); } } cache.xy = z.yw; sum += z * z; } } vec4 result = xAtOutputCoords * ${powOperator}; setOutput(result); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const lrn = (args) => { const { inputs, backend, attrs } = args; const { x } = inputs; const { depthRadius, bias, alpha, beta } = attrs; const program = tf.env().getBool('WEBGL_PACK_NORMALIZATION') ? new LRNPackedProgram(x.shape, depthRadius, bias, alpha, beta) : new LRNProgram(x.shape, depthRadius, bias, alpha, beta); return backend.runWebGLProgram(program, [x], x.dtype); }; // tslint:disable-next-line: variable-name const LRNConfig = { kernelName: tf.LRN, backendName: 'webgl', kernelFunc: lrn }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class LRNGradProgram { constructor(inputShape, depthRadius, bias, alpha, beta) { this.variableNames = ['inputImage', 'outputImage', 'dy']; this.outputShape = []; this.outputShape = inputShape; this.depth = inputShape[3]; this.depthRadius = depthRadius; this.bias = bias; this.alpha = alpha; this.beta = beta; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int r = coords[1]; int c = coords[2]; float result = 0.0; for (int d = 0; d < ${this.depth}; ++d) { int depthBegin = int(max(0.0, float(d - ${depthRadius}))); int depthEnd = int(min(float(${this.depth}), float(d + ${depthRadius} + 1))); const int MIN_DEPTH_BEGIN = 0; const int MAX_DEPTH_END = ${this.depth}; float norm = 0.0; for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) { if (k < depthBegin){ continue; } else if (k >= depthBegin && k < depthEnd) { norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k); } else { break; } } norm = float(${alpha}) * norm + float(${bias}); for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){ if (k < depthBegin){ continue; } else if (k >= depthBegin && k < depthEnd){ float dyi = -2.0 * float(${alpha}) * float(${beta}) * getInputImage(b, r, c, k) * getOutputImage(b, r, c, d) / norm; if (k == d) { dyi += pow(norm, -1.0 * ${beta}); } if (k == coords[3]) { dyi *= getDy(b, r, c, d); result += dyi; } } else { break; } } } setOutput(result); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const lrnGrad = (args) => { const { inputs, backend, attrs } = args; const { x, y, dy } = inputs; const { depthRadius, bias, alpha, beta } = attrs; const program = new LRNGradProgram(x.shape, depthRadius, bias, alpha, beta); return backend.runWebGLProgram(program, [x, y, dy], x.dtype); }; // tslint:disable-next-line: variable-name const LRNGradConfig = { kernelName: tf.LRNGrad, backendName: 'webgl', kernelFunc: lrnGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxImpl(x, reduceShape, outShape, backend) { const inSize = tf.util.sizeFromShape(reduceShape); const xSize = tf.util.sizeFromShape(x.shape); const batchSize = xSize / inSize; const reshapedInput = reshape({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend }); const reduced = reduce(reshapedInput, x.dtype, 'max', backend); const reshapedOutput = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend }); backend.disposeIntermediateTensorInfo(reshapedInput); backend.disposeIntermediateTensorInfo(reduced); return reshapedOutput; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function max(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { reductionIndices, keepDims } = attrs; const xRank = x.shape.length; const origAxes = tf.util.parseAxisParam(reductionIndices, x.shape); let axes = origAxes; const permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); const maxInputIsTransposed = permutedAxes != null; const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]); let maxInput = x; if (maxInputIsTransposed) { if (shouldExecuteOnCPU) { const xTexData = backend.texData.get(maxInput.dataId); const values = xTexData.values; const newShape = new Array(xRank); for (let i = 0; i < newShape.length; i++) { newShape[i] = x.shape[permutedAxes[i]]; } const maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape); maxInput = backend.makeTensorInfo(newShape, x.dtype); const maxInputData = backend.texData.get(maxInput.dataId); maxInputData.values = maxInputValues; } else { maxInput = transposeImpl(x, permutedAxes, backend); } axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); } tf.backend_util.assertAxesAreInnerMostDims('max', axes, xRank); const [maxOutShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(maxInput.shape, axes); let outShape = maxOutShape; if (keepDims) { // rather than reshape at the end, set the target shape here. outShape = tf.backend_util.expandShapeToKeepDim(maxOutShape, origAxes); } let out; if (shouldExecuteOnCPU) { const xTexData = backend.texData.get(maxInput.dataId); const values = xTexData.values; const outValues = maxImplCPU(values, tf.util.sizeFromShape(reduceShape), outShape, x.dtype); out = backend.makeTensorInfo(outShape, x.dtype); const outData = backend.texData.get(out.dataId); outData.values = outValues; } else { out = maxImpl(maxInput, reduceShape, outShape, backend); } if (maxInputIsTransposed) { backend.disposeIntermediateTensorInfo(maxInput); } return out; } const maxConfig = { kernelName: tf.Max, backendName: 'webgl', kernelFunc: max }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const MAXIMUM = CHECK_NAN_SNIPPET + ` return max(a, b); `; const MAXIMUM_PACKED = ` vec4 result = vec4(max(a, b)); bvec4 isNaNA = isnan(a); bvec4 isNaNB = isnan(b); bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w); ` + CHECK_NAN_SNIPPET_PACKED + ` return result; `; const maximum = binaryKernelFunc({ opSnippet: MAXIMUM, packedOpSnippet: MAXIMUM_PACKED, cpuKernelImpl: maximumImplCPU }); const maximumConfig = { kernelName: tf.Maximum, backendName: 'webgl', kernelFunc: maximum }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxPool(args) { const { inputs, backend, attrs } = args; const { x } = inputs; assertNotComplex(x, 'maxPool'); const { filterSize, strides, pad, dimRoundingMode } = attrs; const dilations = 1; tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); const convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && tf.util.arraysEqual(convInfo.inShape, convInfo.outShape)) { return identity({ inputs: { x }, backend }); } const maxPoolProgram = new Pool2DProgram(convInfo, 'max', false); return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype); } const maxPoolConfig = { kernelName: tf.MaxPool, backendName: 'webgl', kernelFunc: maxPool }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxPool3d(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { filterSize, strides, pad, dataFormat, dimRoundingMode } = attrs; const dilations = [1, 1, 1]; const convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat); const maxPoolProgram = new Pool3DProgram(convInfo, 'max', false); return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype); } const maxPool3DConfig = { kernelName: tf.MaxPool3D, backendName: 'webgl', kernelFunc: maxPool3d }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class MaxPool2DBackpropProgram { constructor(convInfo) { this.variableNames = ['dy', 'maxPos']; this.outputShape = convInfo.inShape; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationHeight = convInfo.dilationHeight; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; const lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1; this.userCode = ` const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; ivec2 dyRCCorner = coords.yz - pads; int dyRCorner = dyRCCorner.x; int dyCCorner = dyRCCorner.y; // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); for (int wC = 0; wC < ${effectiveFilterWidth}; wC++) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); float dyValue = getDy(b, idyR, idyC, d); int maxPosValue = ${lastIndex} - int(getMaxPos(b, idyR, idyC, d)); // Get the current value, check it against the value from the // position matrix. int curPosValue = wR * ${effectiveFilterWidth} + wC; float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0); dotProd += dyValue * mask; } } setOutput(dotProd); } `; } } class MaxPool3DBackpropProgram { constructor(convInfo) { this.variableNames = ['dy', 'maxPos']; this.outputShape = convInfo.inShape; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationDepth = convInfo.dilationDepth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const effectiveFilterDepth = convInfo.effectiveFilterDepth; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; const lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1; this.userCode = ` const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int ch = coords.u; ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads; int dyDCorner = dyCorner.x; int dyRCorner = dyCorner.y; int dyCCorner = dyCorner.z; // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get // dx(xD, xR, xC, ch). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; for (int wD = 0; wD < ${effectiveFilterDepth}; wD += ${dilationDepth}) { float dyD = float(dyDCorner + wD) / ${strideDepth}.0; if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) { continue; } int idyD = int(dyD); for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); for (int wC = 0; wC < ${effectiveFilterWidth}; wC += ${dilationWidth}) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); float dyValue = getDy(batch, idyD, idyR, idyC, ch); int maxPosValue = ${lastIndex} - int(getMaxPos(batch, idyD, idyR, idyC, ch)); // Get the current value, check it against the value from the // position matrix. int curPosValue = wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} + wR * ${effectiveFilterWidth} + wC; float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0); dotProd += dyValue * mask; } } } setOutput(dotProd); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxPool3DGrad(args) { const { inputs, backend, attrs } = args; const { dy, input } = inputs; const x = input; const { filterSize, strides, pad, dimRoundingMode } = attrs; const dilations = [1, 1, 1]; const convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); const maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', true /* get positions */); const maxPool3dPositions = backend.runWebGLProgram(maxPool3dPositionsProgram, [x], x.dtype); const maxPoolBackpropProgram = new MaxPool3DBackpropProgram(convInfo); const result = backend.runWebGLProgram(maxPoolBackpropProgram, [dy, maxPool3dPositions], x.dtype); backend.disposeIntermediateTensorInfo(maxPool3dPositions); return result; } const maxPool3DGradConfig = { kernelName: tf.MaxPool3DGrad, backendName: 'webgl', kernelFunc: maxPool3DGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxPoolGrad(args) { const { inputs, backend, attrs } = args; const { dy, input, output } = inputs; const x = input; assertNotComplex([input, output], 'maxPoolGrad'); const { filterSize, strides, pad, dimRoundingMode } = attrs; const convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode); const getPositions = true; const maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions); const maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype); const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo); const result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype); backend.disposeIntermediateTensorInfo(maxPoolPositions); return result; } const maxPoolGradConfig = { kernelName: tf.MaxPoolGrad, backendName: 'webgl', kernelFunc: maxPoolGrad }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, backend) { let program = new Pool2DProgram(convInfo, 'max', false); const poolOutput = backend.runWebGLProgram(program, [x], 'float32'); program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex); const indexOutput = backend.runWebGLProgram(program, [x], 'float32'); return [poolOutput, indexOutput]; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const maxPoolWithArgmaxConfig = { kernelName: tf.MaxPoolWithArgmax, backendName: 'webgl', kernelFunc: ({ inputs, attrs, backend }) => { const { x } = inputs; const { filterSize, strides, pad, includeBatchInIndex } = attrs; const webglBackend = backend; tf.util.assert(x.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x.shape.length}.`); const dilations = [1, 1]; tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); const convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad); const [result, indexes] = maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, webglBackend); return [result, indexes]; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function meanImpl(x, reduceShape, outShape, backend) { const inSize = tf.util.sizeFromShape(reduceShape); const xSize = tf.util.sizeFromShape(x.shape); const batchSize = xSize / inSize; const reshapedInput = reshape({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend }); const reduced = reduce(reshapedInput, 'float32', 'mean', backend); const reshapedOutput = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend }); backend.disposeIntermediateTensorInfo(reshapedInput); backend.disposeIntermediateTensorInfo(reduced); return reshapedOutput; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const meanConfig = { kernelName: tf.Mean, backendName: 'webgl', kernelFunc: ({ inputs, attrs, backend }) => { const { x } = inputs; const { keepDims, axis } = attrs; const webglBackend = backend; const xRank = x.shape.length; const origAxes = tf.util.parseAxisParam(axis, x.shape); let axes = origAxes; const permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); const meanInputIsTransposed = permutedAxes != null; const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]); const intermediates = []; let meanInput = x; if (meanInputIsTransposed) { if (shouldExecuteOnCPU) { const xTexData = webglBackend.texData.get(meanInput.dataId); const values = xTexData.values; const newShape = new Array(xRank); for (let i = 0; i < newShape.length; i++) { newShape[i] = x.shape[permutedAxes[i]]; } const meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape); meanInput = webglBackend.makeTensorInfo(newShape, x.dtype); const meanInputData = webglBackend.texData.get(meanInput.dataId); meanInputData.values = meanInputValues; } else { meanInput = transposeImpl(x, permutedAxes, webglBackend); } intermediates.push(meanInput); axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); } tf.backend_util.assertAxesAreInnerMostDims('sum', axes, xRank); const [meanOutShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(meanInput.shape, axes); let outShape = meanOutShape; if (keepDims) { // rather than reshape at the end, set the target shape here. outShape = tf.backend_util.expandShapeToKeepDim(meanOutShape, origAxes); } const out = meanImpl(meanInput, reduceShape, outShape, webglBackend); for (const i of intermediates) { webglBackend.disposeIntermediateTensorInfo(i); } return out; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function min(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, keepDims } = attrs; const xRank = x.shape.length; const origAxes = tf.util.parseAxisParam(axis, x.shape); let axes = origAxes; const permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); let permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); axes = tf.backend_util.getInnerMostAxes(axes.length, x.shape.length); } tf.backend_util.assertAxesAreInnerMostDims('min', axes, xRank); const [outShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes); const inSize = tf.util.sizeFromShape(reduceShape); const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } }); const reduced = reduce(a2D, a2D.dtype, 'min', backend); let res; if (keepDims) { const newShape = tf.backend_util.expandShapeToKeepDim(outShape, origAxes); res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: newShape } }); } else { res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } }); } backend.disposeIntermediateTensorInfo(a2D); backend.disposeIntermediateTensorInfo(reduced); if (permutedAxes != null) { backend.disposeIntermediateTensorInfo(permutedX); } return res; } const minConfig = { kernelName: tf.Min, backendName: 'webgl', kernelFunc: min }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const MINIMUM = CHECK_NAN_SNIPPET + ` return min(a, b); `; const MINIMUM_PACKED = ` vec4 result = vec4(min(a, b)); bvec4 isNaNA = isnan(a); bvec4 isNaNB = isnan(b); bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w); ` + CHECK_NAN_SNIPPET_PACKED + ` return result; `; const minimum = binaryKernelFunc({ opSnippet: MINIMUM, packedOpSnippet: MINIMUM_PACKED, cpuKernelImpl: minimumImplCPU }); const minimumConfig = { kernelName: tf.Minimum, backendName: 'webgl', kernelFunc: minimum }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class MirrorPadProgram { constructor(xShape, paddings, mode) { this.variableNames = ['x']; this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */); const rank = xShape.length; const dtype = getCoordsDataType(rank); const start = paddings.map(p => p[0]).join(','); const end = paddings.map((p, i) => p[0] + xShape[i]).join(','); const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank); const offset = mode === 'reflect' ? 0 : 1; if (rank === 1) { this.userCode = ` int start = ${start}; int end = ${end}; void main() { int outC = getOutputCoords(); if (outC < start) { outC = start * 2 - outC - ${offset}; } else if(outC >= end) { outC = (end - 1) * 2 - outC + ${offset}; } setOutput(getX(outC - start)); } `; return; } this.userCode = ` ${dtype} start = ${dtype}(${start}); ${dtype} end = ${dtype}(${end}); void main() { ${dtype} outC = getOutputCoords(); for (int i = 0; i < ${rank}; i++) { if (outC[i] < start[i]) { outC[i] = start[i] * 2 - outC[i] - ${offset}; } else if(outC[i] >= end[i]) { outC[i] = (end[i] - 1) * 2 - outC[i] + ${offset}; } } ${dtype} coords = outC - start; setOutput(getX(${unpackedCoords})); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Example shader code for * `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')` * ``` * const int start = int(2); * const int end = int(5); * * void main() { * int outputLoc = getOutputCoords(); * vec4 result = vec4(0.); * * int rc = outputLoc; * * int source = rc; * if (source < start) { * source = start * 2 - source - 0; * } else if (source >= end) { * source = (end - 1) * 2 - source + 0; * } * source -= start; * * result[0] = getChannel(getX(source), source); * rc += 1; * if(rc < 6) { * int source = rc; * if (source < start) { * source = start * 2 - source - 0; * } else if (source >= end) { * source = (end - 1) * 2 - source + 0; * } * source -= start; * * result[1] = getChannel(getX(source), source); * } * * setOutput(result); * } * ``` */ class MirrorPadPackedProgram { constructor(xShape, paddings, mode) { this.variableNames = ['x']; this.packedInputs = true; this.packedOutput = true; this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */); const rank = xShape.length; const dtype = getCoordsDataType(rank); const start = paddings.map(p => p[0]).join(','); const end = paddings.map((p, i) => p[0] + xShape[i]).join(','); const coords = getChannels('rc', rank); const source = getChannels('source', rank); const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`; const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`; const offset = mode === 'reflect' ? 0 : 1; let mainLoop = ''; if (rank === 1) { const padSetup = ` ${dtype} source = rc; if (source < start) { source = start * 2 - source - ${offset}; } else if (source >= end) { source = (end - 1) * 2 - source + ${offset}; } source -= start; `; mainLoop = ` ${dtype} rc = outputLoc; ${padSetup} result[0] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[1] = getChannel(getX(${source.join()}), ${innerDims}); } `; } else { const padSetup = ` ${dtype} source = rc; ${dtype} lt = ${dtype}(lessThan(source, start)); ${dtype} gte = ${dtype}(greaterThanEqual(source, end)); ${dtype} orig = 1 - (lt + gte); source = orig * source + lt * (start * 2 - source - ${offset}) + gte * ((end - 1) * 2 - source + ${offset}); source -= start; `; mainLoop = ` ${dtype} rc = outputLoc; ${padSetup} result[0] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[1] = getChannel(getX(${source.join()}), ${innerDims}); } rc = outputLoc; ${coords[rank - 2]} += 1; if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) { ${padSetup} result[2] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[3] = getChannel(getX(${source.join()}), ${innerDims}); } } `; } this.userCode = ` const ${dtype} start = ${dtype}(${start}); const ${dtype} end = ${dtype}(${end}); void main() { ${dtype} outputLoc = getOutputCoords(); vec4 result = vec4(0.); ${mainLoop} setOutput(result); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const mirrorPadKernelFunc = ({ inputs, backend, attrs }) => { const { x } = inputs; const { paddings, mode } = attrs; const program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new MirrorPadPackedProgram(x.shape, paddings, mode) : new MirrorPadProgram(x.shape, paddings, mode); const output = backend.runWebGLProgram(program, [x], x.dtype); return output; }; const mirrorPadConfig = { kernelName: tf.MirrorPad, backendName: 'webgl', kernelFunc: mirrorPadKernelFunc, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const MOD = `if (b == 0.0) return NAN; return mod(a, b);`; const MOD_PACKED = ` vec4 result = mod(a, b); bvec4 isNaN = equal(b, vec4(0.0)); ` + CHECK_NAN_SNIPPET_PACKED + ` return result; `; const mod = binaryKernelFunc({ opSnippet: MOD, packedOpSnippet: MOD_PACKED, }); const modConfig = { kernelName: tf.Mod, backendName: 'webgl', kernelFunc: mod }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class MultinomialProgram { constructor(batchSize, numOutcomes, numSamples) { this.variableNames = ['probs']; this.customUniforms = [{ name: 'seed', type: 'float' }]; this.outputShape = [batchSize, numSamples]; this.userCode = ` void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; float r = random(seed); float cdf = 0.0; for (int i = 0; i < ${numOutcomes - 1}; i++) { cdf += getProbs(batch, i); if (r < cdf) { setOutput(float(i)); return; } } // If no other event happened, last event happened. setOutput(float(${numOutcomes - 1})); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Without the equality check div produces 0.9999 for a = b, which when // floored can cause errors. const DIV = ` if (a == b) { return 1.0; }; return a / b;`; // We do the same as in ./binaryop_gpu, with vec4 and ivec4. // On Linux, the vectorized implementation produces NaNs when a and b are 0. const DIV_PACKED = ` // vec4 one = vec4(equal(a, b)); // return one + (vec4(1.0) - one) * a / b; vec4 result = a / b; if(a.x == b.x) { result.x = 1.; } if(a.y == b.y) { result.y = 1.; } if(a.z == b.z) { result.z = 1.; } if(a.w == b.w) { result.w = 1.; } return result; `; const realDiv = binaryKernelFunc({ opSnippet: DIV, packedOpSnippet: DIV_PACKED, checkOutOfBounds: true }); const realDivConfig = { kernelName: tf.RealDiv, backendName: 'webgl', kernelFunc: realDiv, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const SUB = 'return a - b;'; const sub = binaryKernelFunc({ opSnippet: SUB, packedOpSnippet: SUB, supportsComplex: true, cpuKernelImpl: subImplCPU }); const subConfig = { kernelName: tf.Sub, backendName: 'webgl', kernelFunc: sub }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function softmax(args) { const { inputs, backend, attrs } = args; const { logits } = inputs; const { dim } = attrs; const axes = tf.util.parseAxisParam([dim], logits.shape); const maxLogit = max({ inputs: { x: logits }, backend, attrs: { reductionIndices: axes, keepDims: false } }); const expandedShape = tf.backend_util.expandShapeToKeepDim(maxLogit.shape, axes); const maxLogitsReshaped = reshape({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } }); const a = sub({ inputs: { a: logits, b: maxLogitsReshaped }, backend }); const b = exp({ inputs: { x: a }, backend }); const sumExp = sum({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } }); const sumExpReshaped = reshape({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } }); const res = realDiv({ inputs: { a: b, b: sumExpReshaped }, backend }); backend.disposeIntermediateTensorInfo(maxLogit); backend.disposeIntermediateTensorInfo(maxLogitsReshaped); backend.disposeIntermediateTensorInfo(a); backend.disposeIntermediateTensorInfo(b); backend.disposeIntermediateTensorInfo(sumExp); backend.disposeIntermediateTensorInfo(sumExpReshaped); return res; } const softmaxConfig = { kernelName: tf.Softmax, backendName: 'webgl', kernelFunc: softmax }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function multinomial(args) { const { inputs, backend, attrs } = args; const { logits } = inputs; const { numSamples, seed, normalized } = attrs; const probs = normalized ? logits : softmax({ inputs: { logits }, backend, attrs: { dim: logits.shape.length - 1 } }); const batchSize = probs.shape[0]; const numOutcomes = probs.shape[1]; const program = new MultinomialProgram(batchSize, numOutcomes, numSamples); const customValues = [[seed]]; const res = backend.runWebGLProgram(program, [probs], 'int32', customValues); if (!normalized) { backend.disposeIntermediateTensorInfo(probs); } return res; } const multinomialConfig = { kernelName: tf.Multinomial, backendName: 'webgl', kernelFunc: multinomial }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const NEG = CHECK_NAN_SNIPPET$1 + ` return -x; `; const NEG_PACKED = ` vec4 result = -x; bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; // This doesn't use unaryKernelFunc because negImplCPU is not of type // SimpleUnaryKernelImplCPU. function neg(args) { const { inputs, backend } = args; const { x } = inputs; if (backend.shouldExecuteOnCPU([x])) { const xData = backend.texData.get(x.dataId); const [outValues, newShape] = negImplCPU(xData.values, x.shape, x.dtype); return backend.makeTensorInfo(newShape, x.dtype, outValues); } let program; if (tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { program = new UnaryOpPackedProgram(x.shape, NEG_PACKED); } else { program = new UnaryOpProgram(x.shape, NEG); } return backend.runWebGLProgram(program, [x], x.dtype); } const negConfig = { kernelName: tf.Neg, backendName: 'webgl', kernelFunc: neg }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const nonMaxSuppressionV3Impl = tf.kernel_impls.nonMaxSuppressionV3Impl; function nonMaxSuppressionV3(args) { tf.backend_util.warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); const { inputs, backend, attrs } = args; const { boxes, scores } = inputs; const { maxOutputSize, iouThreshold, scoreThreshold } = attrs; const boxesVals = backend.readSync(boxes.dataId); const scoresVals = backend.readSync(scores.dataId); const { selectedIndices } = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)); } const nonMaxSuppressionV3Config = { kernelName: tf.NonMaxSuppressionV3, backendName: 'webgl', kernelFunc: nonMaxSuppressionV3 }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const nonMaxSuppressionV4Impl = tf.kernel_impls.nonMaxSuppressionV4Impl; function nonMaxSuppressionV4(args) { tf.backend_util.warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); const { inputs, backend, attrs } = args; const { boxes, scores } = inputs; const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs; const boxesVals = backend.readSync(boxes.dataId); const scoresVals = backend.readSync(scores.dataId); const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize); return [ backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs])) ]; } const nonMaxSuppressionV4Config = { kernelName: tf.NonMaxSuppressionV4, backendName: 'webgl', kernelFunc: nonMaxSuppressionV4 }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const nonMaxSuppressionV5Impl = tf.kernel_impls.nonMaxSuppressionV5Impl; function nonMaxSuppressionV5(args) { tf.backend_util.warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); const { inputs, backend, attrs } = args; const { boxes, scores } = inputs; const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs; const boxesVals = backend.readSync(boxes.dataId); const scoresVals = backend.readSync(scores.dataId); const maxOutputSizeVal = maxOutputSize; const iouThresholdVal = iouThreshold; const scoreThresholdVal = scoreThreshold; const softNmsSigmaVal = softNmsSigma; const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal); return [ backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores)) ]; } const nonMaxSuppressionV5Config = { kernelName: tf.NonMaxSuppressionV5, backendName: 'webgl', kernelFunc: nonMaxSuppressionV5 }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class OneHotProgram { constructor(numIndices, depth, onValue, offValue) { this.variableNames = ['indices']; this.outputShape = [numIndices, depth]; this.userCode = ` void main() { ivec2 coords = getOutputCoords(); int index = round(getIndices(coords.x)); setOutput(mix(float(${offValue}), float(${onValue}), float(index == coords.y))); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const oneHot = (args) => { const { inputs, backend, attrs } = args; const { indices } = inputs; const { dtype, depth, onValue, offValue } = attrs; const indicesSize = tf.util.sizeFromShape(indices.shape); const program = new OneHotProgram(indicesSize, depth, onValue, offValue); const reshaped = reshape({ inputs: { x: indices }, backend, attrs: { shape: [indicesSize] } }); const result = backend.runWebGLProgram(program, [reshaped], dtype); backend.disposeIntermediateTensorInfo(reshaped); const outShape = [...indices.shape, depth]; const out = reshape({ inputs: { x: result }, backend, attrs: { shape: outShape } }); backend.disposeIntermediateTensorInfo(result); return out; }; const oneHotConfig = { kernelName: tf.OneHot, backendName: 'webgl', kernelFunc: oneHot }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function zerosLike(args) { const { inputs, backend } = args; const { x } = inputs; if (x.dtype === 'complex64') { const realPart = real({ inputs: { input: x }, backend }); const r = zerosLike({ inputs: { x: realPart }, backend }); const imagPart = imag({ inputs: { input: x }, backend }); const i = zerosLike({ inputs: { x: imagPart }, backend }); const result = complex({ inputs: { real: r, imag: i }, backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(r); backend.disposeIntermediateTensorInfo(imagPart); backend.disposeIntermediateTensorInfo(i); return result; } else { return fill({ attrs: { shape: x.shape, dtype: x.dtype, value: x.dtype === 'string' ? '' : 0 }, backend }); } } const zerosLikeConfig = { kernelName: tf.ZerosLike, backendName: 'webgl', kernelFunc: zerosLike }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function onesLike(args) { const { inputs, backend } = args; const { x } = inputs; if (x.dtype === 'string') { throw new Error('onesLike is not supported under string dtype'); } else if (x.dtype === 'complex64') { const realPart = real({ inputs: { input: x }, backend }); const r = onesLike({ inputs: { x: realPart }, backend }); const imagPart = imag({ inputs: { input: x }, backend }); const i = zerosLike({ inputs: { x: imagPart }, backend }); const result = complex({ inputs: { real: r, imag: i }, backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(r); backend.disposeIntermediateTensorInfo(imagPart); backend.disposeIntermediateTensorInfo(i); return result; } else { // TODO(cais, smilkov): Add WebGL shader for onesLike: // https://github.com/tensorflow/tfjs/issues/1293 return fill({ attrs: { shape: x.shape, dtype: x.dtype, value: 1 }, backend }); } } const onesLikeConfig = { kernelName: tf.OnesLike, backendName: 'webgl', kernelFunc: onesLike }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function pack(args) { const { inputs, backend, attrs } = args; const { axis } = attrs; if (inputs.length === 1) { return expandDims({ inputs: { input: inputs[0] }, backend, attrs: { dim: axis } }); } const shape = inputs[0].shape; const dtype = inputs[0].dtype; inputs.forEach(t => { tf.util.assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes'); tf.util.assert(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes'); }); const intermediateTensorInfos = []; const expandedTensors = inputs.map(t => { const expandedT = expandDims({ inputs: { input: t }, backend, attrs: { dim: axis } }); intermediateTensorInfos.push(expandedT); return expandedT; }); const result = concat({ inputs: expandedTensors, backend, attrs: { axis } }); intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t)); return result; } const packConfig = { kernelName: tf.Pack, backendName: 'webgl', kernelFunc: pack }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class PadProgram { constructor(xShape, paddings, constantValue) { this.variableNames = ['x']; this.customUniforms = [{ name: 'value', type: 'float' }]; this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */); const rank = xShape.length; const type = getCoordsDataType(rank); const start = paddings.map(p => p[0]).join(','); const end = paddings.map((p, i) => p[0] + xShape[i]).join(','); const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank); if (rank === 1) { this.userCode = ` int start = ${start}; int end = ${end}; void main() { int outC = getOutputCoords(); if (outC < start || outC >= end) { setOutput(value); } else { setOutput(getX(outC - start)); } } `; return; } this.userCode = ` ${type} start = ${type}(${start}); ${type} end = ${type}(${end}); void main() { ${type} outC = getOutputCoords(); if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) { setOutput(value); } else { ${type} coords = outC - start; setOutput(getX(${unpackedCoords})); } } `; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class PadPackedProgram { constructor(xShape, paddings, constantValue) { this.variableNames = ['x']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [{ name: 'value', type: 'float' }]; this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */); const rank = xShape.length; const dtype = getCoordsDataType(rank); const start = paddings.map(p => p[0]).join(','); const end = paddings.map((p, i) => p[0] + xShape[i]).join(','); const coords = getChannels('rc', rank); const source = getChannels('source', rank); const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`; const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`; const componentSetup = [ `${dtype} rc = outputLoc;`, `${coords[rank - 1]} += 1; if(${cLimit}) { `, rank === 1 ? '' : `} rc = outputLoc; ${coords[rank - 2]} += 1; if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {`, rank === 1 ? '' : ` ${coords[rank - 1]} += 1; if(${cLimit}) {` ]; const paddingArea = rank === 1 ? 'rc < start || rc >= end' : 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))'; let mainLoop = ''; for (let i = 0, j = rank === 1 ? 2 : 4; i < j; i++) { mainLoop += ` ${componentSetup[i]} if (${paddingArea}) { result[${i}] = float(value); } else { ${dtype} source = rc - start; result[${i}] = getChannel(getX(${source.join()}), ${innerDims}); } `; } mainLoop += (rank === 1 ? `} ` : `}}`); this.userCode = ` const ${dtype} start = ${dtype}(${start}); const ${dtype} end = ${dtype}(${end}); void main() { ${dtype} outputLoc = getOutputCoords(); vec4 result = vec4(0.); ${mainLoop} setOutput(result); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const padV2 = (args) => { const { inputs, backend, attrs } = args; const { x } = inputs; const { paddings, constantValue } = attrs; if (tf.util.sizeFromShape(x.shape) === 0) { // Short-circuit the computation, since x doesn't have value, only // the shape is used to compute output shape to pad. const outputShape = paddings.map((p, i) => p[0] /* beforePad */ + x.shape[i] + p[1] /* afterPad */); return fill({ backend, attrs: { shape: outputShape, value: constantValue, dtype: x.dtype } }); } const program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new PadPackedProgram(x.shape, paddings, constantValue) : new PadProgram(x.shape, paddings, constantValue); const customValues = [[constantValue]]; return backend.runWebGLProgram(program, [x], x.dtype, customValues); }; const padV2Config = { kernelName: tf.PadV2, backendName: 'webgl', kernelFunc: padV2 }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const POW = ` if(a < 0.0 && floor(b) < b){ return NAN; } if (b == 0.0) { return 1.0; } return (round(mod(b, 2.0)) != 1) ? pow(abs(a), b) : sign(a) * pow(abs(a), b); `; const POW_PACKED = ` // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise. vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1))); vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1); vec4 result = multiplier * pow(abs(a), b); // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS bvec4 isExpZero = equal(b, vec4(0.0)); result.r = isExpZero.r ? 1.0 : result.r; result.g = isExpZero.g ? 1.0 : result.g; result.b = isExpZero.b ? 1.0 : result.b; result.a = isExpZero.a ? 1.0 : result.a; bvec4 isNaN1 = lessThan(a, vec4(0.0)); bvec4 isNaN2 = lessThan(floor(b), b); bvec4 isNaN = bvec4(isNaN1.x && isNaN2.x, isNaN1.y && isNaN2.y, isNaN1.z && isNaN2.z, isNaN1.w && isNaN2.w); ` + CHECK_NAN_SNIPPET_PACKED + ` return result; `; const pow = binaryKernelFunc({ opSnippet: POW, packedOpSnippet: POW_PACKED }); const powConfig = { kernelName: tf.Pow, backendName: 'webgl', kernelFunc: pow }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function prod(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, keepDims } = attrs; const xRank = x.shape.length; const toDispose = []; const origAxes = tf.util.parseAxisParam(axis, x.shape); let axes = origAxes; const permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank); let permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); axes = tf.backend_util.getInnerMostAxes(axes.length, xRank); toDispose.push(permutedX); } tf.backend_util.assertAxesAreInnerMostDims('prod', axes, xRank); let res; if (backend.shouldExecuteOnCPU([permutedX])) { const xVals = backend.texData.get(permutedX.dataId).values; const { outVals, outShape, outDtype } = prodImplCPU(permutedX.shape, permutedX.dtype, xVals, axes); res = backend.makeTensorInfo(outShape, outDtype, outVals); } else { const [outShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes); const inSize = tf.util.sizeFromShape(reduceShape); const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } }); const outputDType = tf.sumOutType(x.dtype); const reduced = reduce(a2D, outputDType, 'prod', backend); res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } }); toDispose.push(a2D); toDispose.push(reduced); } if (keepDims) { toDispose.push(res); const newShape = tf.backend_util.expandShapeToKeepDim(res.shape, origAxes); res = reshape({ inputs: { x: res }, backend, attrs: { shape: newShape } }); } toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return res; } const prodConfig = { kernelName: tf.Prod, backendName: 'webgl', kernelFunc: prod }; /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function raggedGather(args) { const { inputs, backend, attrs } = args; const { paramsNestedSplits, paramsDenseValues, indices } = inputs; const { outputRaggedRank } = attrs; const $paramsNestedSplits = paramsNestedSplits.map(t => backend.readSync(t.dataId)); const $paramsNestedSplitsShapes = paramsNestedSplits.map(t => t.shape); const $paramsDenseValues = backend.readSync(paramsDenseValues.dataId); const $indices = backend.readSync(indices.dataId); const [outputNestedSplits, outputDenseValues, outputDenseValuesShape] = raggedGatherImplCPU($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape, outputRaggedRank); const outputNestedSplitsTensors = outputNestedSplits.map((splits) => backend.makeTensorInfo([splits.length], 'int32', splits)); const outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues); return outputNestedSplitsTensors.concat([outputDenseValuesTensor]); } const raggedGatherConfig = { kernelName: tf.RaggedGather, backendName: 'webgl', kernelFunc: raggedGather, }; /** * @license * Copyright 2022 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function raggedRange(args) { const { inputs, backend } = args; const { starts, limits, deltas } = inputs; const $starts = backend.readSync(starts.dataId); const $limits = backend.readSync(limits.dataId); const $deltas = backend.readSync(deltas.dataId); const [rtNestedSplitsData, rtDenseValuesData] = raggedRangeImplCPU($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape); const rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData); const rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData); return [rtNestedSplits, rtDenseValues]; } const raggedRangeConfig = { kernelName: tf.RaggedRange, backendName: 'webgl', kernelFunc: raggedRange, }; /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function raggedTensorToTensor(args) { const { inputs, backend, attrs } = args; const { shape, values, defaultValue, rowPartitionTensors } = inputs; const { rowPartitionTypes } = attrs; const $shape = backend.readSync(shape.dataId); const $values = backend.readSync(values.dataId); const $defaultValue = backend.readSync(defaultValue.dataId); const $rowPartitionValues = rowPartitionTensors.map(t => backend.readSync(t.dataId)); const rowPartitionValuesShapes = rowPartitionTensors.map(t => t.shape); const [outputShape, output] = raggedTensorToTensorImplCPU($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes); return backend.makeTensorInfo(outputShape, values.dtype, output); } const raggedTensorToTensorConfig = { kernelName: tf.RaggedTensorToTensor, backendName: 'webgl', kernelFunc: raggedTensorToTensor, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const range = (args) => { const { backend, attrs } = args; const { start, stop, step, dtype } = attrs; const values = rangeImplCPU(start, stop, step, dtype); return backend.makeTensorInfo([values.length], dtype, values); }; const rangeConfig = { kernelName: tf.Range, backendName: 'webgl', kernelFunc: range }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const RECIPROCAL = `return 1.0 / x;`; const reciprocal = unaryKernelFunc({ opSnippet: RECIPROCAL }); const reciprocalConfig = { kernelName: tf.Reciprocal, backendName: 'webgl', kernelFunc: reciprocal, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const RELU = CHECK_NAN_SNIPPET$1 + ` return (x < 0.0) ? 0.0 : x; `; const RELU_PACKED = ` vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0))); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const relu = unaryKernelFunc({ opSnippet: RELU, packedOpSnippet: RELU_PACKED }); const reluConfig = { kernelName: tf.Relu, backendName: 'webgl', kernelFunc: relu }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const RELU6 = CHECK_NAN_SNIPPET$1 + ` return (x < 0.0) ? 0.0 : min(6.0, x); `; const RELU6_PACKED = ` vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0))); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const relu6 = unaryKernelFunc({ opSnippet: RELU6, packedOpSnippet: RELU6_PACKED }); const relu6Config = { kernelName: tf.Relu6, backendName: 'webgl', kernelFunc: relu6 }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ResizeBilinearProgram { constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.outputShape = []; const [batch, oldHeight, oldWidth, depth] = inputShape; this.outputShape = [batch, newHeight, newWidth, depth]; const effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; const effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; let sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = `(vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` + ` - vec2(0.5)`; } else { sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`; } this.userCode = ` const vec2 effectiveInputOverOutputRatioRC = vec2( ${effectiveInSize[0] / effectiveOutSize[0]}, ${effectiveInSize[1] / effectiveOutSize[1]}); const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0); void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; ivec2 yRC = coords.yz; // Fractional source index. vec2 sourceFracIndexRC = ${sourceFracIndexRC}; // Compute the four integer indices. ivec2 sourceFloorRC = ivec2(max(sourceFracIndexRC, vec2(0.0))); ivec2 sourceCeilRC = ivec2( min(inputShapeRC - 1.0, ceil(sourceFracIndexRC))); float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d); float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d); float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d); float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d); vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC); float top = topLeft + (topRight - topLeft) * fracRC.y; float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y; float newValue = top + (bottom - top) * fracRC.x; setOutput(newValue); } `; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ResizeBilinearPackedProgram { constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outputShape = []; const [batch, oldHeight, oldWidth, depth] = inputShape; this.outputShape = [batch, newHeight, newWidth, depth]; const effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; const effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; let sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = `(vec3(yRC) + vec3(0.5)) * ` + `effectiveInputOverOutputRatioRC - vec3(0.5)`; } else { sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`; } this.userCode = ` const vec3 effectiveInputOverOutputRatioRC = vec3( ${effectiveInSize[0] / effectiveOutSize[0]}, ${effectiveInSize[1] / effectiveOutSize[1]}, ${effectiveInSize[1] / effectiveOutSize[1]}); const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0, ${oldWidth}.0); float getAValue(int b, int r, int c, int d) { return getChannel(getA(b, r, c, d), vec2(c, d)); } void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; // Calculate values for next column in yRC.z. ivec3 yRC = coords.yzz + ivec3(0, 0, 1); // Fractional source index. vec3 sourceFracIndexRC = ${sourceFracIndexRC}; // Compute the four integer indices. ivec3 sourceFloorRC = ivec3(max(sourceFracIndexRC, vec3(0.0))); ivec3 sourceCeilRC = ivec3( min(inputShapeRC - 1.0, ceil(sourceFracIndexRC))); // Should we calculate next column and row elements in 2x2 packed cell. bool hasNextCol = d < ${depth - 1}; bool hasNextRow = coords.z < ${newWidth - 1}; // In parallel, construct four corners for all four components in // packed 2x2 cell. vec4 topLeft = vec4( getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d), hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1) : 0.0, hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d) : 0.0, (hasNextRow && hasNextCol) ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0); vec4 bottomLeft = vec4( getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d), hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1) : 0.0, hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d) : 0.0, (hasNextRow && hasNextCol) ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0); vec4 topRight = vec4( getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d), hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1) : 0.0, hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d) : 0.0, (hasNextRow && hasNextCol) ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0); vec4 bottomRight = vec4( getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d), hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1) : 0.0, hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d) : 0.0, (hasNextRow && hasNextCol) ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0); vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC); vec4 top = mix(topLeft, topRight, fracRC.yyzz); vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz); vec4 newValue = mix(top, bottom, fracRC.x); setOutput(newValue); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function resizeBilinear(args) { const { inputs, backend, attrs } = args; const { images } = inputs; const { alignCorners, halfPixelCenters, size } = attrs; const [newHeight, newWidth] = size; const program = tf.env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeBilinearPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) : new ResizeBilinearProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters); return backend.runWebGLProgram(program, [images], 'float32'); } const resizeBilinearConfig = { kernelName: tf.ResizeBilinear, backendName: 'webgl', kernelFunc: resizeBilinear }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ResizeBilinearBackpropProgram { constructor(dyShape, inputShape, alignCorners) { this.variableNames = ['dy']; this.outputShape = []; this.outputShape = inputShape; const [, xHeight, xWidth,] = inputShape; const [, yHeight, yWidth] = dyShape; // In the backwards pass, we want to find the pixels that were generated for // each pixel in the input image the forward pass and add the corresponding // coefficient from dy to the gradient (with some interpolation). const effectiveXSize = [ (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth ]; const effectiveYSize = [ (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth ]; const heightScale = effectiveXSize[0] / effectiveYSize[0]; const widthScale = effectiveXSize[1] / effectiveYSize[1]; const invHeightScale = 1 / heightScale; const invWidthScale = 1 / widthScale; // This defines the size of the window of values around a particular // index in dy that we want to search for contributions to dx. const winHeight = (Math.ceil(invHeightScale) * 2) + 2; const winWidth = (Math.ceil(invWidthScale) * 2) + 2; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; int r = coords[1]; int c = coords[2]; float accumulator = 0.0; const float heightScale = float(${heightScale}); const float widthScale = float(${widthScale}); const float invHeightScale = float(${invHeightScale}); const float invWidthScale = float(${invWidthScale}); const int winHeight = int(${winHeight}); const int winWidth = int(${winWidth}); // Compute bounds for where in dy we will look float startRLerp = floor(float(r) * invHeightScale); int startDyR = int(startRLerp - float(winHeight / 2)); float startCLerp = floor(float(c) * invWidthScale); int startDyC = int(startCLerp - float(winWidth / 2)); // Loop over dy for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) { int dyR = dyROffset + startDyR; // Guard against the window exceeding the bounds of dy if (dyR < 0 || dyR >= ${yHeight}) { continue; } for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) { int dyC = dyCOffset + startDyC; // Guard against the window exceeding the bounds of dy if (dyC < 0 || dyC >= ${yWidth}) { continue; } float dxR = float(dyR) * heightScale; int topDxRIndex = int(floor(dxR)); int bottomDxRIndex = int(min(ceil(dxR), ${xHeight - 1}.0)); float dxRLerp = dxR - float(topDxRIndex); float inverseDxRLerp = 1.0 - dxRLerp; float dxC = float(dyC) * widthScale; int leftDxCIndex = int(floor(dxC)); int rightDxCIndex = int(min(ceil(dxC), ${xWidth - 1}.0)); float dxCLerp = dxC - float(leftDxCIndex); float inverseDxCLerp = 1.0 - dxCLerp; if (r == topDxRIndex && c == leftDxCIndex) { // topLeft accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp; } if (r == topDxRIndex && c == rightDxCIndex) { // topRight accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp; } if (r == bottomDxRIndex && c == leftDxCIndex) { // bottomLeft accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp; } if (r == bottomDxRIndex && c == rightDxCIndex) { // bottomRight accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp; } } } // End loop over dy setOutput(accumulator); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function resizeBilinearGrad(args) { const { inputs, backend, attrs } = args; const { images, dy } = inputs; const { alignCorners } = attrs; const program = new ResizeBilinearBackpropProgram(dy.shape, images.shape, alignCorners); return backend.runWebGLProgram(program, [dy], dy.dtype); } const resizeBilinearGradConfig = { kernelName: tf.ResizeBilinearGrad, backendName: 'webgl', kernelFunc: resizeBilinearGrad }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ResizeNearestNeighborProgram { constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.outputShape = []; const [batch, oldHeight, oldWidth, depth] = inputShape; this.outputShape = [batch, newHeight, newWidth, depth]; const effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; const effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; // When align corners is false, we rounds the value with floor. const roundBase = alignCorners ? '0.5' : '0.0'; let sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = `max((vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` + `, vec2(0.0))`; } else { sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`; } this.userCode = ` const vec2 effectiveInputOverOutputRatioRC = vec2( ${effectiveInSize[0] / effectiveOutSize[0]}, ${effectiveInSize[1] / effectiveOutSize[1]}); const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0); void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; ivec2 yRC = coords.yz; // Fractional source index. vec2 sourceFracIndexRC = ${sourceFracIndexRC}; // Compute the coordinators of nearest neighbor point. ivec2 sourceNearestRC = ivec2( min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase}))); float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d); setOutput(newValue); } `; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ResizeNearestNeighborPackedProgram { constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outputShape = []; const [batch, oldHeight, oldWidth, depth] = inputShape; this.outputShape = [batch, newHeight, newWidth, depth]; const effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; const effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; // When align corners is false, we rounds the value with floor. const roundBase = alignCorners ? '0.5' : '0.0'; let sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = `max((vec3(yRC) + vec3(0.5)) * ` + `effectiveInputOverOutputRatioRC, vec3(0.0))`; } else { sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`; } this.userCode = ` const vec3 effectiveInputOverOutputRatioRC = vec3( ${effectiveInSize[0] / effectiveOutSize[0]}, ${effectiveInSize[1] / effectiveOutSize[1]}, ${effectiveInSize[1] / effectiveOutSize[1]}); const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0, ${oldWidth}.0); float getAValue(int b, int r, int c, int d) { return getChannel(getA(b, r, c, d), vec2(c, d)); } void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; // Calculate values for next column in yRC.z. ivec3 yRC = coords.yzz + ivec3(0, 0, 1); // Fractional source index. vec3 sourceFracIndexRC = ${sourceFracIndexRC}; // Compute the coordinators of nearest neighbor point. ivec3 sourceNearestRC = ivec3( min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase}))); // Should we calculate next column and row elements in 2x2 packed cell. bool hasNextCol = d < ${depth - 1}; bool hasNextRow = coords.z < ${newWidth - 1}; vec4 newValue = vec4( getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d), hasNextCol ? getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d + 1) : 0.0, hasNextRow ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d) : 0.0, (hasNextRow && hasNextCol) ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d + 1) : 0.0); setOutput(newValue); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function resizeNearestNeighbor(args) { const { inputs, backend, attrs } = args; const { images } = inputs; const { alignCorners, halfPixelCenters, size } = attrs; const [newHeight, newWidth] = size; const program = tf.env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeNearestNeighborPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) : new ResizeNearestNeighborProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters); return backend.runWebGLProgram(program, [images], images.dtype); } const resizeNearestNeighborConfig = { kernelName: tf.ResizeNearestNeighbor, backendName: 'webgl', kernelFunc: resizeNearestNeighbor }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ResizeNearestNeigborBackpropProgram { constructor(dyShape, inputShape, alignCorners) { this.variableNames = ['dy']; this.outputShape = []; this.outputShape = inputShape; const [, xHeight, xWidth,] = inputShape; const [, yHeight, yWidth] = dyShape; // In the backwards pass, we want to find the pixels that were generated for // each pixel in the input image the forward pass and add the corresponding // coefficient from dy to the gradient (with some interpolation). const effectiveXSize = [ (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth ]; const effectiveYSize = [ (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth ]; const heightScale = effectiveXSize[0] / effectiveYSize[0]; const widthScale = effectiveXSize[1] / effectiveYSize[1]; const invHeightScale = 1 / heightScale; const invWidthScale = 1 / widthScale; // This defines the size of the window of values around a particular // index in dy that we want to search for contributions to dx. const winHeight = (Math.ceil(invHeightScale) * 2) + 2; const winWidth = (Math.ceil(invWidthScale) * 2) + 2; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; int r = coords[1]; int c = coords[2]; float accumulator = 0.0; const float heightScale = float(${heightScale}); const float widthScale = float(${widthScale}); const float invHeightScale = float(${invHeightScale}); const float invWidthScale = float(${invWidthScale}); const int winHeight = int(${winHeight}); const int winWidth = int(${winWidth}); // Compute bounds for where in dy we will look float startRLerp = floor(float(r) * invHeightScale); int startDyR = int(floor(startRLerp - float(winHeight / 2))); float startCLerp = floor(float(c) * invWidthScale); int startDyC = int(floor(startCLerp - float(winWidth / 2))); // Loop over dy for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) { int dyR = dyROffset + startDyR; // Guard against the window exceeding the bounds of dy if (dyR < 0 || dyR >= ${yHeight}) { continue; } for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) { int dyC = dyCOffset + startDyC; // Guard against the window exceeding the bounds of dy if (dyC < 0 || dyC >= ${yWidth}) { continue; } float sourceFracRow = float(${effectiveXSize[0]}) * (float(dyR) / float(${effectiveYSize[0]})); float sourceFracCol = float(${effectiveXSize[1]}) * (float(dyC) / float(${effectiveYSize[1]})); int sourceNearestRow = int(min( float(int(${xHeight}) - 1), ${alignCorners} ? float(round(sourceFracRow)) : float(floor(sourceFracRow)))); int sourceNearestCol = int(min( float(int(${xWidth}) - 1), ${alignCorners} ? float(round(sourceFracCol)) : float(floor(sourceFracCol)))); if (r == sourceNearestRow && c == sourceNearestCol) { accumulator += getDy(b, dyR, dyC, d); } } } // End loop over dy setOutput(accumulator); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function resizeNearestNeighborGrad(args) { const { inputs, backend, attrs } = args; const { images, dy } = inputs; const { alignCorners } = attrs; const program = new ResizeNearestNeigborBackpropProgram(dy.shape, images.shape, alignCorners); return backend.runWebGLProgram(program, [dy], dy.dtype); } const resizeNearestNeighborGradConfig = { kernelName: tf.ResizeNearestNeighborGrad, backendName: 'webgl', kernelFunc: resizeNearestNeighborGrad }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ReverseProgram { constructor(xShape, axis) { this.variableNames = ['x']; const rank = xShape.length; if (rank > 4) { throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`); } this.outputShape = xShape; if (rank === 1) { this.userCode = ` void main() { int coord = getOutputCoords(); setOutput(getX(${xShape[0]} - coord - 1)); } `; return; } const getInCoord = (i) => { if (axis.indexOf(i) !== -1 && xShape[i] !== 1) { return `${xShape[i]} - coords[${i}] - 1`; } return `coords[${i}]`; }; const inCoords = xShape.map((_, i) => getInCoord(i)).join(','); const type = getCoordsDataType(rank); this.userCode = ` void main() { ${type} coords = getOutputCoords(); setOutput(getX(${inCoords})); } `; } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ReversePackedProgram { constructor(xShape, axis) { this.variableNames = ['x']; this.packedInputs = true; this.packedOutput = true; const rank = xShape.length; if (rank > 4) { throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`); } this.outputShape = xShape; const channels = getChannels('rc', rank); const nextColumn = `${channels[rank - 1]} + 1 < ${this.outputShape[rank - 1]}`; const nextRow = `${channels[rank - 2]} + 1 < ${this.outputShape[rank - 2]}`; const type = getCoordsDataType(rank); if (rank === 1) { this.userCode = ` void main(){ int rc = getOutputCoords(); vec4 result = vec4(0.); result.r = getChannel(getX(${xShape[0]} - rc - 1), ${xShape[0]} - rc - 1); if(${nextColumn}){ result.g = getChannel(getX(${xShape[0]} - (rc + 1) - 1), ${xShape[0]} - (rc + 1) - 1); } setOutput(result); } `; } else { this.userCode = ` void main() { ${type} rc = getOutputCoords(); vec4 result = vec4(0.); result.r = ${getR(channels.slice())}; if(${nextColumn}){ result.g = ${getG(channels.slice())}; } if(${nextRow}) { result.b = ${getB(channels.slice())}; if(${nextColumn}) { result.a = ${getA(channels.slice())}; } } setOutput(result); } `; } function getR(channels) { return getChannel(channels); } function getG(channels) { channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`; return getChannel(channels); } function getB(channels) { channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`; return getChannel(channels); } function getA(channels) { channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`; channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`; return getChannel(channels); } function getChannel(channels) { const inCoordsArray = xShape.map((_, i) => getInCoord(i, channels)); const inCoords = inCoordsArray.join(','); const innerDims = inCoordsArray.slice(-2).join(','); return `getChannel(getX(${inCoords}), vec2(${innerDims}))`; } function getInCoord(i, channels1) { if (axis.indexOf(i) !== -1 && xShape[i] !== 1) { return `${xShape[i]} - ${channels1[i]} - 1`; } else { return `${channels1[i]}`; } } } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function reverse(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { dims } = attrs; const xRank = x.shape.length; const $dims = tf.util.parseAxisParam(dims, x.shape); if (xRank === 0) { return identity({ inputs: { x }, backend }); } const program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new ReversePackedProgram(x.shape, $dims) : new ReverseProgram(x.shape, $dims); return backend.runWebGLProgram(program, [x], x.dtype); } const reverseConfig = { kernelName: tf.Reverse, backendName: 'webgl', kernelFunc: reverse }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class RotateProgram { constructor(imageShape, fillValue) { this.variableNames = ['Image']; this.outputShape = []; this.customUniforms = [{ name: 'params', type: 'vec4' }]; const imageHeight = imageShape[1]; const imageWidth = imageShape[2]; this.outputShape = imageShape; let fillSnippet = ''; if (typeof fillValue === 'number') { fillSnippet = `float outputValue = ${fillValue.toFixed(2)};`; } else { fillSnippet = ` vec3 fill = vec3(${fillValue.join(',')}); float outputValue = fill[coords[3]];`; } this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int x = coords[2]; int y = coords[1]; float coordXFloat = (float(x) - params[0]) * params[3] - (float(y) - params[1]) * params[2]; float coordYFloat = (float(x) - params[0]) * params[2] + (float(y) - params[1]) * params[3]; int coordX = int(round(coordXFloat + params[0])); int coordY = int(round(coordYFloat + params[1])); ${fillSnippet} if(coordX >= 0 && coordX < ${imageWidth} && coordY >= 0 && coordY < ${imageHeight}) { outputValue = getImage(coords[0], coordY, coordX, coords[3]); } setOutput(outputValue); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const rotateWithOffsetConfig = { kernelName: tf.RotateWithOffset, backendName: 'webgl', kernelFunc: ({ inputs, attrs, backend }) => { const { image } = inputs; const { radians, fillValue, center } = attrs; const webglBackend = backend; const program = new RotateProgram(image.shape, fillValue); const [centerX, centerY] = tf.backend_util.getImageCenter(center, image.shape[1], image.shape[2]); const customValues = [[centerX, centerY, Math.sin(radians), Math.cos(radians)]]; const output = webglBackend.runWebGLProgram(program, [image], image.dtype, customValues); return output; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const ROUND = ` // OpenGL ES does not support round function. // The algorithm is based on banker's rounding. float base = floor(x); if ((x - base) < 0.5) { return floor(x); } else if ((x - base) > 0.5) { return ceil(x); } else { if (mod(base, 2.0) == 0.0) { return base; } else { return base + 1.0; } } `; const round = unaryKernelFunc({ opSnippet: ROUND }); const roundConfig = { kernelName: tf.Round, backendName: 'webgl', kernelFunc: round, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const RSQRT = `return inversesqrt(x);`; const rsqrt = unaryKernelFunc({ opSnippet: RSQRT, cpuKernelImpl: rsqrtImplCPU }); const rsqrtConfig = { kernelName: tf.Rsqrt, backendName: 'webgl', kernelFunc: rsqrt }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ScatterProgram { constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true, defaultIsTensor = false) { this.variableNames = ['updates', 'indices', 'defaultValue']; this.outputShape = shape; const stridesType = getCoordsDataType(strides.length); const dtype = getCoordsDataType(shape.length); let indicesString = ''; if (indicesRank === 1) { indicesString = 'i'; } else if (indicesRank === 2) { indicesString = 'i, j'; } const indicesSnippet = `getIndices(${indicesString})`; let updatesString = ''; if (updatesRank === 1) { updatesString = 'i'; } else if (updatesRank === 2) { updatesString = 'i, coords[1]'; } const updatesSnippet = `getUpdates(${updatesString})`; let defaultValuesString = ''; if (defaultIsTensor) { defaultValuesString = 'coords[0], coords[1]'; } const defaultValueSnippet = `getDefaultValue(${defaultValuesString})`; const strideString = sliceDim > 1 ? 'strides[j]' : 'strides'; this.userCode = ` ${stridesType} strides = ${stridesType}(${strides}); void main() { ${dtype} coords = getOutputCoords(); float sum = 0.0; bool found = false; for (int i = 0; i < ${updateSize}; i++) { int flattenedIndex = 0; for (int j = 0; j < ${sliceDim}; j++) { int index = round(${indicesSnippet}); flattenedIndex += index * ${strideString}; } if (flattenedIndex == coords[0]) { sum += ${updatesSnippet}; found = true; } } setOutput(mix(${defaultValueSnippet}, sum, float(found))); } `; } } /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class ScatterPackedProgram { constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true, defaultIsTensor = false) { this.variableNames = ['updates', 'indices', 'defaultValue']; this.packedInputs = true; this.packedOutput = true; this.outputShape = shape; const stridesType = getCoordsDataType(strides.length); const dtype = getCoordsDataType(shape.length); let indicesString = ''; if (indicesRank === 1) { indicesString = 'i'; } else if (indicesRank === 2) { indicesString = 'i, j'; } const indicesSnippet = `getIndices(${indicesString})`; let updatesString = ''; if (updatesRank === 1) { updatesString = 'i'; } else if (updatesRank === 2) { updatesString = 'i, coords[1]'; } const updatesSnippet = `getUpdates(${updatesString})`; let defaultValuesString = ''; if (defaultIsTensor) { defaultValuesString = 'coords[0], coords[1]'; } const defaultValueSnippet = `getDefaultValue(${defaultValuesString})`; const strideString = sliceDim > 1 ? 'strides[j]' : 'strides'; const strideString2 = sliceDim > 1 ? 'strides[j + 1]' : 'strides'; this.userCode = ` ${stridesType} strides = ${stridesType}(${strides}); void main() { ${dtype} coords = getOutputCoords(); vec4 sum = vec4(0.); vec4 found = vec4(0.); for (int i = 0; i < ${updateSize}; i+=2) { ivec2 flattenedIndex = ivec2(0); for (int j = 0; j < ${sliceDim}; j+=2) { ivec4 index = round(${indicesSnippet}); flattenedIndex += index.xz * ${strideString}; if (j + 1 < ${sliceDim}) { flattenedIndex += index.yw * ${strideString2}; } } if (flattenedIndex[0] == coords[0] || flattenedIndex[1] == coords[0] || flattenedIndex[0] == coords[0] + 1 || flattenedIndex[1] == coords[0] + 1) { vec4 updVals = ${updatesSnippet}; if (flattenedIndex[0] == coords[0]) { sum.xy += updVals.xy; found.xy = vec2(1.); } else if (flattenedIndex[0] == coords[0] + 1) { sum.zw += updVals.xy; found.zw = vec2(1.); } if (flattenedIndex[1] == coords[0]) { sum.xy += updVals.zw; found.xy = vec2(1.); } else if (flattenedIndex[1] == coords[0] + 1) { sum.zw += updVals.zw; found.zw = vec2(1.); } } } setOutput(mix(${defaultValueSnippet}, sum, found)); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function scatterNd(args) { const { inputs, backend, attrs } = args; const { indices, updates } = inputs; const { shape } = attrs; const { sliceRank, numUpdates, sliceSize, strides, outputSize } = tf.backend_util.calculateShapes(updates, indices, shape); const flattenShape = [outputSize / sliceSize, sliceSize]; if (outputSize === 0) { return backend.makeTensorInfo(shape, indices.dtype); } const flattenIndices = reshape({ inputs: { x: indices }, backend, attrs: { shape: [numUpdates, sliceRank] } }); const flattenX = reshape({ inputs: { x: updates }, backend, attrs: { shape: [numUpdates, sliceSize] } }); const defaultValue = backend.makeTensorInfo([], 'float32', new Float32Array([0])); // scalar(0) let program; if (tf.env().getBool('WEBGL_PACK')) { program = new ScatterPackedProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape); } else { program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape); } const res = backend.runWebGLProgram(program, [flattenX, flattenIndices, defaultValue], flattenX.dtype); const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape } }); backend.disposeIntermediateTensorInfo(flattenIndices); backend.disposeIntermediateTensorInfo(flattenX); backend.disposeIntermediateTensorInfo(res); backend.disposeIntermediateTensorInfo(defaultValue); return reshaped; } const scatterNdConfig = { kernelName: tf.ScatterNd, backendName: 'webgl', kernelFunc: scatterNd }; /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class SearchSortedProgram { constructor(batchSize, numInputs, numValues, side) { this.variableNames = ['sortedSequence', 'values']; this.customUniforms = [{ name: 'numInputs', type: 'int' }]; this.outputShape = [batchSize, numValues]; const webGL2LoopHead = 'while (left < right) {'; // WebGL1 doesn't accept non constant loop conditions, so upper bound loop // iterations. const webGL1LoopHead = `for (int i = 0; i < ${Math.ceil(Math.log2(numInputs + 1))}; ++i) { if (left >= right) break;`; const loopHead = tf.env().getNumber('WEBGL_VERSION') === 2 ? webGL2LoopHead : webGL1LoopHead; // left corresponds to lower bound and right to upper bound. const boundComparator = side === 'left' ? '<' : '<='; this.userCode = ` int findBound(int batch, float value) { int left = 0; int right = numInputs; int mid; ${loopHead} mid = (left + right) / 2; if (getSortedSequence(batch, mid) ${boundComparator} value) { left = mid + 1; } else { right = mid; } } return right; } void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int valueIndex = coords[1]; float value = getValues(batch, valueIndex); setOutput(float(findBound(batch, value))); } `; } } /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function searchSorted(args) { const { inputs, backend, attrs } = args; const { sortedSequence, values } = inputs; const { side } = attrs; const program = new SearchSortedProgram(sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side); const customValues = [[sortedSequence.shape[1]]]; return backend.runWebGLProgram(program, [sortedSequence, values], 'int32', customValues); } const searchSortedConfig = { kernelName: tf.SearchSorted, backendName: 'webgl', kernelFunc: searchSorted, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class SelectProgram { constructor(cRank, shape, rank) { this.variableNames = ['c', 'a', 'b']; this.outputShape = shape; let cCoords; let abCoords; if (rank > 4) { throw Error(`Where for rank ${rank} is not yet supported`); } if (rank === 1) { abCoords = `resRC`; cCoords = `resRC`; } else { const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; const cCoordVars = []; const abCoordVars = []; for (let i = 0; i < shape.length; i++) { abCoordVars.push(`${currentCoords[i]}`); if (i < cRank) { cCoordVars.push(`${currentCoords[i]}`); } } cCoords = cCoordVars.join(); abCoords = abCoordVars.join(); } const dtype = getCoordsDataType(rank); this.userCode = ` void main() { ${dtype} resRC = getOutputCoords(); float cVal = getC(${cCoords}); if (cVal >= 1.0) { setOutput(getA(${abCoords})); } else { setOutput(getB(${abCoords})); } } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function select(args) { const { inputs, backend } = args; const { condition, t, e } = inputs; const program = new SelectProgram(condition.shape.length, t.shape, t.shape.length); return backend.runWebGLProgram(program, [condition, t, e], tf.upcastType(t.dtype, e.dtype)); } const selectConfig = { kernelName: tf.Select, backendName: 'webgl', kernelFunc: select }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const SELU = ` // Stable and Attracting Fixed Point (0, 1) for Normalized Weights. // see: https://arxiv.org/abs/1706.02515 float scaleAlpha = ${tf.backend_util.SELU_SCALEALPHA}; float scale = ${tf.backend_util.SELU_SCALE}; return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0); `; const selu = unaryKernelFunc({ opSnippet: SELU }); const seluConfig = { kernelName: tf.Selu, backendName: 'webgl', kernelFunc: selu, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const SIGMOID = CHECK_NAN_SNIPPET_UNARY + ` return 1.0 / (1.0 + exp(-1.0 * x)); `; const SIGMOID_PACKED = ` vec4 result = 1.0 / (1.0 + exp(-1.0 * x)); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const sigmoid = unaryKernelFunc({ opSnippet: SIGMOID, packedOpSnippet: SIGMOID_PACKED, cpuKernelImpl: sigmoidImplCPU }); const sigmoidConfig = { kernelName: tf.Sigmoid, backendName: 'webgl', kernelFunc: sigmoid, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Sign does not propagate NANs. const SIGN = ` if (isnan(x)) { return 0.0; } return sign(x); `; const sign = unaryKernelFunc({ opSnippet: SIGN }); const signConfig = { kernelName: tf.Sign, backendName: 'webgl', kernelFunc: sign, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const SIN = CHECK_NAN_SNIPPET_UNARY + ` return sin(x); `; const SIN_PACKED = ` vec4 result = sin(x); bvec4 isNaN = isnan(x); ${CHECK_NAN_SNIPPET_PACKED} return result; `; const sin = unaryKernelFunc({ opSnippet: SIN, packedOpSnippet: SIN_PACKED }); const sinConfig = { kernelName: tf.Sin, backendName: 'webgl', kernelFunc: sin, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const SINH = ` float e2x = exp(x); return (e2x - 1.0 / e2x) / 2.0; `; const sinh = unaryKernelFunc({ opSnippet: SINH }); const sinhConfig = { kernelName: tf.Sinh, backendName: 'webgl', kernelFunc: sinh, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const SOFTPLUS = ` float epsilon = 1.1920928955078125e-7; float threshold = log(epsilon) + 2.0; bool too_large = x > -threshold; bool too_small = x < threshold; float result; float exp_x = exp(x); if (too_large){ result = x; } else if (too_small){ result = exp_x; } else{ result = log(exp_x + 1.0); } return result; `; const softplus = unaryKernelFunc({ opSnippet: SOFTPLUS }); const softplusConfig = { kernelName: tf.Softplus, backendName: 'webgl', kernelFunc: softplus, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const spaceToBatchND = (args) => { const { inputs, backend, attrs } = args; const { x } = inputs; const { blockShape, paddings } = attrs; tf.util.assert(x.shape.length <= 4, () => 'spaceToBatchND for rank > 4 with a WebGL backend not ' + 'implemented yet'); const prod = blockShape.reduce((a, b) => a * b); const completePaddings = [[0, 0]]; completePaddings.push(...paddings); for (let i = 1 + blockShape.length; i < x.shape.length; ++i) { completePaddings.push([0, 0]); } const toDispose = []; const paddedX = padV2({ inputs: { x }, backend, attrs: { paddings: completePaddings, constantValue: 0 } }); const reshapedPaddedShape = tf.backend_util.getReshaped(paddedX.shape, blockShape, prod, false); const permutedReshapedPaddedPermutation = tf.backend_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false); const flattenShape = tf.backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false); const reshapedPaddedX = reshape({ inputs: { x: paddedX }, backend, attrs: { shape: reshapedPaddedShape } }); const paddedXT = transpose({ inputs: { x: reshapedPaddedX }, backend, attrs: { perm: permutedReshapedPaddedPermutation } }); const result = reshape({ inputs: { x: paddedXT }, backend, attrs: { shape: flattenShape } }); toDispose.push(paddedX); toDispose.push(reshapedPaddedX); toDispose.push(paddedXT); toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return result; }; const spaceToBatchNDConfig = { kernelName: tf.SpaceToBatchND, backendName: 'webgl', kernelFunc: spaceToBatchND }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseFillEmptyRows(args) { const { inputs, backend } = args; const { indices, values, denseShape, defaultValue } = inputs; if (denseShape.shape.length !== 1) { throw new Error(`Dense shape must be a vector, saw: ${denseShape.shape}`); } if (indices.shape.length !== 2) { throw new Error(`Indices must be a matrix, saw: ${indices.shape}`); } if (values.shape.length !== 1) { throw new Error(`Values must be a vector, saw: ${values.shape}`); } if (defaultValue.shape.length !== 0) { throw new Error(`Default value must be a scalar, saw: ${defaultValue.shape}`); } const $indices = backend.readSync(indices.dataId); const $values = backend.readSync(values.dataId); const $denseShape = backend.readSync(denseShape.dataId); const $defaultValue = backend.readSync(defaultValue.dataId)[0]; const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue); return [ backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices), backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues), backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))), backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)), ]; } const sparseFillEmptyRowsConfig = { kernelName: tf.SparseFillEmptyRows, backendName: 'webgl', kernelFunc: sparseFillEmptyRows, }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseReshape(args) { const { inputs, backend } = args; const { inputIndices, inputShape, newShape } = inputs; if (inputIndices.shape.length !== 2) { throw new Error(`Input indices should be a matrix but received shape ${inputIndices.shape}`); } if (inputShape.shape.length !== 1) { throw new Error(`Input shape should be a vector but received shape ${inputShape.shape}`); } if (newShape.shape.length !== 1) { throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`); } const $inputShape = Array.from(backend.readSync(inputShape.dataId)); const $inputIndices = backend.readSync(inputIndices.dataId); const targetShape = Array.from(backend.readSync(newShape.dataId)); const [newIndices, indicesShape, outputShape] = sparseReshapeImplCPU($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape); return [ backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices), backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)), ]; } const sparseReshapeConfig = { kernelName: tf.SparseReshape, backendName: 'webgl', kernelFunc: sparseReshape, }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseSegmentMean(args) { const { inputs, backend } = args; const { data, indices, segmentIds } = inputs; if (data.shape.length < 1) { throw new Error(`Data should be at least 1 dimensional but received scalar`); } if (indices.shape.length !== 1) { throw new Error(`Indices should be a vector but received shape ${indices.shape}`); } if (segmentIds.shape.length !== 1) { throw new Error(`Segment ids should be a vector but received shape ${segmentIds.shape}`); } const $data = backend.readSync(data.dataId); const $indices = backend.readSync(indices.dataId); const $segmentIds = backend.readSync(segmentIds.dataId); const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds, true); return backend.makeTensorInfo(outputDataShape, data.dtype, outputData); } const sparseSegmentMeanConfig = { kernelName: tf.SparseSegmentMean, backendName: 'webgl', kernelFunc: sparseSegmentMean, }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseSegmentSum(args) { const { inputs, backend } = args; const { data, indices, segmentIds } = inputs; if (data.shape.length < 1) { throw new Error(`Data should be at least 1 dimensional but received scalar`); } if (indices.shape.length !== 1) { throw new Error(`Indices should be a vector but received shape ${indices.shape}`); } if (segmentIds.shape.length !== 1) { throw new Error(`Segment ids should be a vector but received shape ${segmentIds.shape}`); } const $data = backend.readSync(data.dataId); const $indices = backend.readSync(indices.dataId); const $segmentIds = backend.readSync(segmentIds.dataId); const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds); return backend.makeTensorInfo(outputDataShape, data.dtype, outputData); } const sparseSegmentSumConfig = { kernelName: tf.SparseSegmentSum, backendName: 'webgl', kernelFunc: sparseSegmentSum, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function sparseToDense(args) { const { inputs, backend, attrs } = args; const { sparseIndices, sparseValues, defaultValue } = inputs; const { outputShape } = attrs; const { sliceRank, numUpdates, sliceSize, strides, outputSize } = tf.backend_util.calculateShapes(sparseValues, sparseIndices, outputShape); const sumDupeIndices = false; if (sparseValues.dtype === 'string') { const indicesBuf = backend.bufferSync(sparseIndices); const updatesBuf = backend.bufferSync(sparseValues); const $defaultValue = tf.util.decodeString(backend.readSync(defaultValue.dataId)[0]); const outBuf = scatterImplCPU(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices); return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values); } const program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.shape.length, sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices); const res = backend.runWebGLProgram(program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype); const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: outputShape } }); backend.disposeIntermediateTensorInfo(res); return reshaped; } const sparseToDenseConfig = { kernelName: tf.SparseToDense, backendName: 'webgl', kernelFunc: sparseToDense }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function splitV(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { numOrSizeSplits, axis } = attrs; const $axis = tf.util.parseAxisParam(axis, x.shape)[0]; const splitSizes = tf.backend_util.prepareSplitSize(x, numOrSizeSplits, $axis); const xRank = x.shape.length; const begin = new Array(xRank).fill(0); const size = x.shape.slice(); return splitSizes.map(s => { const sliceSize = [...size]; sliceSize[$axis] = s; const sliceT = slice({ inputs: { x }, backend, attrs: { begin, size: sliceSize } }); begin[$axis] += s; return sliceT; }); } const splitVConfig = { kernelName: tf.SplitV, backendName: 'webgl', kernelFunc: splitV }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const SQRT = `return sqrt(x);`; const sqrt = unaryKernelFunc({ opSnippet: SQRT, packedOpSnippet: SQRT, cpuKernelImpl: sqrtImplCPU }); const sqrtConfig = { kernelName: tf.Sqrt, backendName: 'webgl', kernelFunc: sqrt }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const SQUARE = `return x * x;`; const square = unaryKernelFunc({ opSnippet: SQUARE }); const squareConfig = { kernelName: tf.Square, backendName: 'webgl', kernelFunc: square, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);'; const squaredDifference = binaryKernelFunc({ opSnippet: SQUARED_DIFFERENCE, packedOpSnippet: SQUARED_DIFFERENCE }); const squaredDifferenceConfig = { kernelName: tf.SquaredDifference, backendName: 'webgl', kernelFunc: squaredDifference, }; /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function staticRegexReplace(args) { const { inputs, backend, attrs } = args; const { x } = inputs; if (x.dtype !== 'string') { throw new Error('Input must be of datatype string'); } const $x = backend.readSync(x.dataId); const stringInput = tf.backend_util.fromUint8ToStringArray($x); const output = staticRegexReplaceImplCPU(stringInput, 'string', attrs); return backend.makeTensorInfo(x.shape, 'string', output); } const staticRegexReplaceConfig = { kernelName: tf.StaticRegexReplace, backendName: 'webgl', kernelFunc: staticRegexReplace, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function step({ inputs, attrs, backend }) { const { x } = inputs; const opSnippet = CHECK_NAN_SNIPPET$1 + ` return x > 0.0 ? 1.0 : float(${attrs.alpha}); `; const program = new UnaryOpProgram(x.shape, opSnippet); return backend.runWebGLProgram(program, [x], x.dtype); } const stepConfig = { kernelName: tf.Step, backendName: 'webgl', kernelFunc: step, }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class StridedSliceProgram { constructor(begin, strides, size) { this.variableNames = ['x']; this.outputShape = size; const rank = size.length; const inputDtype = getCoordsDataType(size.length); const dtype = getCoordsDataType(size.length); let newCoords = ''; if (rank === 1) { newCoords = 'coords * strides + begin'; } else { let outputAxis = 0; newCoords = size.map((_, i) => { outputAxis++; return size.length === 1 ? `coords * strides[${i}] + begin[${i}]` : `coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`; }) .join(','); } this.userCode = ` ${inputDtype} begin = ${inputDtype}(${begin}); ${inputDtype} strides = ${inputDtype}(${strides}); void main() { ${dtype} coords = getOutputCoords(); setOutput(getX(${newCoords})); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function stridedSlice(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs; const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = tf.slice_util.sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); let result; if (isIdentity) { // Optimization #1, slice is a no-op plus reshape result = reshape({ inputs: { x }, backend, attrs: { shape: finalShape } }); } else if (sliceDim0 || isSimpleSlice) { // Optimization #2, slice is memory contiguous (only occurs in dim 0) tf.util.assert(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`); const size = tf.slice_util.computeOutShape($begin, $end, $strides); // To tolerate begin[0] > end[0] (a 0-output slice), we min(begin, end). const sliced = slice({ inputs: { x }, backend, attrs: { begin: $begin, size } }); result = reshape({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } }); backend.disposeIntermediateTensorInfo(sliced); } else { const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]); if (shouldExecuteOnCPU) { // tslint:disable-next-line: no-unnecessary-type-assertion const values = backend.readSync(x.dataId); // tslint:disable-next-line: no-unnecessary-type-assertion const xBuf = tf.buffer(x.shape, x.dtype, values); const resultValues = stridedSliceImplCPU(finalShapeSparse, xBuf, $strides, $begin); result = backend.makeTensorInfo(finalShape, x.dtype, resultValues.values); } else { const program = new StridedSliceProgram($begin, $strides, finalShapeSparse); result = backend.runWebGLProgram(program, [x], x.dtype); } } const resultReshaped = reshape({ inputs: { x: result }, backend, attrs: { shape: finalShape } }); backend.disposeIntermediateTensorInfo(result); return resultReshaped; } const stridedSliceConfig = { kernelName: tf.StridedSlice, backendName: 'webgl', kernelFunc: stridedSlice }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function stringNGrams(args) { const { inputs, backend, attrs } = args; const { separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences } = attrs; const { data, dataSplits } = inputs; const $data = backend.readSync(data.dataId); const $dataSplits = backend.readSync(dataSplits.dataId); const [nGrams, nGramsSplits] = stringNGramsImplCPU($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences); return [ backend.makeTensorInfo([nGrams.length], 'string', nGrams), backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits), ]; } const stringNGramsConfig = { kernelName: tf.StringNGrams, backendName: 'webgl', kernelFunc: stringNGrams, }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function stringSplit(args) { const { inputs, backend, attrs } = args; const { skipEmpty } = attrs; const { input, delimiter } = inputs; if (input.dtype !== 'string') { throw new Error('Input must be of datatype string'); } if (input.shape.length !== 1) { throw new Error(`Input must be a vector, got shape: ${input.shape}`); } if (delimiter.shape.length !== 0) { throw new Error(`Delimiter must be a scalar, got shape: ${delimiter.shape}`); } const $input = backend.readSync(input.dataId); const $delimiter = backend.readSync(delimiter.dataId)[0]; const [indices, values, shape] = stringSplitImplCPU($input, $delimiter, skipEmpty); const outputSize = values.length; return [ backend.makeTensorInfo([outputSize, 2], 'int32', indices), backend.makeTensorInfo([outputSize], 'string', values), backend.makeTensorInfo([2], 'int32', new Int32Array(shape)) ]; } const stringSplitConfig = { kernelName: tf.StringSplit, backendName: 'webgl', kernelFunc: stringSplit, }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function stringToHashBucketFast(args) { const { inputs, backend, attrs } = args; const { numBuckets } = attrs; const { input } = inputs; if (input.dtype !== 'string') { throw new Error('Input must be of datatype string'); } if (numBuckets <= 0) { throw new Error(`Number of buckets must be at least 1`); } const $input = backend.readSync(input.dataId); const output = stringToHashBucketFastImplCPU($input, numBuckets); return backend.makeTensorInfo(input.shape, 'int32', output); } const stringToHashBucketFastConfig = { kernelName: tf.StringToHashBucketFast, backendName: 'webgl', kernelFunc: stringToHashBucketFast, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const TAN = `return tan(x);`; const tan = unaryKernelFunc({ opSnippet: TAN }); const tanConfig = { kernelName: tf.Tan, backendName: 'webgl', kernelFunc: tan, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const TANH = ` float e2x = exp(-2.0 * abs(x)); return sign(x) * (1.0 - e2x) / (1.0 + e2x); `; const tanh = unaryKernelFunc({ opSnippet: TANH }); const tanhConfig = { kernelName: tf.Tanh, backendName: 'webgl', kernelFunc: tanh, }; /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function tensorScatterUpdate(args) { const { inputs, backend, attrs } = args; const { tensor, indices, updates } = inputs; const { sliceRank, numUpdates, sliceSize, strides, outputSize } = tf.backend_util.calculateShapes(updates, indices, tensor.shape); const flattenShape = [outputSize / sliceSize, sliceSize]; if (outputSize === 0) { return backend.makeTensorInfo(tensor.shape, indices.dtype); } const flattenIndices = reshape({ inputs: { x: indices }, backend, attrs: { shape: [numUpdates, sliceRank] } }); const flattenX = reshape({ inputs: { x: updates }, backend, attrs: { shape: [numUpdates, sliceSize] } }); const flattenTensor = reshape({ inputs: { x: tensor }, backend, attrs: { shape: flattenShape } }); const program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape, false, true); const res = backend.runWebGLProgram(program, [flattenX, flattenIndices, flattenTensor], flattenTensor.dtype); const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: tensor.shape } }); backend.disposeIntermediateTensorInfo(flattenIndices); backend.disposeIntermediateTensorInfo(flattenX); backend.disposeIntermediateTensorInfo(flattenTensor); backend.disposeIntermediateTensorInfo(res); return reshaped; } const tensorScatterUpdateConfig = { kernelName: tf.TensorScatterUpdate, backendName: 'webgl', kernelFunc: tensorScatterUpdate }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class TileProgram { constructor(aShape, reps) { this.variableNames = ['A']; const outputShape = new Array(aShape.length); for (let i = 0; i < outputShape.length; i++) { outputShape[i] = aShape[i] * reps[i]; } this.outputShape = outputShape; this.rank = outputShape.length; const dtype = getCoordsDataType(this.rank); const sourceCoords = getSourceCoords(aShape); this.userCode = ` void main() { ${dtype} resRC = getOutputCoords(); setOutput(getA(${sourceCoords})); } `; } } function getSourceCoords(aShape) { const rank = aShape.length; if (rank > 5) { throw Error(`Tile for rank ${rank} is not yet supported`); } if (rank === 1) { return `imod(resRC, ${aShape[0]})`; } const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u']; const sourceCoords = []; for (let i = 0; i < aShape.length; i++) { sourceCoords.push(`imod(${currentCoords[i]}, ${aShape[i]})`); } return sourceCoords.join(); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function tile(params) { const { inputs, backend, attrs } = params; const { x } = inputs; const { reps } = attrs; // tile gpu program cannot handle rank > 5 case. if (x.dtype === 'string' || x.shape.length > 5) { // Even thought string tensor is always on CPU, just to be consistent on how // to access tensor data. const data = backend.readSync(x.dataId); const value = x.dtype === 'string' ? data.map(d => tf.util.decodeString(d)) : data; const buf = tf.buffer(x.shape, x.dtype, value); const outBuf = tileImplCPU(buf, reps); return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values); } const program = new TileProgram(x.shape, reps); const output = backend.runWebGLProgram(program, [x], x.dtype); return output; } const tileConfig = { kernelName: tf.Tile, backendName: 'webgl', kernelFunc: tile, }; // Based on Algorithm 2 of Bitonic Top K, ref: // https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf // The original algorithm is based on computing the top K only, however // since for TFJS we require the indices of the top K values as well then the // algorithm found here is a bit modified. Rather than producing the values // at each step, the indices containing the top K are generated instead. // The output values are not generated to reduce the number of outputs in the // GPU, the values can easily be retrieved from the indices using a gather // op. class SwapProgram { /** * @param shape desired output shape (can be larger than input shape, output * will be padded with -Infinity) */ constructor(shape) { this.variableNames = ['x', 'indices']; // |n| Size of the original input of TopK. // |firstPass|indicates if this is the first time swap is being used which // means no indices input containing the top K is present yet. // |inc| Swaps pairs of indices (0, inc), (1, inc + 1), (2, inc + 2) ... this.customUniforms = [ { name: 'n', type: 'int' }, { name: 'firstPass', type: 'int' }, { name: 'negativeInf', type: 'float' }, { name: 'dir', type: 'int' }, { name: 'inc', type: 'int' } ]; this.outputShape = shape; this.userCode = ` void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int elemIdx = coords[1]; // We compare elements pair-wise within a group of size 2 * inc. // The comparing rule for each group alternates between ascending // and descending. Within each group, we compare each pair at // positions i and i+inc. To decide whether an element at position i // is x0 or x1, we mod it by 2 * inc, if the result is smaller than // inc, it is in the first half of the group, we denote it as x0, // otherwise we denote it as x1. // For example, as shown in the Bitonic top K paper referenced above, // Figure5(a) shows that element[1] is in the // second half of the group when group size is 2, but it is in the // first half of the group when group size is 4. bool isFirstInPair = imod(elemIdx, 2 * inc) < inc; int i = isFirstInPair ? elemIdx : elemIdx - inc; int i0 = firstPass == 1 ? i : int(getIndices(batch, i)); int i1 = firstPass == 1 ? i + inc : int(getIndices(batch, i + inc)); float x0 = i0 < n ? getX(batch, i0) : negativeInf; float x1 = i1 < n ? getX(batch, i1) : negativeInf; // Denotes which direction indices are in (ascending or descending). bool reverse = imod(elemIdx, 2 * dir) >= dir; bool isGreater = x0 > x1 || (x0 == x1 && i1 > i0); if (reverse == isGreater) { // Elements in opposite order of direction int iTemp = i0; i0 = i1; i1 = iTemp; } if (isFirstInPair) { setOutput(float(i0)); } else { setOutput(float(i1)); } } `; } } class MergeProgram { /** * @param shape desired output shape (must be half of the input size) */ constructor(shape) { this.variableNames = ['x', 'indices']; // |n| Size of the original input of TopK // |firstPass| indicates if this is the first time swap is being used which // means no indices input containing the top K is present yet. // |k| Top k elements desired this.customUniforms = [ { name: 'n', type: 'int' }, { name: 'firstPass', type: 'int' }, { name: 'k', type: 'int' } ]; this.outputShape = shape; this.userCode = ` void main() { // Takes max of indices (0, k), (1, k + 1), (2, k + 2) ... ivec2 coords = getOutputCoords(); int batch = coords[0]; int elemIdx = coords[1]; // The output size is half of the previous size. // If the previous sequence is | | | | _ _ _ _ | | | | _ _ _ _ (k=4), // we only need to output the indices at positions |, the indices at // positions _ can be thrown away, see Figure5(b) After Phase 2 // (Merge phase) in the Bitonic Top K paper referenced above. // For example, the paper shows we only need to output the orange bars. // The output sequence should look like this | | | | | | | |. // Because the sequence is halved, to map the output index back // to the previous sequence to find the corresponding value, // we need to double the index. When we double the index, // we basically interpolate a position, so 2i looks like // | _ | _ | _ | _ | _ | _ | _. We move the | to the first k position // of each 2k positions by - elemIdx % k. E.g. for output at // index 4,5,6,7, we want to get the corresponding element at // original index 8,9,10,11, for output at index 8,9,10,11, // we want to get the corresponding element at original index // 16,17,18,19, so on and so forth. int i = elemIdx < k ? elemIdx : (elemIdx * 2 - imod(elemIdx, k)); int i0 = firstPass == 1 ? i : int(getIndices(batch, i)); int i1 = firstPass == 1 ? i + k : int(getIndices(batch, i + k)); float x0 = getX(batch, i0); float x1 = i1 < n ? getX(batch, i1) : x0; setOutput(x0 >= x1 ? float(i0) : float(i1)); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function disposeIntermediateTensorInfoOrNull(backend, tensorInfo) { if (tensorInfo !== null) { backend.disposeIntermediateTensorInfo(tensorInfo); } } function roundUpToPow2(num) { let pow2 = 1; while (pow2 < num) { pow2 *= 2; } return pow2; } // Based on Algorithm 2 of Bitonic Top K, ref: // https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf function topK(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { k, sorted } = attrs; // Empirically determined constant used to determine last dim threshold for // handing off execution to the CPU. const TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD = tf.env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD'); // Empirically determined constant used to determine k threshold for handing // off execution to the CPU. const TOPK_K_CPU_HANDOFF_THRESHOLD = tf.env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD'); const xShape = x.shape; const lastDim = xShape[xShape.length - 1]; if (backend.shouldExecuteOnCPU([x]) || lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD || k > TOPK_K_CPU_HANDOFF_THRESHOLD) { const xVals = backend.readSync(x.dataId); const [allTopKVals, allTopKIndices] = topKImplCPU(xVals, xShape, x.dtype, k, sorted); return [ backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values), backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values) ]; } if (k === 0) { xShape[xShape.length - 1] = 0; return [ backend.makeTensorInfo(xShape, x.dtype, []), backend.makeTensorInfo(xShape, 'int32', []) ]; } if (lastDim === 1 /* firstPass */) { return [ x, fill({ attrs: { shape: xShape, dtype: 'int32', value: 0 }, backend }) ]; } // Eagerly unpack x input since it is passed in to all the shaders which // require unpacked inputs. const xtexData = backend.texData.get(x.dataId); const xIsPacked = xtexData !== null && xtexData.isPacked; const xUnPacked = xIsPacked ? backend.unpackTensor(x) : x; // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim. const xSize = tf.util.sizeFromShape(xShape); const batch = xSize / lastDim; const x2D = reshape({ inputs: { x: xUnPacked }, attrs: { shape: [batch, lastDim] }, backend }); if (xIsPacked) { disposeIntermediateTensorInfoOrNull(backend, xUnPacked); } const kPow2 = roundUpToPow2(k); const lastDimPow2 = roundUpToPow2(lastDim); // Only the indices containing the top K are kept at every step to reduce // number of outputs in the GPU algorithms, so once the final set of indices // is computed then gather is used to grab the corresponding values // from the original input. let indices = null; // GPU algorithm always takes in an indices input but this input is not used // on the first run of a GPU algorithm, therefore if indices is null we simply // pass in x2D instead of it but the value will not actually be used const getInputs = () => indices === null ? [x2D, x2D] : [x2D, indices]; const runSwap = (dir, inc, shape) => { const inputs = getInputs(); const program = new SwapProgram(shape); const fistPass = indices === null ? 1 : 0; const customValues = [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]]; const prevIndices = indices; indices = backend.runWebGLProgram(program, inputs, 'int32', customValues); disposeIntermediateTensorInfoOrNull(backend, prevIndices); }; // Step 1: local sort for (let len = 1; len < kPow2; len *= 2) { const dir = len * 2; for (let inc = len; inc >= 1; inc /= 2) { runSwap(dir, inc, [batch, lastDimPow2]); } } // Step 2: merge for (let indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) { const inputs = getInputs(); const mergeProgram = new MergeProgram([batch, indicesSize / 2]); const firstPass = indices === null ? 1 : 0; const customValues = [[lastDim], [firstPass], [kPow2]]; const prevIndices = indices; indices = backend.runWebGLProgram(mergeProgram, inputs, 'int32', customValues); disposeIntermediateTensorInfoOrNull(backend, prevIndices); // Step 3: rebuild const len = kPow2 / 2; const dir = len * 2; for (let inc = len; inc >= 1; inc /= 2) { runSwap(dir, inc, indices.shape); } } // Keep only the requested top K results instead of kPow2 let prevIndices = indices; indices = slice({ inputs: { x: indices }, backend, attrs: { begin: 0, size: [batch, k] } }); disposeIntermediateTensorInfoOrNull(backend, prevIndices); // Gather values on last dimension let values = gatherV2({ inputs: { x: x2D, indices }, backend, attrs: { axis: 1, batchDims: 1 } }); disposeIntermediateTensorInfoOrNull(backend, x2D); // Reshape back to the original input shape, except that the last // dimension is k. const newShape = xShape.slice(0, -1); newShape.push(k); prevIndices = indices; indices = reshape({ inputs: { x: indices }, attrs: { shape: newShape }, backend }); disposeIntermediateTensorInfoOrNull(backend, prevIndices); const prevValues = values; values = reshape({ inputs: { x: values }, attrs: { shape: newShape }, backend }); disposeIntermediateTensorInfoOrNull(backend, prevValues); return [values, indices]; } const topKConfig = { kernelName: tf.TopK, backendName: 'webgl', kernelFunc: topK }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class TransformProgram { constructor(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape) { this.variableNames = ['Image', 'Transforms']; this.outputShape = outShape; const interpolationModeId = interpolation === 'nearest' ? 1 : 2; let fillModeId; switch (fillMode) { case 'constant': fillModeId = 1; break; case 'reflect': fillModeId = 2; break; case 'wrap': fillModeId = 3; break; case 'nearest': fillModeId = 4; break; default: fillModeId = 1; break; } this.userCode = ` float mapCoord(float outCoord, float len) { float inCoord = outCoord; if(${fillModeId} == 2) { if (inCoord < 0.0) { if (len <= 1.0) { inCoord = 0.0; } else { float sz2 = 2.0 * len; if (inCoord < sz2) { inCoord = sz2 * float(int(float(-inCoord / sz2))) + inCoord; } inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0; } } else if (inCoord > len - 1.0) { if (len <= 1.0) { inCoord = 0.0; } else { float sz2 = 2.0 * len; inCoord -= sz2 * float(int(float(inCoord / sz2))); if (inCoord >= len) { inCoord = sz2 - inCoord - 1.0; } } } return clamp(inCoord, 0.0, len - 1.0); } else if (${fillModeId} == 3) { if (inCoord < 0.0) { if (len <= 1.0) { inCoord = 0.0; } else { float sz = len - 1.0; inCoord += len * (float(int(float(-inCoord / sz))) + 1.0); } } else if (inCoord > len - 1.0) { if (len <= 1.0) { inCoord = 0.0; } else { float sz = len - 1.0; inCoord -= len * float(int(float(inCoord / sz))); } } return clamp(inCoord, 0.0, len - 1.0); } else if (${fillModeId} == 4) { return clamp(outCoord, 0.0, len - 1.0); } else { return outCoord; } } float readWithFillValue(int batch, int coordY, int coordX, int channel) { float outputValue; if (0 <= coordY && coordY < ${imageHeight} && 0 <= coordX && coordX < ${imageWidth}) { outputValue = getImage(batch, coordY, coordX, channel); } else { outputValue = float(${fillValue}); } return outputValue; } void main() { ivec4 coords = getOutputCoords(); float outputValue; int batch = coords[0]; int x = coords[2]; int y = coords[1]; int channel = coords[3]; float xf = float(x); float yf = float(y); float a1 = getTransforms(batch, 0); float a2 = getTransforms(batch, 1); float a3 = getTransforms(batch, 2); float b1 = getTransforms(batch, 3); float b2 = getTransforms(batch, 4); float b3 = getTransforms(batch, 5); float c1 = getTransforms(batch, 6); float c2 = getTransforms(batch, 7); float projection = c1 * xf + c2 * yf + 1.0; if (projection == 0.0) { outputValue = float(${fillValue}); } else { float inX = (a1 * xf + a2 * yf + a3) / projection; float inY = (b1 * xf + b2 * yf + b3) / projection; float mapX = mapCoord(inX, float(${imageWidth})); float mapY = mapCoord(inY, float(${imageHeight})); if (${interpolationModeId} == 1) { int coordY = int(round(mapY)); int coordX = int(round(mapX)); outputValue = readWithFillValue(batch, coordY, coordX, channel); } else { float yFloor = floor(mapY); float xFloor = floor(mapX); float yCeil = yFloor + 1.0; float xCeil = xFloor + 1.0; float valueYFloor = (xCeil - mapX) * readWithFillValue(batch, int(yFloor), int(xFloor), channel) + (mapX - xFloor) * readWithFillValue(batch, int(yFloor), int(xCeil), channel); float valueYCeil = (xCeil - mapX) * readWithFillValue(batch, int(yCeil), int(xFloor), channel) + (mapX - xFloor) * readWithFillValue(batch, int(yCeil), int(xCeil), channel); outputValue = (yCeil - mapY) * valueYFloor + (mapY - yFloor) * valueYCeil; } } setOutput(outputValue); } `; } } /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function transform(args) { const { inputs, backend, attrs } = args; const { image, transforms } = inputs; const { interpolation, fillMode, fillValue, outputShape } = attrs; const [batch, imageHeight, imageWidth, numChannels] = image.shape; const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth]; const outShape = [batch, outHeight, outWidth, numChannels]; const program = new TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape); return backend.runWebGLProgram(program, [image, transforms], 'float32'); } const transformConfig = { kernelName: tf.Transform, backendName: 'webgl', kernelFunc: transform }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the License); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function unique(args) { const { inputs, attrs, backend } = args; const { axis } = attrs; const { x } = inputs; assertNotComplex(x, 'unique'); // For now, always forward calculation to the CPU backend. console.warn('WARNING: ', 'UI might be locked temporarily as data is being downloaded'); const values = backend.readSync(x.dataId); const { outputValues, outputShape, indices } = uniqueImplCPU(values, axis, x.shape, x.dtype); return [ backend.makeTensorInfo(outputShape, x.dtype, outputValues), backend.makeTensorInfo([indices.length], 'int32', indices), ]; } const uniqueConfig = { kernelName: tf.Unique, backendName: 'webgl', kernelFunc: unique, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function unpack(args) { const { inputs, backend, attrs } = args; const { value } = inputs; let { axis } = attrs; if (axis < 0) { axis += value.shape.length; } const x = value; const xRank = x.shape.length; const num = value.shape[axis]; const outShape = new Array(xRank - 1); let outIndex = 0; for (let i = 0; i < xRank; i++) { if (i !== axis) { outShape[outIndex++] = x.shape[i]; } } const toDispose = []; const begin = new Array(xRank).fill(0); const size = x.shape.slice(); size[axis] = 1; const res = new Array(num); for (let i = 0; i < res.length; i++) { begin[axis] = i; const sliced = slice({ inputs: { x }, backend, attrs: { begin, size } }); const reshaped = reshape({ inputs: { x: sliced }, backend, attrs: { shape: outShape } }); res[i] = reshaped; toDispose.push(sliced); } toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return res; } const unpackConfig = { kernelName: tf.Unpack, backendName: 'webgl', kernelFunc: unpack }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ class SegmentOpProgram { constructor(segOpInfo, segOpType) { this.variableNames = ['x', 'segmentIds']; const windowSize = segOpInfo.windowSize; const batchSize = segOpInfo.batchSize; const inSize = segOpInfo.inSize; const numSegments = segOpInfo.numSegments; const outSize = numSegments * Math.ceil(inSize / windowSize); this.outputShape = [batchSize, outSize]; const initializationValue = '0.0'; const returnValue = `sumValue`; const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; const windowSizeVec4Remainder = windowSize % 4; const updateSnippet = ` sumValue += dot(values, segFilter); `; let checkValueOutOfBounds = ''; if (inSize % windowSize > 0) { checkValueOutOfBounds = ` if (inIdx < 0 || inIdx >= ${inSize}) { return initializationValue; } `; } let checkSegmentIdOutOfBounds = ''; if (inSize % windowSize > 0) { checkSegmentIdOutOfBounds = ` if (inIdx < 0 || inIdx >= ${inSize}) { return -1.0; } `; } this.userCode = ` const float initializationValue = ${initializationValue}; float getValue(int batch, int inIdx) { ${checkValueOutOfBounds} return getX(batch, inIdx); } float getSegmentIdAtIndex(int inIdx) { ${checkSegmentIdOutOfBounds} return getSegmentIds(inIdx); } void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int outIdx = coords[1]; int inOffset = int(floor(float(outIdx) / float( ${numSegments})) * float(${windowSize})); int currentSeg = int(mod(float(outIdx), float(${numSegments}))); float sumValue = 0.0; for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) { int inIdx = inOffset + i; vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), getValue(batch, inIdx + 3) ); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0 ); ${updateSnippet} } int inIdx = inOffset + ${windowSizeNearestVec4}; if (${windowSizeVec4Remainder === 1}) { vec4 values = vec4( getValue(batch, inIdx), initializationValue, initializationValue, initializationValue ); int inIdxSeg = int(getSegmentIdAtIndex(inIdx)); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, 0, 0, 0 ); ${updateSnippet} } else if (${windowSizeVec4Remainder === 2}) { vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), initializationValue, initializationValue ); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0, 0, 0 ); ${updateSnippet} } else if (${windowSizeVec4Remainder === 3}) { vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), initializationValue ); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0, 0 ); ${updateSnippet} } setOutput(${returnValue}); } `; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function unsortedSegmentSum(args) { const { inputs, backend, attrs } = args; const { x, segmentIds } = inputs; const { numSegments } = attrs; const xRank = x.shape.length; const toDispose = []; let axis = 0; const permutation = tf.backend_util.getAxesPermutation([axis], xRank); let permutedX = x; if (permutation != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } }); toDispose.push(permutedX); axis = tf.backend_util.getInnerMostAxes(1, xRank)[0]; } const outShape = tf.backend_util.segment_util.computeOutShape(permutedX.shape, axis, numSegments); const inSize = tf.util.sizeFromShape([permutedX.shape[axis]]); const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } }); toDispose.push(a2D); const outputDType = tf.sumOutType(x.dtype); const segOpCompute = (x, segOpType, segmentIds, dtype, numSegments) => { const batchSize = x.shape[0]; const inSize = x.shape[1]; const windowSize = tf.backend_util.segment_util.segOpComputeOptimalWindowSize(inSize, numSegments); const segOpInfo = { windowSize, inSize, batchSize, numSegments }; const program = new SegmentOpProgram(segOpInfo, segOpType); const output = backend.compileAndRun(program, [x, segmentIds], dtype); toDispose.push(output); // No need to run another GPGPU program. if (output.shape[1] === numSegments) { return output; } const rangeInfo = range({ backend, attrs: { start: 0, stop: numSegments, step: 1, dtype: 'float32' } }); const tileInfo = tile({ inputs: { x: rangeInfo }, backend, attrs: { reps: [inSize / windowSize] } }); toDispose.push(rangeInfo); toDispose.push(tileInfo); const result = segOpCompute(output, segOpType, tileInfo, dtype, numSegments); return result; }; const segOpResult = segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments); const reshaped = reshape({ inputs: { x: segOpResult }, backend, attrs: { shape: outShape } }); let result = reshaped; if (permutation != null) { toDispose.push(reshaped); const perm = tf.backend_util.getUndoAxesPermutation(permutation); result = transpose({ inputs: { x: result }, backend, attrs: { perm } }); } toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return result; } const unsortedSegmentSumConfig = { kernelName: tf.UnsortedSegmentSum, backendName: 'webgl', kernelFunc: unsortedSegmentSum }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // List all kernel configs here const kernelConfigs = [ _fusedMatMulConfig, absConfig, acosConfig, acoshConfig, addConfig, addNConfig, allConfig, anyConfig, argMaxConfig, argMinConfig, asinConfig, asinhConfig, atanConfig, atan2Config, atanhConfig, avgPoolConfig, avgPool3DConfig, avgPool3DGradConfig, avgPoolGradConfig, batchMatMulConfig, batchNormConfig, batchToSpaceNDConfig, bincountConfig, bitwiseAndConfig, broadcastArgsConfig, castConfig, ceilConfig, clipByValueConfig, complexConfig, complexAbsConfig, concatConfig, conv2DConfig, conv2DBackpropFilterConfig, conv2DBackpropInputConfig, conv3DConfig, conv3DBackpropFilterV2Config, conv3DBackpropInputConfig, cosConfig, coshConfig, cropAndResizeConfig, cumprodConfig, cumsumConfig, denseBincountConfig, depthToSpaceConfig, depthwiseConv2dNativeConfig, depthwiseConv2dNativeBackpropFilterConfig, depthwiseConv2dNativeBackpropInputConfig, diagConfig, dilation2DConfig, einsumConfig, eluConfig, eluGradConfig, equalConfig, erfConfig, expConfig, expandDimsConfig, expm1Config, fftConfig, fillConfig, flipLeftRightConfig, floorConfig, floorDivConfig, fromPixelsConfig, fusedConv2DConfig, fusedDepthwiseConv2DConfig, gatherNdConfig, gatherV2Config, greaterConfig, greaterEqualConfig, identityConfig, ifftConfig, imagConfig, isFiniteConfig, isInfConfig, isNaNConfig, leakyReluConfig, lessConfig, lessEqualConfig, linSpaceConfig, logConfig, log1pConfig, logicalAndConfig, logicalNotConfig, logicalOrConfig, LRNConfig, LRNGradConfig, maxConfig, maximumConfig, maxPoolConfig, maxPool3DConfig, maxPool3DGradConfig, maxPoolGradConfig, maxPoolWithArgmaxConfig, meanConfig, minConfig, minimumConfig, mirrorPadConfig, modConfig, multinomialConfig, multiplyConfig, negConfig, nonMaxSuppressionV3Config, nonMaxSuppressionV4Config, nonMaxSuppressionV5Config, notEqualConfig, oneHotConfig, onesLikeConfig, packConfig, padV2Config, powConfig, preluConfig, prodConfig, raggedGatherConfig, raggedRangeConfig, raggedTensorToTensorConfig, rangeConfig, realConfig, realDivConfig, reciprocalConfig, reluConfig, relu6Config, reshapeConfig, resizeBilinearConfig, resizeBilinearGradConfig, resizeNearestNeighborConfig, resizeNearestNeighborGradConfig, reverseConfig, rotateWithOffsetConfig, roundConfig, rsqrtConfig, scatterNdConfig, searchSortedConfig, selectConfig, seluConfig, sigmoidConfig, signConfig, sinConfig, sinhConfig, sliceConfig, softmaxConfig, softplusConfig, spaceToBatchNDConfig, sparseFillEmptyRowsConfig, sparseReshapeConfig, sparseSegmentMeanConfig, sparseSegmentSumConfig, sparseToDenseConfig, splitVConfig, sqrtConfig, squareConfig, squaredDifferenceConfig, staticRegexReplaceConfig, stepConfig, stridedSliceConfig, stringNGramsConfig, stringSplitConfig, stringToHashBucketFastConfig, subConfig, sumConfig, tanConfig, tanhConfig, tensorScatterUpdateConfig, tileConfig, topKConfig, transformConfig, transposeConfig, uniqueConfig, unpackConfig, unsortedSegmentSumConfig, zerosLikeConfig ]; for (const kernelConfig of kernelConfigs) { tf.registerKernel(kernelConfig); } exports.GPGPUContext = GPGPUContext; exports.MathBackendWebGL = MathBackendWebGL; exports.forceHalfFloat = forceHalfFloat; exports.gpgpu_util = gpgpu_util; exports.setWebGLContext = setWebGLContext; exports.version_webgl = version; exports.webgl = webgl; exports.webgl_util = webgl_util; })); //# sourceMappingURL=tf-backend-webgl.es2017.js.map