/**
|
* @license
|
* Copyright 2018 Google Inc. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
|
import {SegOpInfo} from '../../ops/segment_util';
|
import {GPGPUProgram} from './gpgpu_math';
|
|
export class SegmentOpProgram implements GPGPUProgram {
|
variableNames = ['x', 'segmentIds'];
|
outputShape: number[];
|
userCode: string;
|
|
constructor(segOpInfo: SegOpInfo, segOpType: 'unsortedSegmentSum') {
|
const windowSize = segOpInfo.windowSize;
|
const batchSize = segOpInfo.batchSize;
|
const inSize = segOpInfo.inSize;
|
const numSegments = segOpInfo.numSegments;
|
const outSize = numSegments * Math.ceil(inSize / windowSize);
|
this.outputShape = [batchSize, outSize];
|
|
const initializationValue = '0.0';
|
const returnValue = `sumValue`;
|
|
const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
|
const windowSizeVec4Remainder = windowSize % 4;
|
|
const updateSnippet = `
|
sumValue += dot(values, segFilter);
|
`;
|
|
let checkValueOutOfBounds = '';
|
if (inSize % windowSize > 0) {
|
checkValueOutOfBounds = `
|
if (inIdx < 0 || inIdx >= ${inSize}) {
|
return initializationValue;
|
}
|
`;
|
}
|
|
let checkSegmentIdOutOfBounds = '';
|
if (inSize % windowSize > 0) {
|
checkSegmentIdOutOfBounds = `
|
if (inIdx < 0 || inIdx >= ${inSize}) {
|
return -1.0;
|
}
|
`;
|
}
|
|
this.userCode = `
|
const float initializationValue = ${initializationValue};
|
|
float getValue(int batch, int inIdx) {
|
${checkValueOutOfBounds}
|
return getX(batch, inIdx);
|
}
|
|
float getSegmentIdAtIndex(int inIdx) {
|
${checkSegmentIdOutOfBounds}
|
return getSegmentIds(inIdx);
|
}
|
|
void main() {
|
ivec2 coords = getOutputCoords();
|
int batch = coords[0];
|
int outIdx = coords[1];
|
int inOffset = int(floor(float(outIdx) / float(
|
${numSegments})) * float(${windowSize}));
|
int currentSeg = int(mod(float(outIdx), float(${numSegments})));
|
|
float sumValue = 0.0;
|
|
for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
|
int inIdx = inOffset + i;
|
vec4 values = vec4(
|
getValue(batch, inIdx),
|
getValue(batch, inIdx + 1),
|
getValue(batch, inIdx + 2),
|
getValue(batch, inIdx + 3)
|
);
|
|
vec4 segFilter = vec4(
|
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
|
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
|
int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
|
int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0
|
);
|
|
${updateSnippet}
|
}
|
|
int inIdx = inOffset + ${windowSizeNearestVec4};
|
if (${windowSizeVec4Remainder === 1}) {
|
vec4 values = vec4(
|
getValue(batch, inIdx),
|
initializationValue,
|
initializationValue,
|
initializationValue
|
);
|
|
int inIdxSeg = int(getSegmentIdAtIndex(inIdx));
|
|
vec4 segFilter = vec4(
|
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
|
0,
|
0,
|
0
|
);
|
|
${updateSnippet}
|
} else if (${windowSizeVec4Remainder === 2}) {
|
vec4 values = vec4(
|
getValue(batch, inIdx),
|
getValue(batch, inIdx + 1),
|
initializationValue,
|
initializationValue
|
);
|
|
vec4 segFilter = vec4(
|
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
|
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
|
0,
|
0
|
);
|
|
${updateSnippet}
|
} else if (${windowSizeVec4Remainder === 3}) {
|
vec4 values = vec4(
|
getValue(batch, inIdx),
|
getValue(batch, inIdx + 1),
|
getValue(batch, inIdx + 2),
|
initializationValue
|
);
|
|
vec4 segFilter = vec4(
|
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
|
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
|
int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
|
0
|
);
|
|
${updateSnippet}
|
}
|
setOutput(${returnValue});
|
}
|
`;
|
}
|
}
|