/**
|
* @license
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import * as tf from './index';
|
import { ALL_ENVS, describeWithFlags } from './jasmine_util';
|
import { expectArraysClose } from './test_util';
|
describeWithFlags('kernel_registry', ALL_ENVS, (testEnv) => {
|
afterEach(async () => {
|
// Revert backend mutation.
|
await tf.setBackend(testEnv.backendName);
|
});
|
it('register a kernel and call it', () => {
|
// Make sure the backend is loaded. Perhaps tf.getBackend
|
// should call tf.backend to make sure the backend is loaded?
|
expect(tf.backend()).toBeDefined();
|
let called = false;
|
tf.registerKernel({
|
kernelName: 'MyKernel',
|
backendName: tf.getBackend(),
|
kernelFunc: ({ inputs, attrs }) => {
|
expect(attrs.a).toBe(5);
|
expect(inputs.x.shape).toEqual([2, 2]);
|
expect(inputs.x.dtype).toBe('float32');
|
called = true;
|
return { dtype: 'float32', shape: [3, 3], dataId: {} };
|
}
|
});
|
const inputs = { x: tf.zeros([2, 2]) };
|
const attrs = { a: 5 };
|
const res = tf.engine().runKernel('MyKernel', inputs, attrs);
|
expect(called).toBe(true);
|
expect(res.dtype).toBe('float32');
|
expect(res.shape).toEqual([3, 3]);
|
tf.unregisterKernel('MyKernel', tf.getBackend());
|
});
|
it('errors when running non-existent kernel', () => {
|
const inputs = {};
|
const attrs = {};
|
expect(() => tf.engine().runKernel('DoesNotExist', inputs, attrs))
|
.toThrowError();
|
});
|
// TODO (yassogba) double registration happens now because a backend might be
|
// imported more than once (e.g. by a top level package and a dependent
|
// package). We may want to remove this test long-term but skip it for
|
// now.
|
// tslint:disable-next-line: ban
|
xit('errors when registering the same kernel twice', () => {
|
tf.registerBackend('backend1', () => {
|
return {
|
id: 1,
|
dispose: () => null,
|
disposeData: (dataId) => null,
|
numDataIds: () => 0
|
};
|
});
|
tf.registerKernel({
|
kernelName: 'MyKernel',
|
backendName: 'backend1',
|
kernelFunc: () => {
|
return null;
|
}
|
});
|
expect(() => tf.registerKernel({
|
kernelName: 'MyKernel',
|
backendName: 'backend1',
|
kernelFunc: () => {
|
return null;
|
}
|
})).toThrowError();
|
tf.unregisterKernel('MyKernel', 'backend1');
|
tf.removeBackend('backend1');
|
});
|
it('register same kernel on two different backends', async () => {
|
tf.registerBackend('backend1', () => {
|
return {
|
id: 1,
|
dispose: () => null,
|
disposeData: (dataId) => true,
|
numDataIds: () => 0
|
};
|
});
|
tf.registerBackend('backend2', () => {
|
return {
|
id: 2,
|
dispose: () => null,
|
disposeData: (dataId) => null,
|
numDataIds: () => 0
|
};
|
});
|
let lastStorageId = -1;
|
const kernelFunc = ({ backend }) => {
|
lastStorageId = backend.id;
|
return { dataId: {}, shape: [], dtype: 'float32' };
|
};
|
tf.registerKernel({ kernelName: 'MyKernel', backendName: 'backend1', kernelFunc });
|
tf.registerKernel({ kernelName: 'MyKernel', backendName: 'backend2', kernelFunc });
|
// No kernel has been executed yet.
|
expect(lastStorageId).toBe(-1);
|
// Kernel was executed on the first backend.
|
await tf.setBackend('backend1');
|
tf.engine().runKernel('MyKernel', {}, {});
|
expect(lastStorageId).toBe(1);
|
// Kernel was executed on the second backend.
|
await tf.setBackend('backend2');
|
tf.engine().runKernel('MyKernel', {}, {});
|
expect(lastStorageId).toBe(2);
|
tf.removeBackend('backend1');
|
tf.removeBackend('backend2');
|
tf.unregisterKernel('MyKernel', 'backend1');
|
tf.unregisterKernel('MyKernel', 'backend2');
|
});
|
it('register kernel with setup and dispose functions', async () => {
|
const backendName = 'custom-backend';
|
const kernelName = 'MyKernel';
|
const customBackend = {
|
dispose: () => null,
|
disposeData: (dataId) => true,
|
numDataIds: () => 0
|
};
|
tf.registerBackend(backendName, () => customBackend);
|
const kernelFunc = () => {
|
return { dataId: {}, shape: [], dtype: 'float32' };
|
};
|
let setupCalled = false;
|
const setupFunc = (backend) => {
|
expect(backend).toBe(customBackend);
|
setupCalled = true;
|
};
|
let disposeCalled = false;
|
const disposeFunc = (backend) => {
|
expect(backend).toBe(customBackend);
|
disposeCalled = true;
|
};
|
tf.registerKernel({ kernelName, backendName, kernelFunc, setupFunc, disposeFunc });
|
expect(setupCalled).toBe(false);
|
expect(disposeCalled).toBe(false);
|
await tf.setBackend(backendName);
|
expect(setupCalled).toBe(true);
|
expect(disposeCalled).toBe(false);
|
// Kernel was executed on the first backend.
|
tf.engine().runKernel(kernelName, {}, {});
|
tf.removeBackend(backendName);
|
expect(setupCalled).toBe(true);
|
expect(disposeCalled).toBe(true);
|
tf.unregisterKernel(kernelName, backendName);
|
});
|
});
|
describeWithFlags('gradient registry', ALL_ENVS, () => {
|
it('register a kernel with gradient and call it', async () => {
|
let kernelWasCalled = false;
|
let gradientWasCalled = false;
|
const kernelName = 'MyKernel';
|
const x = tf.zeros([2, 2]);
|
tf.registerKernel({
|
kernelName,
|
backendName: tf.getBackend(),
|
kernelFunc: () => {
|
kernelWasCalled = true;
|
return { dtype: 'float32', shape: [3, 3], dataId: {} };
|
}
|
});
|
tf.registerGradient({
|
kernelName,
|
inputsToSave: ['x'],
|
gradFunc: (dy, saved) => {
|
// Make sure saved input (x) was passed to the gradient function.
|
expect(saved[0].dataId).toEqual(x.dataId);
|
// Make sure dy matches the shape of the output.
|
expect(dy.shape).toEqual([3, 3]);
|
gradientWasCalled = true;
|
return { x: () => tf.fill([2, 2], 3) };
|
},
|
});
|
const gradFunc = tf.grad(x => tf.engine().runKernel(kernelName, { x }, {} /* attrs */));
|
const dx = gradFunc(x);
|
expect(kernelWasCalled).toBe(true);
|
expect(gradientWasCalled).toBe(true);
|
expect(dx.dtype).toBe('float32');
|
expect(dx.shape).toEqual([2, 2]);
|
expectArraysClose(await dx.data(), [3, 3, 3, 3]);
|
tf.unregisterKernel(kernelName, tf.getBackend());
|
tf.unregisterGradient(kernelName);
|
});
|
it('register a kernel with gradient that specifies outputsToSave and call it', async () => {
|
let kernelWasCalled = false;
|
let gradientWasCalled = false;
|
const kernelName = 'MyKernel';
|
const tensor = tf.zeros([3, 3], 'float32');
|
const forwardReturnDataId = tensor.dataId;
|
tf.registerKernel({
|
kernelName,
|
backendName: tf.getBackend(),
|
kernelFunc: () => {
|
kernelWasCalled = true;
|
return {
|
dtype: tensor.dtype,
|
shape: tensor.shape,
|
dataId: forwardReturnDataId
|
};
|
}
|
});
|
tf.registerGradient({
|
kernelName,
|
outputsToSave: [true],
|
gradFunc: (dy, saved) => {
|
// Make sure saved output was passed to the gradient function.
|
expect(saved[0].dataId).toEqual(forwardReturnDataId);
|
// Make sure dy matches the shape of the output.
|
expect(dy.shape).toEqual([3, 3]);
|
gradientWasCalled = true;
|
return { x: () => tf.fill([2, 2], 3) };
|
},
|
});
|
const gradFunc = tf.grad(x => tf.engine().runKernel(kernelName, { x }, {} /* attrs */));
|
const x = tf.zeros([2, 2]);
|
const dx = gradFunc(x);
|
expect(kernelWasCalled).toBe(true);
|
expect(gradientWasCalled).toBe(true);
|
expect(dx.dtype).toBe('float32');
|
expect(dx.shape).toEqual([2, 2]);
|
tf.unregisterKernel(kernelName, tf.getBackend());
|
tf.unregisterGradient(kernelName);
|
});
|
it('register a kernel with array inputs and saveAllInputs true', async () => {
|
let kernelWasCalled = false;
|
let gradientWasCalled = false;
|
const kernelName = 'MyKernel';
|
const x = [tf.zeros([2, 2]), tf.zeros([2, 2])];
|
const forwardReturnDataId = {};
|
tf.registerKernel({
|
kernelName,
|
backendName: tf.getBackend(),
|
kernelFunc: () => {
|
kernelWasCalled = true;
|
return { dtype: 'float32', shape: [3, 3], dataId: forwardReturnDataId };
|
}
|
});
|
tf.registerGradient({
|
kernelName,
|
saveAllInputs: true,
|
gradFunc: (dy, saved) => {
|
// Make sure saved input (x) was passed to the gradient function.
|
const [$x0, $x1] = x;
|
expect(saved.length).toEqual(x.length);
|
expect($x0.dataId).toEqual(x[0].dataId);
|
expect($x1.dataId).toEqual(x[1].dataId);
|
gradientWasCalled = true;
|
return { 0: () => tf.fill([2, 2], 3), 1: () => tf.fill([2, 2], 3) };
|
}
|
});
|
// Inputs as array.
|
const z = (...x) =>
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
tf.engine().runKernel(kernelName, x, {} /* attrs */);
|
const gradFunc = tf.grads(z);
|
const dx = gradFunc(x);
|
expect(kernelWasCalled).toBe(true);
|
expect(gradientWasCalled).toBe(true);
|
expect(dx.length).toEqual(2);
|
expect(dx[0].dtype).toBe('float32');
|
expect(dx[0].shape).toEqual([2, 2]);
|
expect(dx[1].dtype).toBe('float32');
|
expect(dx[1].shape).toEqual([2, 2]);
|
expectArraysClose(await dx[0].data(), [3, 3, 3, 3]);
|
expectArraysClose(await dx[1].data(), [3, 3, 3, 3]);
|
tf.unregisterKernel(kernelName, tf.getBackend());
|
tf.unregisterGradient(kernelName);
|
});
|
it('register a kernel with map inputs and saveAllInputs true should throw ' +
|
'error', async () => {
|
const kernelName = 'MyKernel';
|
const x0 = tf.zeros([2, 2]);
|
const x1 = tf.zeros([2, 2]);
|
const forwardReturnDataId = {};
|
tf.registerKernel({
|
kernelName,
|
backendName: tf.getBackend(),
|
kernelFunc: () => {
|
return {
|
dtype: 'float32',
|
shape: [3, 3],
|
dataId: forwardReturnDataId
|
};
|
}
|
});
|
tf.registerGradient({
|
kernelName,
|
saveAllInputs: true,
|
gradFunc: (dy, saved) => {
|
// Make sure saved input (x) was passed to the gradient function.
|
const [$x0, $x1] = saved;
|
expect($x0.dataId).toEqual(x0.dataId);
|
expect($x1.dataId).toEqual(x1.dataId);
|
return { x0: () => tf.fill([2, 2], 3), x1: () => tf.fill([2, 2], 3) };
|
}
|
});
|
// Inputs as map.
|
const z = (x0, x1) =>
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
tf.engine().runKernel(kernelName, { x0, x1 }, {} /* attrs */);
|
const gradFunc = tf.grads(z);
|
expect(() => gradFunc([x0, x1]))
|
.toThrowError(/saveAllInputs is true, expected inputs to be an array/);
|
tf.unregisterKernel(kernelName, tf.getBackend());
|
tf.unregisterGradient(kernelName);
|
});
|
it('errors when running non-existent gradient', () => {
|
const kernelName = 'MyKernel';
|
const x = tf.zeros([2, 2]);
|
tf.registerKernel({
|
kernelName,
|
backendName: tf.getBackend(),
|
kernelFunc: () => ({ dtype: 'float32', shape: [3, 3], dataId: {} })
|
});
|
const gradFunc = tf.grad(x => tf.engine().runKernel(kernelName, { x }, {} /* attrs */));
|
expect(() => gradFunc(x))
|
.toThrowError(/gradient function not found for MyKernel/);
|
tf.unregisterKernel(kernelName, tf.getBackend());
|
});
|
// tslint:disable-next-line: ban
|
xit('warning when registering the same gradient twice', () => {
|
const kernelName = 'MyKernel';
|
tf.registerGradient({ kernelName, gradFunc: () => null });
|
spyOn(console, 'warn').and.callFake((msg) => {
|
expect(msg).toBe('Overriding the gradient for \'MyKernel\'');
|
});
|
tf.registerGradient({ kernelName, gradFunc: () => null });
|
tf.unregisterGradient(kernelName);
|
});
|
});
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"kernel_registry_test.js","sourceRoot":"","sources":["../../../../../tfjs-core/src/kernel_registry_test.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,EAAE,MAAM,SAAS,CAAC;AAE9B,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAU,MAAM,gBAAgB,CAAC;AAGpE,OAAO,EAAC,iBAAiB,EAAC,MAAM,aAAa,CAAC;AAE9C,iBAAiB,CAAC,iBAAiB,EAAE,QAAQ,EAAE,CAAC,OAAgB,EAAE,EAAE;IAClE,SAAS,CAAC,KAAK,IAAI,EAAE;QACnB,2BAA2B;QAC3B,MAAM,EAAE,CAAC,UAAU,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC;IAC3C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,+BAA+B,EAAE,GAAG,EAAE;QACvC,yDAAyD;QACzD,6DAA6D;QAC7D,MAAM,CAAC,EAAE,CAAC,OAAO,EAAE,CAAC,CAAC,WAAW,EAAE,CAAC;QACnC,IAAI,MAAM,GAAG,KAAK,CAAC;QACnB,EAAE,CAAC,cAAc,CAAC;YAChB,UAAU,EAAE,UAAU;YACtB,WAAW,EAAE,EAAE,CAAC,UAAU,EAAE;YAC5B,UAAU,EAAE,CAAC,EAAC,MAAM,EAAE,KAAK,EAAC,EAAE,EAAE;gBAC9B,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;gBACxB,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;gBACvC,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;gBACvC,MAAM,GAAG,IAAI,CAAC;gBACd,OAAO,EAAC,KAAK,EAAE,SAAS,EAAE,KAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,EAAE,EAAE,EAAC,CAAC;YACvD,CAAC;SACF,CAAC,CAAC;QAEH,MAAM,MAAM,GAAG,EAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAC,CAAC;QACrC,MAAM,KAAK,GAAG,EAAC,CAAC,EAAE,CAAC,EAAC,CAAC;QACrB,MAAM,GAAG,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,SAAS,CAAC,UAAU,EAAE,MAAM,EAAE,KAAK,CAAe,CAAC;QAE3E,MAAM,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QAC1B,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QAClC,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAElC,EAAE,CAAC,gBAAgB,CAAC,UAAU,EAAE,EAAE,CAAC,UAAU,EAAE,CAAC,CAAC;IACnD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,yCAAyC,EAAE,GAAG,EAAE;QACjD,MAAM,MAAM,GAAG,EAAE,CAAC;QAClB,MAAM,KAAK,GAAG,EAAE,CAAC;QACjB,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,MAAM,EAAE,CAAC,SAAS,CAAC,cAAc,EAAE,MAAM,EAAE,KAAK,CAAC,CAAC;aAC7D,YAAY,EAAE,CAAC;IACtB,CAAC,CAAC,CAAC;IAEH,6EAA6E;IAC7E,uEAAuE;IACvE,sEAAsE;IACtE,OAAO;IACP,gCAAgC;IAChC,GAAG,CAAC,+CAA+C,EAAE,GAAG,EAAE;QAIxD,EAAE,CAAC,eAAe,CAAC,UAAU,EAAE,GAAG,EAAE;YAClC,OAAO;gBACL,EAAE,EAAE,CAAC;gBACL,OAAO,EAAE,GAAG,EAAE,CAAC,IAAI;gBACnB,WAAW,EAAE,CAAC,MAAU,EAAE,EAAE,CAAC,IAAI;gBACjC,UAAU,EAAE,GAAG,EAAE,CAAC,CAAC;aACL,CAAC;QACnB,CAAC,CAAC,CAAC;QAEH,EAAE,CAAC,cAAc,CAAC;YAChB,UAAU,EAAE,UAAU;YACtB,WAAW,EAAE,UAAU;YACvB,UAAU,EAAE,GAAG,EAAE;gBACf,OAAO,IAAI,CAAC;YACd,CAAC;SACF,CAAC,CAAC;QACH,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,cAAc,CAAC;YAC7B,UAAU,EAAE,UAAU;YACtB,WAAW,EAAE,UAAU;YACvB,UAAU,EAAE,GAAG,EAAE;gBACf,OAAO,IAAI,CAAC;YACd,CAAC;SACF,CAAC,CAAC,CAAC,YAAY,EAAE,CAAC;QAEnB,EAAE,CAAC,gBAAgB,CAAC,UAAU,EAAE,UAAU,CAAC,CAAC;QAC5C,EAAE,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;IAC/B,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,gDAAgD,EAAE,KAAK,IAAI,EAAE;QAI9D,EAAE,CAAC,eAAe,CAAC,UAAU,EAAE,GAAG,EAAE;YAClC,OAAO;gBACL,EAAE,EAAE,CAAC;gBACL,OAAO,EAAE,GAAG,EAAE,CAAC,IAAI;gBACnB,WAAW,EAAE,CAAC,MAAU,EAAE,EAAE,CAAC,IAAI;gBACjC,UAAU,EAAE,GAAG,EAAE,CAAC,CAAC;aACL,CAAC;QACnB,CAAC,CAAC,CAAC;QACH,EAAE,CAAC,eAAe,CAAC,UAAU,EAAE,GAAG,EAAE;YAClC,OAAO;gBACL,EAAE,EAAE,CAAC;gBACL,OAAO,EAAE,GAAG,EAAE,CAAC,IAAI;gBACnB,WAAW,EAAE,CAAC,MAAU,EAAE,EAAE,CAAC,IAAI;gBACjC,UAAU,EAAE,GAAG,EAAE,CAAC,CAAC;aACL,CAAC;QACnB,CAAC,CAAC,CAAC;QAEH,IAAI,aAAa,GAAG,CAAC,CAAC,CAAC;QACvB,MAAM,UAAU,GAAe,CAAC,EAAC,OAAO,EAAC,EAAE,EAAE;YAC3C,aAAa,GAAI,OAAuB,CAAC,EAAE,CAAC;YAC5C,OAAO,EAAC,MAAM,EAAE,EAAE,EAAE,KAAK,EAAE,EAAE,EAAE,KAAK,EAAE,SAAS,EAAC,CAAC;QACnD,CAAC,CAAC;QACF,EAAE,CAAC,cAAc,CACb,EAAC,UAAU,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,UAAU,EAAC,CAAC,CAAC;QACnE,EAAE,CAAC,cAAc,CACb,EAAC,UAAU,EAAE,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,UAAU,EAAC,CAAC,CAAC;QAEnE,mCAAmC;QACnC,MAAM,CAAC,aAAa,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;QAE/B,4CAA4C;QAC5C,MAAM,EAAE,CAAC,UAAU,CAAC,UAAU,CAAC,CAAC;QAChC,EAAE,CAAC,MAAM,EAAE,CAAC,SAAS,CAAC,UAAU,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC;QAC1C,MAAM,CAAC,aAAa,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE9B,6CAA6C;QAC7C,MAAM,EAAE,CAAC,UAAU,CAAC,UAAU,CAAC,CAAC;QAChC,EAAE,CAAC,MAAM,EAAE,CAAC,SAAS,CAAC,UAAU,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC;QAC1C,MAAM,CAAC,aAAa,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE9B,EAAE,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;QAC7B,EAAE,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;QAC7B,EAAE,CAAC,gBAAgB,CAAC,UAAU,EAAE,UAAU,CAAC,CAAC;QAC5C,EAAE,CAAC,gBAAgB,CAAC,UAAU,EAAE,UAAU,CAAC,CAAC;IAC9C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,kDAAkD,EAAE,KAAK,IAAI,EAAE;QAChE,MAAM,WAAW,GAAG,gBAAgB,CAAC;QACrC,MAAM,UAAU,GAAG,UAAU,CAAC;QAE9B,MAAM,aAAa,GAAG;YACpB,OAAO,EAAE,GAAG,EAAE,CAAC,IAAI;YACnB,WAAW,EAAE,CAAC,MAAU,EAAE,EAAE,CAAC,IAAI;YACjC,UAAU,EAAE,GAAG,EAAE,CAAC,CAAC;SACL,CAAC;QACjB,EAAE,CAAC,eAAe,CAAC,WAAW,EAAE,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC;QAErD,MAAM,UAAU,GAAe,GAAG,EAAE;YAClC,OAAO,EAAC,MAAM,EAAE,EAAE,EAAE,KAAK,EAAE,EAAE,EAAE,KAAK,EAAE,SAAS,EAAC,CAAC;QACnD,CAAC,CAAC;QACF,IAAI,WAAW,GAAG,KAAK,CAAC;QACxB,MAAM,SAAS,GAAG,CAAC,OAAsB,EAAE,EAAE;YAC3C,MAAM,CAAC,OAAO,CAAC,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC;YACpC,WAAW,GAAG,IAAI,CAAC;QACrB,CAAC,CAAC;QACF,IAAI,aAAa,GAAG,KAAK,CAAC;QAC1B,MAAM,WAAW,GAAG,CAAC,OAAsB,EAAE,EAAE;YAC7C,MAAM,CAAC,OAAO,CAAC,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC;YACpC,aAAa,GAAG,IAAI,CAAC;QACvB,CAAC,CAAC;QACF,EAAE,CAAC,cAAc,CACb,EAAC,UAAU,EAAE,WAAW,EAAE,UAAU,EAAE,SAAS,EAAE,WAAW,EAAC,CAAC,CAAC;QAEnE,MAAM,CAAC,WAAW,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QAChC,MAAM,CAAC,aAAa,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QAElC,MAAM,EAAE,CAAC,UAAU,CAAC,WAAW,CAAC,CAAC;QACjC,MAAM,CAAC,WAAW,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QAC/B,MAAM,CAAC,aAAa,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QAElC,4CAA4C;QAC5C,EAAE,CAAC,MAAM,EAAE,CAAC,SAAS,CAAC,UAAU,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC;QAE1C,EAAE,CAAC,aAAa,CAAC,WAAW,CAAC,CAAC;QAC9B,MAAM,CAAC,WAAW,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QAC/B,MAAM,CAAC,aAAa,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QAEjC,EAAE,CAAC,gBAAgB,CAAC,UAAU,EAAE,WAAW,CAAC,CAAC;IAC/C,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC;AAEH,iBAAiB,CAAC,mBAAmB,EAAE,QAAQ,EAAE,GAAG,EAAE;IACpD,EAAE,CAAC,6CAA6C,EAAE,KAAK,IAAI,EAAE;QAC3D,IAAI,eAAe,GAAG,KAAK,CAAC;QAC5B,IAAI,iBAAiB,GAAG,KAAK,CAAC;QAC9B,MAAM,UAAU,GAAG,UAAU,CAAC;QAC9B,MAAM,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE3B,EAAE,CAAC,cAAc,CAAC;YAChB,UAAU;YACV,WAAW,EAAE,EAAE,CAAC,UAAU,EAAE;YAC5B,UAAU,EAAE,GAAG,EAAE;gBACf,eAAe,GAAG,IAAI,CAAC;gBACvB,OAAO,EAAC,KAAK,EAAE,SAAS,EAAE,KAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,EAAE,EAAE,EAAC,CAAC;YACvD,CAAC;SACF,CAAC,CAAC;QAEH,EAAE,CAAC,gBAAgB,CAAC;YAClB,UAAU;YACV,YAAY,EAAE,CAAC,GAAG,CAAC;YACnB,QAAQ,EAAE,CAAC,EAAa,EAAE,KAAK,EAAE,EAAE;gBACjC,iEAAiE;gBACjE,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;gBAC1C,gDAAgD;gBAChD,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;gBACjC,iBAAiB,GAAG,IAAI,CAAC;gBACzB,OAAO,EAAC,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAC,CAAC;YACvC,CAAC;SACF,CAAC,CAAC;QAEH,MAAM,QAAQ,GACV,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,MAAM,EAAE,CAAC,SAAS,CAAC,UAAU,EAAE,EAAC,CAAC,EAAC,EAAE,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC;QACzE,MAAM,EAAE,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC;QACvB,MAAM,CAAC,eAAe,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QACnC,MAAM,CAAC,iBAAiB,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QACrC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACjC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACjC,iBAAiB,CAAC,MAAM,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACjD,EAAE,CAAC,gBAAgB,CAAC,UAAU,EAAE,EAAE,CAAC,UAAU,EAAE,CAAC,CAAC;QACjD,EAAE,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,0EAA0E,EAC1E,KAAK,IAAI,EAAE;QACT,IAAI,eAAe,GAAG,KAAK,CAAC;QAC5B,IAAI,iBAAiB,GAAG,KAAK,CAAC;QAC9B,MAAM,UAAU,GAAG,UAAU,CAAC;QAE9B,MAAM,MAAM,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC;QAC3C,MAAM,mBAAmB,GAAG,MAAM,CAAC,MAAM,CAAC;QAC1C,EAAE,CAAC,cAAc,CAAC;YAChB,UAAU;YACV,WAAW,EAAE,EAAE,CAAC,UAAU,EAAE;YAC5B,UAAU,EAAE,GAAG,EAAE;gBACf,eAAe,GAAG,IAAI,CAAC;gBACvB,OAAO;oBACL,KAAK,EAAE,MAAM,CAAC,KAAK;oBACnB,KAAK,EAAE,MAAM,CAAC,KAAK;oBACnB,MAAM,EAAE,mBAAmB;iBAC5B,CAAC;YACJ,CAAC;SACF,CAAC,CAAC;QAEH,EAAE,CAAC,gBAAgB,CAAC;YAClB,UAAU;YACV,aAAa,EAAE,CAAC,IAAI,CAAC;YACrB,QAAQ,EAAE,CAAC,EAAa,EAAE,KAAK,EAAE,EAAE;gBACjC,8DAA8D;gBAC9D,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,mBAAmB,CAAC,CAAC;gBACrD,gDAAgD;gBAChD,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;gBACjC,iBAAiB,GAAG,IAAI,CAAC;gBACzB,OAAO,EAAC,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAC,CAAC;YACvC,CAAC;SACF,CAAC,CAAC;QAEH,MAAM,QAAQ,GAAG,EAAE,CAAC,IAAI,CACpB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,MAAM,EAAE,CAAC,SAAS,CACtB,UAAU,EAAE,EAAC,CAAC,EAAC,EAAE,EAAE,CAAC,WAAW,CAC9B,CAAC,CAAC;QACX,MAAM,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3B,MAAM,EAAE,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC;QACvB,MAAM,CAAC,eAAe,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QACnC,MAAM,CAAC,iBAAiB,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QACrC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACjC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACjC,EAAE,CAAC,gBAAgB,CAAC,UAAU,EAAE,EAAE,CAAC,UAAU,EAAE,CAAC,CAAC;QACjD,EAAE,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEN,EAAE,CAAC,4DAA4D,EAAE,KAAK,IAAI,EAAE;QAC1E,IAAI,eAAe,GAAG,KAAK,CAAC;QAC5B,IAAI,iBAAiB,GAAG,KAAK,CAAC;QAC9B,MAAM,UAAU,GAAG,UAAU,CAAC;QAC9B,MAAM,CAAC,GAAG,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QAE/C,MAAM,mBAAmB,GAAG,EAAE,CAAC;QAC/B,EAAE,CAAC,cAAc,CAAC;YAChB,UAAU;YACV,WAAW,EAAE,EAAE,CAAC,UAAU,EAAE;YAC5B,UAAU,EAAE,GAAG,EAAE;gBACf,eAAe,GAAG,IAAI,CAAC;gBACvB,OAAO,EAAC,KAAK,EAAE,SAAS,EAAE,KAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,EAAE,mBAAmB,EAAC,CAAC;YACxE,CAAC;SACF,CAAC,CAAC;QAEH,EAAE,CAAC,gBAAgB,CAAC;YAClB,UAAU;YACV,aAAa,EAAE,IAAI;YACnB,QAAQ,EAAE,CAAC,EAAa,EAAE,KAAK,EAAE,EAAE;gBACjC,iEAAiE;gBACjE,MAAM,CAAC,GAAG,EAAE,GAAG,CAAC,GAAG,CAAC,CAAC;gBACrB,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;gBACvC,MAAM,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;gBACxC,MAAM,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;gBACxC,iBAAiB,GAAG,IAAI,CAAC;gBACzB,OAAO,EAAC,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAC,CAAC;YACpE,CAAC;SACF,CAAC,CAAC;QAEH,mBAAmB;QACnB,MAAM,CAAC,GAAG,CAAC,GAAG,CAAc,EAAE,EAAE;QAC5B,0DAA0D;QAC1D,EAAE,CAAC,MAAM,EAAE,CAAC,SAAS,CACjB,UAAU,EAAE,CAAiC,EAAE,EAAE,CAAC,WAAW,CACxD,CAAC;QACd,MAAM,QAAQ,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAC7B,MAAM,EAAE,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC;QACvB,MAAM,CAAC,eAAe,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QACnC,MAAM,CAAC,iBAAiB,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;QACrC,MAAM,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;QAC7B,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACpC,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QACpC,MAAM,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpC,iBAAiB,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,iBAAiB,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,EAAE,CAAC,gBAAgB,CAAC,UAAU,EAAE,EAAE,CAAC,UAAU,EAAE,CAAC,CAAC;QACjD,EAAE,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,wEAAwE;QACpE,OAAO,EACX,KAAK,IAAI,EAAE;QACT,MAAM,UAAU,GAAG,UAAU,CAAC;QAC9B,MAAM,EAAE,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC5B,MAAM,EAAE,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE5B,MAAM,mBAAmB,GAAG,EAAE,CAAC;QAC/B,EAAE,CAAC,cAAc,CAAC;YAChB,UAAU;YACV,WAAW,EAAE,EAAE,CAAC,UAAU,EAAE;YAC5B,UAAU,EAAE,GAAG,EAAE;gBACf,OAAO;oBACL,KAAK,EAAE,SAAS;oBAChB,KAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC;oBACb,MAAM,EAAE,mBAAmB;iBAC5B,CAAC;YACJ,CAAC;SACF,CAAC,CAAC;QAEH,EAAE,CAAC,gBAAgB,CAAC;YAClB,UAAU;YACV,aAAa,EAAE,IAAI;YACnB,QAAQ,EAAE,CAAC,EAAa,EAAE,KAAK,EAAE,EAAE;gBACjC,iEAAiE;gBACjE,MAAM,CAAC,GAAG,EAAE,GAAG,CAAC,GAAG,KAAK,CAAC;gBACzB,MAAM,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC;gBACtC,MAAM,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC;gBACtC,OAAO,EAAC,EAAE,EAAE,GAAG,EAAE,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,EAAE,EAAE,GAAG,EAAE,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAC,CAAC;YACtE,CAAC;SACF,CAAC,CAAC;QAEH,iBAAiB;QACjB,MAAM,CAAC,GAAG,CAAC,EAAa,EAAE,EAAa,EAAE,EAAE;QACvC,0DAA0D;QAC1D,EAAE,CAAC,MAAM,EAAE,CAAC,SAAS,CAAC,UAAU,EAAE,EAAC,EAAE,EAAE,EAAE,EAAC,EAAE,EAAE,CAAC,WAAW,CACjD,CAAC;QACd,MAAM,QAAQ,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAC7B,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;aAC3B,YAAY,CACT,uDAAuD,CAAC,CAAC;QACjE,EAAE,CAAC,gBAAgB,CAAC,UAAU,EAAE,EAAE,CAAC,UAAU,EAAE,CAAC,CAAC;QACjD,EAAE,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEN,EAAE,CAAC,2CAA2C,EAAE,GAAG,EAAE;QACnD,MAAM,UAAU,GAAG,UAAU,CAAC;QAC9B,MAAM,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE3B,EAAE,CAAC,cAAc,CAAC;YAChB,UAAU;YACV,WAAW,EAAE,EAAE,CAAC,UAAU,EAAE;YAC5B,UAAU,EAAE,GAAG,EAAE,CAAC,CAAC,EAAC,KAAK,EAAE,SAAS,EAAE,KAAK,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,MAAM,EAAE,EAAE,EAAC,CAAC;SAClE,CAAC,CAAC;QAEH,MAAM,QAAQ,GACV,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,MAAM,EAAE,CAAC,SAAS,CAAC,UAAU,EAAE,EAAC,CAAC,EAAC,EAAE,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC;QACzE,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;aACpB,YAAY,CAAC,0CAA0C,CAAC,CAAC;QAE9D,EAAE,CAAC,gBAAgB,CAAC,UAAU,EAAE,EAAE,CAAC,UAAU,EAAE,CAAC,CAAC;IACnD,CAAC,CAAC,CAAC;IAEH,gCAAgC;IAChC,GAAG,CAAC,kDAAkD,EAAE,GAAG,EAAE;QAC3D,MAAM,UAAU,GAAG,UAAU,CAAC;QAC9B,EAAE,CAAC,gBAAgB,CAAC,EAAC,UAAU,EAAE,QAAQ,EAAE,GAAG,EAAE,CAAC,IAAI,EAAC,CAAC,CAAC;QACxD,KAAK,CAAC,OAAO,EAAE,MAAM,CAAC,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC,GAAW,EAAE,EAAE;YAClD,MAAM,CAAC,GAAG,CAAC,CAAC,IAAI,CAAC,0CAA0C,CAAC,CAAC;QAC/D,CAAC,CAAC,CAAC;QACH,EAAE,CAAC,gBAAgB,CAAC,EAAC,UAAU,EAAE,QAAQ,EAAE,GAAG,EAAE,CAAC,IAAI,EAAC,CAAC,CAAC;QACxD,EAAE,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tf from './index';\nimport {KernelBackend} from './index';\nimport {ALL_ENVS, describeWithFlags, TestEnv} from './jasmine_util';\nimport { KernelFunc } from './kernel_registry';\nimport { TensorInfo } from './tensor_info';\nimport {expectArraysClose} from './test_util';\n\ndescribeWithFlags('kernel_registry', ALL_ENVS, (testEnv: TestEnv) => {\n  afterEach(async () => {\n    // Revert backend mutation.\n    await tf.setBackend(testEnv.backendName);\n  });\n\n  it('register a kernel and call it', () => {\n    // Make sure the backend is loaded. Perhaps tf.getBackend\n    // should call tf.backend to make sure the backend is loaded?\n    expect(tf.backend()).toBeDefined();\n    let called = false;\n    tf.registerKernel({\n      kernelName: 'MyKernel',\n      backendName: tf.getBackend(),\n      kernelFunc: ({inputs, attrs}) => {\n        expect(attrs.a).toBe(5);\n        expect(inputs.x.shape).toEqual([2, 2]);\n        expect(inputs.x.dtype).toBe('float32');\n        called = true;\n        return {dtype: 'float32', shape: [3, 3], dataId: {}};\n      }\n    });\n\n    const inputs = {x: tf.zeros([2, 2])};\n    const attrs = {a: 5};\n    const res = tf.engine().runKernel('MyKernel', inputs, attrs) as TensorInfo;\n\n    expect(called).toBe(true);\n    expect(res.dtype).toBe('float32');\n    expect(res.shape).toEqual([3, 3]);\n\n    tf.unregisterKernel('MyKernel', tf.getBackend());\n  });\n\n  it('errors when running non-existent kernel', () => {\n    const inputs = {};\n    const attrs = {};\n    expect(() => tf.engine().runKernel('DoesNotExist', inputs, attrs))\n        .toThrowError();\n  });\n\n  // TODO (yassogba) double registration happens now because a backend might be\n  // imported more than once (e.g. by a top level package and a dependent\n  // package). We may want to remove this test long-term but skip it for\n  // now.\n  // tslint:disable-next-line: ban\n  xit('errors when registering the same kernel twice', () => {\n    interface TestBackend extends KernelBackend {\n      id: number;\n    }\n    tf.registerBackend('backend1', () => {\n      return {\n        id: 1,\n        dispose: () => null,\n        disposeData: (dataId: {}) => null,\n        numDataIds: () => 0\n      } as TestBackend;\n    });\n\n    tf.registerKernel({\n      kernelName: 'MyKernel',\n      backendName: 'backend1',\n      kernelFunc: () => {\n        return null;\n      }\n    });\n    expect(() => tf.registerKernel({\n      kernelName: 'MyKernel',\n      backendName: 'backend1',\n      kernelFunc: () => {\n        return null;\n      }\n    })).toThrowError();\n\n    tf.unregisterKernel('MyKernel', 'backend1');\n    tf.removeBackend('backend1');\n  });\n\n  it('register same kernel on two different backends', async () => {\n    interface TestBackend extends KernelBackend {\n      id: number;\n    }\n    tf.registerBackend('backend1', () => {\n      return {\n        id: 1,\n        dispose: () => null,\n        disposeData: (dataId: {}) => true,\n        numDataIds: () => 0\n      } as TestBackend;\n    });\n    tf.registerBackend('backend2', () => {\n      return {\n        id: 2,\n        dispose: () => null,\n        disposeData: (dataId: {}) => null,\n        numDataIds: () => 0\n      } as TestBackend;\n    });\n\n    let lastStorageId = -1;\n    const kernelFunc: KernelFunc = ({backend}) => {\n      lastStorageId = (backend as TestBackend).id;\n      return {dataId: {}, shape: [], dtype: 'float32'};\n    };\n    tf.registerKernel(\n        {kernelName: 'MyKernel', backendName: 'backend1', kernelFunc});\n    tf.registerKernel(\n        {kernelName: 'MyKernel', backendName: 'backend2', kernelFunc});\n\n    // No kernel has been executed yet.\n    expect(lastStorageId).toBe(-1);\n\n    // Kernel was executed on the first backend.\n    await tf.setBackend('backend1');\n    tf.engine().runKernel('MyKernel', {}, {});\n    expect(lastStorageId).toBe(1);\n\n    // Kernel was executed on the second backend.\n    await tf.setBackend('backend2');\n    tf.engine().runKernel('MyKernel', {}, {});\n    expect(lastStorageId).toBe(2);\n\n    tf.removeBackend('backend1');\n    tf.removeBackend('backend2');\n    tf.unregisterKernel('MyKernel', 'backend1');\n    tf.unregisterKernel('MyKernel', 'backend2');\n  });\n\n  it('register kernel with setup and dispose functions', async () => {\n    const backendName = 'custom-backend';\n    const kernelName = 'MyKernel';\n    interface TestBackend extends KernelBackend {}\n    const customBackend = {\n      dispose: () => null,\n      disposeData: (dataId: {}) => true,\n      numDataIds: () => 0\n    } as TestBackend;\n    tf.registerBackend(backendName, () => customBackend);\n\n    const kernelFunc: KernelFunc = () => {\n      return {dataId: {}, shape: [], dtype: 'float32'};\n    };\n    let setupCalled = false;\n    const setupFunc = (backend: KernelBackend) => {\n      expect(backend).toBe(customBackend);\n      setupCalled = true;\n    };\n    let disposeCalled = false;\n    const disposeFunc = (backend: KernelBackend) => {\n      expect(backend).toBe(customBackend);\n      disposeCalled = true;\n    };\n    tf.registerKernel(\n        {kernelName, backendName, kernelFunc, setupFunc, disposeFunc});\n\n    expect(setupCalled).toBe(false);\n    expect(disposeCalled).toBe(false);\n\n    await tf.setBackend(backendName);\n    expect(setupCalled).toBe(true);\n    expect(disposeCalled).toBe(false);\n\n    // Kernel was executed on the first backend.\n    tf.engine().runKernel(kernelName, {}, {});\n\n    tf.removeBackend(backendName);\n    expect(setupCalled).toBe(true);\n    expect(disposeCalled).toBe(true);\n\n    tf.unregisterKernel(kernelName, backendName);\n  });\n});\n\ndescribeWithFlags('gradient registry', ALL_ENVS, () => {\n  it('register a kernel with gradient and call it', async () => {\n    let kernelWasCalled = false;\n    let gradientWasCalled = false;\n    const kernelName = 'MyKernel';\n    const x = tf.zeros([2, 2]);\n\n    tf.registerKernel({\n      kernelName,\n      backendName: tf.getBackend(),\n      kernelFunc: () => {\n        kernelWasCalled = true;\n        return {dtype: 'float32', shape: [3, 3], dataId: {}};\n      }\n    });\n\n    tf.registerGradient({\n      kernelName,\n      inputsToSave: ['x'],\n      gradFunc: (dy: tf.Tensor, saved) => {\n        // Make sure saved input (x) was passed to the gradient function.\n        expect(saved[0].dataId).toEqual(x.dataId);\n        // Make sure dy matches the shape of the output.\n        expect(dy.shape).toEqual([3, 3]);\n        gradientWasCalled = true;\n        return {x: () => tf.fill([2, 2], 3)};\n      },\n    });\n\n    const gradFunc =\n        tf.grad(x => tf.engine().runKernel(kernelName, {x}, {} /* attrs */));\n    const dx = gradFunc(x);\n    expect(kernelWasCalled).toBe(true);\n    expect(gradientWasCalled).toBe(true);\n    expect(dx.dtype).toBe('float32');\n    expect(dx.shape).toEqual([2, 2]);\n    expectArraysClose(await dx.data(), [3, 3, 3, 3]);\n    tf.unregisterKernel(kernelName, tf.getBackend());\n    tf.unregisterGradient(kernelName);\n  });\n\n  it('register a kernel with gradient that specifies outputsToSave and call it',\n     async () => {\n       let kernelWasCalled = false;\n       let gradientWasCalled = false;\n       const kernelName = 'MyKernel';\n\n       const tensor = tf.zeros([3, 3], 'float32');\n       const forwardReturnDataId = tensor.dataId;\n       tf.registerKernel({\n         kernelName,\n         backendName: tf.getBackend(),\n         kernelFunc: () => {\n           kernelWasCalled = true;\n           return {\n             dtype: tensor.dtype,\n             shape: tensor.shape,\n             dataId: forwardReturnDataId\n           };\n         }\n       });\n\n       tf.registerGradient({\n         kernelName,\n         outputsToSave: [true],\n         gradFunc: (dy: tf.Tensor, saved) => {\n           // Make sure saved output was passed to the gradient function.\n           expect(saved[0].dataId).toEqual(forwardReturnDataId);\n           // Make sure dy matches the shape of the output.\n           expect(dy.shape).toEqual([3, 3]);\n           gradientWasCalled = true;\n           return {x: () => tf.fill([2, 2], 3)};\n         },\n       });\n\n       const gradFunc = tf.grad(\n           x => tf.engine().runKernel(\n               kernelName, {x}, {} /* attrs */\n               ));\n       const x = tf.zeros([2, 2]);\n       const dx = gradFunc(x);\n       expect(kernelWasCalled).toBe(true);\n       expect(gradientWasCalled).toBe(true);\n       expect(dx.dtype).toBe('float32');\n       expect(dx.shape).toEqual([2, 2]);\n       tf.unregisterKernel(kernelName, tf.getBackend());\n       tf.unregisterGradient(kernelName);\n     });\n\n  it('register a kernel with array inputs and saveAllInputs true', async () => {\n    let kernelWasCalled = false;\n    let gradientWasCalled = false;\n    const kernelName = 'MyKernel';\n    const x = [tf.zeros([2, 2]), tf.zeros([2, 2])];\n\n    const forwardReturnDataId = {};\n    tf.registerKernel({\n      kernelName,\n      backendName: tf.getBackend(),\n      kernelFunc: () => {\n        kernelWasCalled = true;\n        return {dtype: 'float32', shape: [3, 3], dataId: forwardReturnDataId};\n      }\n    });\n\n    tf.registerGradient({\n      kernelName,\n      saveAllInputs: true,\n      gradFunc: (dy: tf.Tensor, saved) => {\n        // Make sure saved input (x) was passed to the gradient function.\n        const [$x0, $x1] = x;\n        expect(saved.length).toEqual(x.length);\n        expect($x0.dataId).toEqual(x[0].dataId);\n        expect($x1.dataId).toEqual(x[1].dataId);\n        gradientWasCalled = true;\n        return {0: () => tf.fill([2, 2], 3), 1: () => tf.fill([2, 2], 3)};\n      }\n    });\n\n    // Inputs as array.\n    const z = (...x: tf.Tensor[]) =>\n        // tslint:disable-next-line: no-unnecessary-type-assertion\n        tf.engine().runKernel(\n            kernelName, x as unknown as tf.NamedTensorMap, {} /* attrs */) as\n        tf.Tensor;\n    const gradFunc = tf.grads(z);\n    const dx = gradFunc(x);\n    expect(kernelWasCalled).toBe(true);\n    expect(gradientWasCalled).toBe(true);\n    expect(dx.length).toEqual(2);\n    expect(dx[0].dtype).toBe('float32');\n    expect(dx[0].shape).toEqual([2, 2]);\n    expect(dx[1].dtype).toBe('float32');\n    expect(dx[1].shape).toEqual([2, 2]);\n    expectArraysClose(await dx[0].data(), [3, 3, 3, 3]);\n    expectArraysClose(await dx[1].data(), [3, 3, 3, 3]);\n    tf.unregisterKernel(kernelName, tf.getBackend());\n    tf.unregisterGradient(kernelName);\n  });\n\n  it('register a kernel with map inputs and saveAllInputs true should throw ' +\n         'error',\n     async () => {\n       const kernelName = 'MyKernel';\n       const x0 = tf.zeros([2, 2]);\n       const x1 = tf.zeros([2, 2]);\n\n       const forwardReturnDataId = {};\n       tf.registerKernel({\n         kernelName,\n         backendName: tf.getBackend(),\n         kernelFunc: () => {\n           return {\n             dtype: 'float32',\n             shape: [3, 3],\n             dataId: forwardReturnDataId\n           };\n         }\n       });\n\n       tf.registerGradient({\n         kernelName,\n         saveAllInputs: true,\n         gradFunc: (dy: tf.Tensor, saved) => {\n           // Make sure saved input (x) was passed to the gradient function.\n           const [$x0, $x1] = saved;\n           expect($x0.dataId).toEqual(x0.dataId);\n           expect($x1.dataId).toEqual(x1.dataId);\n           return {x0: () => tf.fill([2, 2], 3), x1: () => tf.fill([2, 2], 3)};\n         }\n       });\n\n       // Inputs as map.\n       const z = (x0: tf.Tensor, x1: tf.Tensor) =>\n           // tslint:disable-next-line: no-unnecessary-type-assertion\n           tf.engine().runKernel(kernelName, {x0, x1}, {} /* attrs */) as\n           tf.Tensor;\n       const gradFunc = tf.grads(z);\n       expect(() => gradFunc([x0, x1]))\n           .toThrowError(\n               /saveAllInputs is true, expected inputs to be an array/);\n       tf.unregisterKernel(kernelName, tf.getBackend());\n       tf.unregisterGradient(kernelName);\n     });\n\n  it('errors when running non-existent gradient', () => {\n    const kernelName = 'MyKernel';\n    const x = tf.zeros([2, 2]);\n\n    tf.registerKernel({\n      kernelName,\n      backendName: tf.getBackend(),\n      kernelFunc: () => ({dtype: 'float32', shape: [3, 3], dataId: {}})\n    });\n\n    const gradFunc =\n        tf.grad(x => tf.engine().runKernel(kernelName, {x}, {} /* attrs */));\n    expect(() => gradFunc(x))\n        .toThrowError(/gradient function not found for MyKernel/);\n\n    tf.unregisterKernel(kernelName, tf.getBackend());\n  });\n\n  // tslint:disable-next-line: ban\n  xit('warning when registering the same gradient twice', () => {\n    const kernelName = 'MyKernel';\n    tf.registerGradient({kernelName, gradFunc: () => null});\n    spyOn(console, 'warn').and.callFake((msg: string) => {\n      expect(msg).toBe('Overriding the gradient for \\'MyKernel\\'');\n    });\n    tf.registerGradient({kernelName, gradFunc: () => null});\n    tf.unregisterGradient(kernelName);\n  });\n});\n"]}
|