/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { broadcast_util, upcastType, util } from '@tensorflow/tfjs-core'; import { mapActivationToShaderProgram } from '../kernel_utils/kernel_funcs_utils'; import { MatMulPackedProgram } from '../mulmat_packed_gpu'; import { multiply } from './Multiply'; import { reshape } from './Reshape'; import { sum } from './Sum'; import { transpose } from './Transpose'; // Empirically determined minimal shared dimension in matmul before we forward // to a.mul(b).sum() in order to take advantage of GPU parallelism. See // https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks. export const MATMUL_SHARED_DIM_THRESHOLD = 1000; export function batchMatMulImpl({ a, b, transposeA, transposeB, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) { const aRank = a.shape.length; const bRank = b.shape.length; const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1]; const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2]; const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2]; const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1]; const outerDimsA = a.shape.slice(0, -2); const outerDimsB = b.shape.slice(0, -2); const batchDimA = util.sizeFromShape(outerDimsA); const batchDimB = util.sizeFromShape(outerDimsB); const outShapeOuterDims = broadcast_util.assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2)); const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]); util.assert(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` + `${innerShapeB}) of Tensors with shapes ${a.shape} and ` + `${b.shape} and transposeA=${transposeA}` + ` and transposeB=${transposeB} must match.`); const a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] : [batchDimA, outerShapeA, innerShapeA]; const b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] : [batchDimB, innerShapeB, outerShapeB]; // The rest of the implementation is designed to operate on rank-3 tensors const a3d = reshape({ inputs: { x: a }, backend, attrs: { shape: a3dShape } }); const b3d = reshape({ inputs: { x: b }, backend, attrs: { shape: b3dShape } }); const intermediates = [a3d, b3d]; const batchDim = Math.max(batchDimA, batchDimB); const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2]; const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; const fusedActivation = activation != null ? mapActivationToShaderProgram(activation, true) : null; const containsFusedOps = hasBias || hasPreluActivationWeights || hasLeakyreluAlpha || fusedActivation != null; let out; // Since the matrices are vectors, it is faster to call mul().sum() // because sum() is O(sqrt(N)) due to divide-and-conquer. if ((outerShapeA === 1 || outerShapeB === 1) && sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) { let aVec = a3d; let bVec = b3d; if (transposeA) { aVec = transpose({ inputs: { x: a3d }, backend, attrs: { perm: [0, 2, 1] } }); intermediates.push(aVec); } if (transposeB) { bVec = transpose({ inputs: { x: b3d }, backend, attrs: { perm: [0, 2, 1] } }); intermediates.push(bVec); } const shouldReshapeA = outerShapeB !== 1; const shouldReshapeB = outerShapeB === 1; let aVec3d = aVec; if (shouldReshapeA) { aVec3d = reshape({ inputs: { x: aVec }, backend, attrs: { shape: [batchDim, sharedDim, 1] } }); intermediates.push(aVec3d); } const axis = outerShapeB === 1 ? 2 : 1; let bVec3d = bVec; if (shouldReshapeB) { bVec3d = reshape({ inputs: { x: bVec }, backend, attrs: { shape: [batchDim, 1, sharedDim] } }); intermediates.push(bVec3d); } const product = multiply({ inputs: { a: aVec3d, b: bVec3d }, backend }); out = sum({ inputs: { x: product }, backend, attrs: { axis, keepDims: true } }); intermediates.push(product); } else { const dtype = upcastType(a.dtype, b.dtype); const program = new MatMulPackedProgram(a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const inputs = [a3d, b3d]; if (bias != null) { inputs.push(bias); } if (hasPreluActivationWeights) { inputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', util.createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } out = backend.runWebGLProgram(program, inputs, dtype); } const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: outShape } }); intermediates.push(out); for (const i of intermediates) { backend.disposeIntermediateTensorInfo(i); } return outReshaped; } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiQmF0Y2hNYXRNdWxfaW1wbC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMva2VybmVscy9CYXRjaE1hdE11bF9pbXBsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBZSxjQUFjLEVBQWMsVUFBVSxFQUFFLElBQUksRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBR2pHLE9BQU8sRUFBQyw0QkFBNEIsRUFBQyxNQUFNLG9DQUFvQyxDQUFDO0FBQ2hGLE9BQU8sRUFBQyxtQkFBbUIsRUFBQyxNQUFNLHNCQUFzQixDQUFDO0FBRXpELE9BQU8sRUFBQyxRQUFRLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFDcEMsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxTQUFTLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFFdEMsOEVBQThFO0FBQzlFLHVFQUF1RTtBQUN2RSxvRUFBb0U7QUFDcEUsTUFBTSxDQUFDLE1BQU0sMkJBQTJCLEdBQUcsSUFBSSxDQUFDO0FBY2hELE1BQU0sVUFBVSxlQUFlLENBQUMsRUFDOUIsQ0FBQyxFQUNELENBQUMsRUFDRCxVQUFVLEVBQ1YsVUFBVSxFQUNWLE9BQU8sRUFDUCxJQUFJLEdBQUcsSUFBSSxFQUNYLHNCQUFzQixHQUFHLElBQUksRUFDN0IsY0FBYyxHQUFHLENBQUMsRUFDbEIsVUFBVSxHQUFHLElBQUksRUFDQztJQUNsQixNQUFNLEtBQUssR0FBRyxDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztJQUM3QixNQUFNLEtBQUssR0FBRyxDQUFDLENBQUMsS0FBSyxDQUFDLE1BQU0sQ0FBQztJQUU3QixNQUFNLFdBQVcsR0FBRyxVQUFVLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsS0FBSyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEtBQUssR0FBRyxDQUFDLENBQUMsQ0FBQztJQUN6RSxNQUFNLFdBQVcsR0FBRyxVQUFVLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsS0FBSyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEtBQUssR0FBRyxDQUFDLENBQUMsQ0FBQztJQUV6RSxNQUFNLFdBQVcsR0FBRyxVQUFVLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsS0FBSyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEtBQUssR0FBRyxDQUFDLENBQUMsQ0FBQztJQUN6RSxNQUFNLFdBQVcsR0FBRyxVQUFVLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxLQUFLLENBQUMsS0FBSyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLEtBQUssR0FBRyxDQUFDLENBQUMsQ0FBQztJQUV6RSxNQUFNLFVBQVUsR0FBRyxDQUFDLENBQUMsS0FBSyxDQUFDLEtBQUssQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUN4QyxNQUFNLFVBQVUsR0FBRyxDQUFDLENBQUMsS0FBSyxDQUFDLEtBQUssQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUV4QyxNQUFNLFNBQVMsR0FBRyxJQUFJLENBQUMsYUFBYSxDQUFDLFVBQVUsQ0FBQyxDQUFDO0lBQ2pELE1BQU0sU0FBUyxHQUFHLElBQUksQ0FBQyxhQUFhLENBQUMsVUFBVSxDQUFDLENBQUM7SUFFakQsTUFBTSxpQkFBaUIsR0FBRyxjQUFjLENBQUMsMEJBQTBCLENBQy9ELENBQUMsQ0FBQyxLQUFLLENBQUMsS0FBSyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxLQUFLLENBQUMsS0FBSyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDaEQsTUFBTSxRQUFRLEdBQUcsaUJBQWlCLENBQUMsTUFBTSxDQUFDLENBQUMsV0FBVyxFQUFFLFdBQVcsQ0FBQyxDQUFDLENBQUM7SUFFdEUsSUFBSSxDQUFDLE1BQU0sQ0FDUCxXQUFXLEtBQUssV0FBVyxFQUMzQixHQUFHLEVBQUUsQ0FBQyxrQ0FBa0MsV0FBVyxTQUFTO1FBQ3hELEdBQUcsV0FBVyw0QkFBNEIsQ0FBQyxDQUFDLEtBQUssT0FBTztRQUN4RCxHQUFHLENBQUMsQ0FBQyxLQUFLLG1CQUFtQixVQUFVLEVBQUU7UUFDekMsbUJBQW1CLFVBQVUsY0FBYyxDQUFDLENBQUM7SUFFckQsTUFBTSxRQUFRLEdBQTZCLFVBQVUsQ0FBQyxDQUFDO1FBQ25ELENBQUMsU0FBUyxFQUFFLFdBQVcsRUFBRSxXQUFXLENBQUMsQ0FBQyxDQUFDO1FBQ3ZDLENBQUMsU0FBUyxFQUFFLFdBQVcsRUFBRSxXQUFXLENBQUMsQ0FBQztJQUMxQyxNQUFNLFFBQVEsR0FBNkIsVUFBVSxDQUFDLENBQUM7UUFDbkQsQ0FBQyxTQUFTLEVBQUUsV0FBVyxFQUFFLFdBQVcsQ0FBQyxDQUFDLENBQUM7UUFDdkMsQ0FBQyxTQUFTLEVBQUUsV0FBVyxFQUFFLFdBQVcsQ0FBQyxDQUFDO0lBRTFDLDBFQUEwRTtJQUMxRSxNQUFNLEdBQUcsR0FBRyxPQUFPLENBQUMsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsQ0FBQyxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxRQUFRLEVBQUMsRUFBQyxDQUFDLENBQUM7SUFDekUsTUFBTSxHQUFHLEdBQUcsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLENBQUMsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsUUFBUSxFQUFDLEVBQUMsQ0FBQyxDQUFDO0lBRXpFLE1BQU0sYUFBYSxHQUFpQixDQUFDLEdBQUcsRUFBRSxHQUFHLENBQUMsQ0FBQztJQUUvQyxNQUFNLFFBQVEsR0FBRyxJQUFJLENBQUMsR0FBRyxDQUFDLFNBQVMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUNoRCxNQUFNLFNBQVMsR0FBRyxVQUFVLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFFM0QsTUFBTSxPQUFPLEdBQUcsSUFBSSxJQUFJLElBQUksQ0FBQztJQUM3QixNQUFNLHlCQUF5QixHQUFHLHNCQUFzQixJQUFJLElBQUksQ0FBQztJQUNqRSxNQUFNLGlCQUFpQixHQUFHLFVBQVUsS0FBSyxXQUFXLENBQUM7SUFDckQsTUFBTSxlQUFlLEdBQUcsVUFBVSxJQUFJLElBQUksQ0FBQyxDQUFDO1FBQ3hDLDRCQUE0QixDQUFDLFVBQVUsRUFBRSxJQUFJLENBQUMsQ0FBQyxDQUFDO1FBQ2hELElBQUksQ0FBQztJQUNULE1BQU0sZ0JBQWdCLEdBQUcsT0FBTyxJQUFJLHlCQUF5QjtRQUN6RCxpQkFBaUIsSUFBSSxlQUFlLElBQUksSUFBSSxDQUFDO0lBQ2pELElBQUksR0FBZSxDQUFDO0lBRXBCLG1FQUFtRTtJQUNuRSx5REFBeUQ7SUFDekQsSUFBSSxDQUFDLFdBQVcsS0FBSyxDQUFDLElBQUksV0FBVyxLQUFLLENBQUMsQ0FBQztRQUN4QyxTQUFTLEdBQUcsMkJBQTJCLElBQUksZ0JBQWdCLEtBQUssS0FBSyxFQUFFO1FBQ3pFLElBQUksSUFBSSxHQUFHLEdBQUcsQ0FBQztRQUNmLElBQUksSUFBSSxHQUFHLEdBQUcsQ0FBQztRQUNmLElBQUksVUFBVSxFQUFFO1lBQ2QsSUFBSSxHQUFHLFNBQVMsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxHQUFHLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBQyxFQUFDLENBQUMsQ0FBQztZQUN4RSxhQUFhLENBQUMsSUFBSSxDQUFDLElBQUksQ0FBQyxDQUFDO1NBQzFCO1FBQ0QsSUFBSSxVQUFVLEVBQUU7WUFDZCxJQUFJLEdBQUcsU0FBUyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLEdBQUcsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxJQUFJLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFDLEVBQUMsQ0FBQyxDQUFDO1lBQ3hFLGFBQWEsQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLENBQUM7U0FDMUI7UUFFRCxNQUFNLGNBQWMsR0FBRyxXQUFXLEtBQUssQ0FBQyxDQUFDO1FBQ3pDLE1BQU0sY0FBYyxHQUFHLFdBQVcsS0FBSyxDQUFDLENBQUM7UUFFekMsSUFBSSxNQUFNLEdBQUcsSUFBSSxDQUFDO1FBQ2xCLElBQUksY0FBYyxFQUFFO1lBQ2xCLE1BQU0sR0FBRyxPQUFPLENBQUM7Z0JBQ2YsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLElBQUksRUFBQztnQkFDakIsT0FBTztnQkFDUCxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsQ0FBQyxRQUFRLEVBQUUsU0FBUyxFQUFFLENBQUMsQ0FBQyxFQUFDO2FBQ3pDLENBQUMsQ0FBQztZQUVILGFBQWEsQ0FBQyxJQUFJLENBQUMsTUFBTSxDQUFDLENBQUM7U0FDNUI7UUFFRCxNQUFNLElBQUksR0FBRyxXQUFXLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUV2QyxJQUFJLE1BQU0sR0FBRyxJQUFJLENBQUM7UUFDbEIsSUFBSSxjQUFjLEVBQUU7WUFDbEIsTUFBTSxHQUFHLE9BQU8sQ0FBQztnQkFDZixNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsSUFBSSxFQUFDO2dCQUNqQixPQUFPO2dCQUNQLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxDQUFDLFFBQVEsRUFBRSxDQUFDLEVBQUUsU0FBUyxDQUFDLEVBQUM7YUFDekMsQ0FBQyxDQUFDO1lBRUgsYUFBYSxDQUFDLElBQUksQ0FBQyxNQUFNLENBQUMsQ0FBQztTQUM1QjtRQUVELE1BQU0sT0FBTyxHQUFHLFFBQVEsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxNQUFNLEVBQUUsQ0FBQyxFQUFFLE1BQU0sRUFBQyxFQUFFLE9BQU8sRUFBQyxDQUFDLENBQUM7UUFDcEUsR0FBRyxHQUFHLEdBQUcsQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxPQUFPLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFFLFFBQVEsRUFBRSxJQUFJLEVBQUMsRUFBQyxDQUFDLENBQUM7UUFDMUUsYUFBYSxDQUFDLElBQUksQ0FBQyxPQUFPLENBQUMsQ0FBQztLQUM3QjtTQUFNO1FBQ0wsTUFBTSxLQUFLLEdBQUcsVUFBVSxDQUFDLENBQUMsQ0FBQyxLQUFLLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDO1FBRTNDLE1BQU0sT0FBTyxHQUFHLElBQUksbUJBQW1CLENBQ25DLFFBQVEsRUFBRSxRQUFRLEVBQUUsQ0FBQyxRQUFRLEVBQUUsV0FBVyxFQUFFLFdBQVcsQ0FBQyxFQUFFLFVBQVUsRUFDcEUsVUFBVSxFQUFFLE9BQU8sRUFBRSxlQUFlLEVBQUUseUJBQXlCLEVBQy9ELGlCQUFpQixDQUFDLENBQUM7UUFFdkIsTUFBTSxNQUFNLEdBQWlCLENBQUMsR0FBRyxFQUFFLEdBQUcsQ0FBQyxDQUFDO1FBQ3hDLElBQUksSUFBSSxJQUFJLElBQUksRUFBRTtZQUNoQixNQUFNLENBQUMsSUFBSSxDQUFDLElBQUksQ0FBQyxDQUFDO1NBQ25CO1FBQ0QsSUFBSSx5QkFBeUIsRUFBRTtZQUM3QixNQUFNLENBQUMsSUFBSSxDQUFDLHNCQUFzQixDQUFDLENBQUM7U0FDckM7UUFDRCxJQUFJLGlCQUFpQixFQUFFO1lBQ3JCLE1BQU0sZUFBZSxHQUFHLE9BQU8sQ0FBQyxjQUFjLENBQzFDLEVBQUUsRUFBRSxTQUFTLEVBQ2IsSUFBSSxDQUFDLGlCQUFpQixDQUFDLGNBQXNDLEVBQUUsU0FBUyxDQUFDLENBQUMsQ0FBQztZQUMvRSxNQUFNLENBQUMsSUFBSSxDQUFDLGVBQWUsQ0FBQyxDQUFDO1lBQzdCLGFBQWEsQ0FBQyxJQUFJLENBQUMsZUFBZSxDQUFDLENBQUM7U0FDckM7UUFFRCxHQUFHLEdBQUcsT0FBTyxDQUFDLGVBQWUsQ0FBQyxPQUFPLEVBQUUsTUFBTSxFQUFFLEtBQUssQ0FBQyxDQUFDO0tBQ3ZEO0lBRUQsTUFBTSxXQUFXLEdBQ2IsT0FBTyxDQUFDLEVBQUMsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLEdBQUcsRUFBQyxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUUsRUFBQyxLQUFLLEVBQUUsUUFBUSxFQUFDLEVBQUMsQ0FBQyxDQUFDO0lBQ25FLGFBQWEsQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLENBQUM7SUFDeEIsS0FBSyxNQUFNLENBQUMsSUFBSSxhQUFhLEVBQUU7UUFDN0IsT0FBTyxDQUFDLDZCQUE2QixDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQzFDO0lBQ0QsT0FBTyxXQUFXLENBQUM7QUFDckIsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIGJyb2FkY2FzdF91dGlsLCBUZW5zb3JJbmZvLCB1cGNhc3RUeXBlLCB1dGlsfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge01hdGhCYWNrZW5kV2ViR0x9IGZyb20gJy4uL2JhY2tlbmRfd2ViZ2wnO1xuaW1wb3J0IHttYXBBY3RpdmF0aW9uVG9TaGFkZXJQcm9ncmFtfSBmcm9tICcuLi9rZXJuZWxfdXRpbHMva2VybmVsX2Z1bmNzX3V0aWxzJztcbmltcG9ydCB7TWF0TXVsUGFja2VkUHJvZ3JhbX0gZnJvbSAnLi4vbXVsbWF0X3BhY2tlZF9ncHUnO1xuXG5pbXBvcnQge211bHRpcGx5fSBmcm9tICcuL011bHRpcGx5JztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9SZXNoYXBlJztcbmltcG9ydCB7c3VtfSBmcm9tICcuL1N1bSc7XG5pbXBvcnQge3RyYW5zcG9zZX0gZnJvbSAnLi9UcmFuc3Bvc2UnO1xuXG4vLyBFbXBpcmljYWxseSBkZXRlcm1pbmVkIG1pbmltYWwgc2hhcmVkIGRpbWVuc2lvbiBpbiBtYXRtdWwgYmVmb3JlIHdlIGZvcndhcmRcbi8vIHRvIGEubXVsKGIpLnN1bSgpIGluIG9yZGVyIHRvIHRha2UgYWR2YW50YWdlIG9mIEdQVSBwYXJhbGxlbGlzbS4gU2VlXG4vLyBodHRwczovL2dpdGh1Yi5jb20vdGVuc29yZmxvdy90ZmpzLWNvcmUvcHVsbC8xMzc5IGZvciBiZW5jaG1hcmtzLlxuZXhwb3J0IGNvbnN0IE1BVE1VTF9TSEFSRURfRElNX1RIUkVTSE9MRCA9IDEwMDA7XG5cbnR5cGUgQmF0Y2hNYXRNdWxDb25maWcgPSB7XG4gIGE6IFRlbnNvckluZm8sXG4gIGI6IFRlbnNvckluZm8sXG4gIHRyYW5zcG9zZUE6IGJvb2xlYW4sXG4gIHRyYW5zcG9zZUI6IGJvb2xlYW4sXG4gIGJhY2tlbmQ6IE1hdGhCYWNrZW5kV2ViR0wsXG4gIGJpYXM/OiBUZW5zb3JJbmZvLFxuICBwcmVsdUFjdGl2YXRpb25XZWlnaHRzPzogVGVuc29ySW5mbyxcbiAgbGVha3lyZWx1QWxwaGE/OiBudW1iZXIsXG4gIGFjdGl2YXRpb24/OiBiYWNrZW5kX3V0aWwuQWN0aXZhdGlvblxufTtcblxuZXhwb3J0IGZ1bmN0aW9uIGJhdGNoTWF0TXVsSW1wbCh7XG4gIGEsXG4gIGIsXG4gIHRyYW5zcG9zZUEsXG4gIHRyYW5zcG9zZUIsXG4gIGJhY2tlbmQsXG4gIGJpYXMgPSBudWxsLFxuICBwcmVsdUFjdGl2YXRpb25XZWlnaHRzID0gbnVsbCxcbiAgbGVha3lyZWx1QWxwaGEgPSAwLFxuICBhY3RpdmF0aW9uID0gbnVsbFxufTogQmF0Y2hNYXRNdWxDb25maWcpOiBUZW5zb3JJbmZvIHtcbiAgY29uc3QgYVJhbmsgPSBhLnNoYXBlLmxlbmd0aDtcbiAgY29uc3QgYlJhbmsgPSBiLnNoYXBlLmxlbmd0aDtcblxuICBjb25zdCBpbm5lclNoYXBlQSA9IHRyYW5zcG9zZUEgPyBhLnNoYXBlW2FSYW5rIC0gMl0gOiBhLnNoYXBlW2FSYW5rIC0gMV07XG4gIGNvbnN0IGlubmVyU2hhcGVCID0gdHJhbnNwb3NlQiA/IGIuc2hhcGVbYlJhbmsgLSAxXSA6IGIuc2hhcGVbYlJhbmsgLSAyXTtcblxuICBjb25zdCBvdXRlclNoYXBlQSA9IHRyYW5zcG9zZUEgPyBhLnNoYXBlW2FSYW5rIC0gMV0gOiBhLnNoYXBlW2FSYW5rIC0gMl07XG4gIGNvbnN0IG91dGVyU2hhcGVCID0gdHJhbnNwb3NlQiA/IGIuc2hhcGVbYlJhbmsgLSAyXSA6IGIuc2hhcGVbYlJhbmsgLSAxXTtcblxuICBjb25zdCBvdXRlckRpbXNBID0gYS5zaGFwZS5zbGljZSgwLCAtMik7XG4gIGNvbnN0IG91dGVyRGltc0IgPSBiLnNoYXBlLnNsaWNlKDAsIC0yKTtcblxuICBjb25zdCBiYXRjaERpbUEgPSB1dGlsLnNpemVGcm9tU2hhcGUob3V0ZXJEaW1zQSk7XG4gIGNvbnN0IGJhdGNoRGltQiA9IHV0aWwuc2l6ZUZyb21TaGFwZShvdXRlckRpbXNCKTtcblxuICBjb25zdCBvdXRTaGFwZU91dGVyRGltcyA9IGJyb2FkY2FzdF91dGlsLmFzc2VydEFuZEdldEJyb2FkY2FzdFNoYXBlKFxuICAgICAgYS5zaGFwZS5zbGljZSgwLCAtMiksIGIuc2hhcGUuc2xpY2UoMCwgLTIpKTtcbiAgY29uc3Qgb3V0U2hhcGUgPSBvdXRTaGFwZU91dGVyRGltcy5jb25jYXQoW291dGVyU2hhcGVBLCBvdXRlclNoYXBlQl0pO1xuXG4gIHV0aWwuYXNzZXJ0KFxuICAgICAgaW5uZXJTaGFwZUEgPT09IGlubmVyU2hhcGVCLFxuICAgICAgKCkgPT4gYEVycm9yIGluIG1hdE11bDogaW5uZXIgc2hhcGVzICgke2lubmVyU2hhcGVBfSkgYW5kIChgICtcbiAgICAgICAgICBgJHtpbm5lclNoYXBlQn0pIG9mIFRlbnNvcnMgd2l0aCBzaGFwZXMgJHthLnNoYXBlfSBhbmQgYCArXG4gICAgICAgICAgYCR7Yi5zaGFwZX0gYW5kIHRyYW5zcG9zZUE9JHt0cmFuc3Bvc2VBfWAgK1xuICAgICAgICAgIGAgYW5kIHRyYW5zcG9zZUI9JHt0cmFuc3Bvc2VCfSBtdXN0IG1hdGNoLmApO1xuXG4gIGNvbnN0IGEzZFNoYXBlOiBbbnVtYmVyLCBudW1iZXIsIG51bWJlcl0gPSB0cmFuc3Bvc2VBID9cbiAgICAgIFtiYXRjaERpbUEsIGlubmVyU2hhcGVBLCBvdXRlclNoYXBlQV0gOlxuICAgICAgW2JhdGNoRGltQSwgb3V0ZXJTaGFwZUEsIGlubmVyU2hhcGVBXTtcbiAgY29uc3QgYjNkU2hhcGU6IFtudW1iZXIsIG51bWJlciwgbnVtYmVyXSA9IHRyYW5zcG9zZUIgP1xuICAgICAgW2JhdGNoRGltQiwgb3V0ZXJTaGFwZUIsIGlubmVyU2hhcGVCXSA6XG4gICAgICBbYmF0Y2hEaW1CLCBpbm5lclNoYXBlQiwgb3V0ZXJTaGFwZUJdO1xuXG4gIC8vIFRoZSByZXN0IG9mIHRoZSBpbXBsZW1lbnRhdGlvbiBpcyBkZXNpZ25lZCB0byBvcGVyYXRlIG9uIHJhbmstMyB0ZW5zb3JzXG4gIGNvbnN0IGEzZCA9IHJlc2hhcGUoe2lucHV0czoge3g6IGF9LCBiYWNrZW5kLCBhdHRyczoge3NoYXBlOiBhM2RTaGFwZX19KTtcbiAgY29uc3QgYjNkID0gcmVzaGFwZSh7aW5wdXRzOiB7eDogYn0sIGJhY2tlbmQsIGF0dHJzOiB7c2hhcGU6IGIzZFNoYXBlfX0pO1xuXG4gIGNvbnN0IGludGVybWVkaWF0ZXM6IFRlbnNvckluZm9bXSA9IFthM2QsIGIzZF07XG5cbiAgY29uc3QgYmF0Y2hEaW0gPSBNYXRoLm1heChiYXRjaERpbUEsIGJhdGNoRGltQik7XG4gIGNvbnN0IHNoYXJlZERpbSA9IHRyYW5zcG9zZUEgPyBhM2Quc2hhcGVbMV0gOiBhM2Quc2hhcGVbMl07XG5cbiAgY29uc3QgaGFzQmlhcyA9IGJpYXMgIT0gbnVsbDtcbiAgY29uc3QgaGFzUHJlbHVBY3RpdmF0aW9uV2VpZ2h0cyA9IHByZWx1QWN0aXZhdGlvbldlaWdodHMgIT0gbnVsbDtcbiAgY29uc3QgaGFzTGVha3lyZWx1QWxwaGEgPSBhY3RpdmF0aW9uID09PSAnbGVha3lyZWx1JztcbiAgY29uc3QgZnVzZWRBY3RpdmF0aW9uID0gYWN0aXZhdGlvbiAhPSBudWxsID9cbiAgICAgIG1hcEFjdGl2YXRpb25Ub1NoYWRlclByb2dyYW0oYWN0aXZhdGlvbiwgdHJ1ZSkgOlxuICAgICAgbnVsbDtcbiAgY29uc3QgY29udGFpbnNGdXNlZE9wcyA9IGhhc0JpYXMgfHwgaGFzUHJlbHVBY3RpdmF0aW9uV2VpZ2h0cyB8fFxuICAgICAgaGFzTGVha3lyZWx1QWxwaGEgfHwgZnVzZWRBY3RpdmF0aW9uICE9IG51bGw7XG4gIGxldCBvdXQ6IFRlbnNvckluZm87XG5cbiAgLy8gU2luY2UgdGhlIG1hdHJpY2VzIGFyZSB2ZWN0b3JzLCBpdCBpcyBmYXN0ZXIgdG8gY2FsbCBtdWwoKS5zdW0oKVxuICAvLyBiZWNhdXNlIHN1bSgpIGlzIE8oc3FydChOKSkgZHVlIHRvIGRpdmlkZS1hbmQtY29ucXVlci5cbiAgaWYgKChvdXRlclNoYXBlQSA9PT0gMSB8fCBvdXRlclNoYXBlQiA9PT0gMSkgJiZcbiAgICAgIHNoYXJlZERpbSA+IE1BVE1VTF9TSEFSRURfRElNX1RIUkVTSE9MRCAmJiBjb250YWluc0Z1c2VkT3BzID09PSBmYWxzZSkge1xuICAgIGxldCBhVmVjID0gYTNkO1xuICAgIGxldCBiVmVjID0gYjNkO1xuICAgIGlmICh0cmFuc3Bvc2VBKSB7XG4gICAgICBhVmVjID0gdHJhbnNwb3NlKHtpbnB1dHM6IHt4OiBhM2R9LCBiYWNrZW5kLCBhdHRyczoge3Blcm06IFswLCAyLCAxXX19KTtcbiAgICAgIGludGVybWVkaWF0ZXMucHVzaChhVmVjKTtcbiAgICB9XG4gICAgaWYgKHRyYW5zcG9zZUIpIHtcbiAgICAgIGJWZWMgPSB0cmFuc3Bvc2Uoe2lucHV0czoge3g6IGIzZH0sIGJhY2tlbmQsIGF0dHJzOiB7cGVybTogWzAsIDIsIDFdfX0pO1xuICAgICAgaW50ZXJtZWRpYXRlcy5wdXNoKGJWZWMpO1xuICAgIH1cblxuICAgIGNvbnN0IHNob3VsZFJlc2hhcGVBID0gb3V0ZXJTaGFwZUIgIT09IDE7XG4gICAgY29uc3Qgc2hvdWxkUmVzaGFwZUIgPSBvdXRlclNoYXBlQiA9PT0gMTtcblxuICAgIGxldCBhVmVjM2QgPSBhVmVjO1xuICAgIGlmIChzaG91bGRSZXNoYXBlQSkge1xuICAgICAgYVZlYzNkID0gcmVzaGFwZSh7XG4gICAgICAgIGlucHV0czoge3g6IGFWZWN9LFxuICAgICAgICBiYWNrZW5kLFxuICAgICAgICBhdHRyczoge3NoYXBlOiBbYmF0Y2hEaW0sIHNoYXJlZERpbSwgMV19XG4gICAgICB9KTtcblxuICAgICAgaW50ZXJtZWRpYXRlcy5wdXNoKGFWZWMzZCk7XG4gICAgfVxuXG4gICAgY29uc3QgYXhpcyA9IG91dGVyU2hhcGVCID09PSAxID8gMiA6IDE7XG5cbiAgICBsZXQgYlZlYzNkID0gYlZlYztcbiAgICBpZiAoc2hvdWxkUmVzaGFwZUIpIHtcbiAgICAgIGJWZWMzZCA9IHJlc2hhcGUoe1xuICAgICAgICBpbnB1dHM6IHt4OiBiVmVjfSxcbiAgICAgICAgYmFja2VuZCxcbiAgICAgICAgYXR0cnM6IHtzaGFwZTogW2JhdGNoRGltLCAxLCBzaGFyZWREaW1dfVxuICAgICAgfSk7XG5cbiAgICAgIGludGVybWVkaWF0ZXMucHVzaChiVmVjM2QpO1xuICAgIH1cblxuICAgIGNvbnN0IHByb2R1Y3QgPSBtdWx0aXBseSh7aW5wdXRzOiB7YTogYVZlYzNkLCBiOiBiVmVjM2R9LCBiYWNrZW5kfSk7XG4gICAgb3V0ID0gc3VtKHtpbnB1dHM6IHt4OiBwcm9kdWN0fSwgYmFja2VuZCwgYXR0cnM6IHtheGlzLCBrZWVwRGltczogdHJ1ZX19KTtcbiAgICBpbnRlcm1lZGlhdGVzLnB1c2gocHJvZHVjdCk7XG4gIH0gZWxzZSB7XG4gICAgY29uc3QgZHR5cGUgPSB1cGNhc3RUeXBlKGEuZHR5cGUsIGIuZHR5cGUpO1xuXG4gICAgY29uc3QgcHJvZ3JhbSA9IG5ldyBNYXRNdWxQYWNrZWRQcm9ncmFtKFxuICAgICAgICBhM2RTaGFwZSwgYjNkU2hhcGUsIFtiYXRjaERpbSwgb3V0ZXJTaGFwZUEsIG91dGVyU2hhcGVCXSwgdHJhbnNwb3NlQSxcbiAgICAgICAgdHJhbnNwb3NlQiwgaGFzQmlhcywgZnVzZWRBY3RpdmF0aW9uLCBoYXNQcmVsdUFjdGl2YXRpb25XZWlnaHRzLFxuICAgICAgICBoYXNMZWFreXJlbHVBbHBoYSk7XG5cbiAgICBjb25zdCBpbnB1dHM6IFRlbnNvckluZm9bXSA9IFthM2QsIGIzZF07XG4gICAgaWYgKGJpYXMgIT0gbnVsbCkge1xuICAgICAgaW5wdXRzLnB1c2goYmlhcyk7XG4gICAgfVxuICAgIGlmIChoYXNQcmVsdUFjdGl2YXRpb25XZWlnaHRzKSB7XG4gICAgICBpbnB1dHMucHVzaChwcmVsdUFjdGl2YXRpb25XZWlnaHRzKTtcbiAgICB9XG4gICAgaWYgKGhhc0xlYWt5cmVsdUFscGhhKSB7XG4gICAgICBjb25zdCAkbGVha3lyZWx1QWxwaGEgPSBiYWNrZW5kLm1ha2VUZW5zb3JJbmZvKFxuICAgICAgICAgIFtdLCAnZmxvYXQzMicsXG4gICAgICAgICAgdXRpbC5jcmVhdGVTY2FsYXJWYWx1ZShsZWFreXJlbHVBbHBoYSBhcyB1bmtub3duIGFzICdmbG9hdDMyJywgJ2Zsb2F0MzInKSk7XG4gICAgICBpbnB1dHMucHVzaCgkbGVha3lyZWx1QWxwaGEpO1xuICAgICAgaW50ZXJtZWRpYXRlcy5wdXNoKCRsZWFreXJlbHVBbHBoYSk7XG4gICAgfVxuXG4gICAgb3V0ID0gYmFja2VuZC5ydW5XZWJHTFByb2dyYW0ocHJvZ3JhbSwgaW5wdXRzLCBkdHlwZSk7XG4gIH1cblxuICBjb25zdCBvdXRSZXNoYXBlZCA9XG4gICAgICByZXNoYXBlKHtpbnB1dHM6IHt4OiBvdXR9LCBiYWNrZW5kLCBhdHRyczoge3NoYXBlOiBvdXRTaGFwZX19KTtcbiAgaW50ZXJtZWRpYXRlcy5wdXNoKG91dCk7XG4gIGZvciAoY29uc3QgaSBvZiBpbnRlcm1lZGlhdGVzKSB7XG4gICAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhpKTtcbiAgfVxuICByZXR1cm4gb3V0UmVzaGFwZWQ7XG59XG4iXX0=