gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
// tslint:disable-next-line: no-imports-from-dist
import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
import { cloneTensor, getParamValue, getTensor } from './utils';
export const executeOp = (node, tensorMap, context, ops = tfOps) => {
    switch (node.op) {
        case 'Const': {
            return tensorMap[node.name];
        }
        case 'PlaceholderWithDefault':
            const def = getParamValue('default', node, tensorMap, context);
            return [getTensor(node.name, tensorMap, context) || def];
        case 'Placeholder':
            return [getTensor(node.name, tensorMap, context)];
        case 'Identity':
        case 'StopGradient':
        case 'FakeQuantWithMinMaxVars': { // This op is currently ignored.
            const data = getParamValue('x', node, tensorMap, context);
            return [cloneTensor(data)];
        }
        case 'IdentityN':
            return getParamValue('x', node, tensorMap, context)
                .map((t) => cloneTensor(t));
        case 'Snapshot':
            const snapshot = getParamValue('x', node, tensorMap, context);
            return [cloneTensor(snapshot)];
        case 'Shape':
            return [ops.tensor1d(getParamValue('x', node, tensorMap, context).shape, 'int32')];
        case 'ShapeN':
            return getParamValue('x', node, tensorMap, context)
                .map((t) => ops.tensor1d(t.shape));
        case 'Size':
            return [ops.scalar(getParamValue('x', node, tensorMap, context).size, 'int32')];
        case 'Rank':
            return [ops.scalar(getParamValue('x', node, tensorMap, context).rank, 'int32')];
        case 'NoOp':
            return [ops.scalar(1)];
        case 'Print':
            const input = getParamValue('x', node, tensorMap, context);
            const data = getParamValue('data', node, tensorMap, context);
            const message = getParamValue('message', node, tensorMap, context);
            const summarize = getParamValue('summarize', node, tensorMap, context);
            console.warn('The graph has a tf.print() operation,' +
                'usually used for debugging, which slows down performance.');
            console.log(message);
            for (let i = 0; i < data.length; i++) {
                console.log(Array.prototype.slice.call(data[i].dataSync())
                    .slice(0, summarize));
            }
            return [input];
        default:
            throw TypeError(`Node type ${node.op} is not implemented`);
    }
};
export const CATEGORY = 'graph';
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"graph_executor.js","sourceRoot":"","sources":["../../../../../../../tfjs-converter/src/operations/executors/graph_executor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,iDAAiD;AACjD,OAAO,KAAK,KAAK,MAAM,kDAAkD,CAAC;AAM1E,OAAO,EAAC,WAAW,EAAE,aAAa,EAAE,SAAS,EAAC,MAAM,SAAS,CAAC;AAE9D,MAAM,CAAC,MAAM,SAAS,GAClB,CAAC,IAAU,EAAE,SAA0B,EACtC,OAAyB,EAAE,GAAG,GAAG,KAAK,EAAY,EAAE;IACnD,QAAQ,IAAI,CAAC,EAAE,EAAE;QACf,KAAK,OAAO,CAAC,CAAC;YACZ,OAAO,SAAS,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;SAC7B;QACD,KAAK,wBAAwB;YAC3B,MAAM,GAAG,GACL,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACjE,OAAO,CAAC,SAAS,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,IAAI,GAAG,CAAC,CAAC;QAC3D,KAAK,aAAa;YAChB,OAAO,CAAC,SAAS,CAAC,IAAI,CAAC,IAAI,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC;QACpD,KAAK,UAAU,CAAC;QAChB,KAAK,cAAc,CAAC;QACpB,KAAK,yBAAyB,CAAC,CAAC,EAAG,gCAAgC;YACjE,MAAM,IAAI,GAAG,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACpE,OAAO,CAAC,WAAW,CAAC,IAAI,CAAC,CAAC,CAAC;SAC5B;QACD,KAAK,WAAW;YACd,OAAQ,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAc;iBAC5D,GAAG,CAAC,CAAC,CAAS,EAAE,EAAE,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC1C,KAAK,UAAU;YACb,MAAM,QAAQ,GACT,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC;YAC7D,OAAO,CAAC,WAAW,CAAC,QAAQ,CAAC,CAAC,CAAC;QACjC,KAAK,OAAO;YACV,OAAO,CAAC,GAAG,CAAC,QAAQ,CACf,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC,KAAK,EAC9D,OAAO,CAAC,CAAC,CAAC;QAChB,KAAK,QAAQ;YACX,OAAQ,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAc;iBAC5D,GAAG,CAAC,CAAC,CAAS,EAAE,EAAE,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC;QACjD,KAAK,MAAM;YACT,OAAO,CAAC,GAAG,CAAC,MAAM,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC,IAAI,EAC7D,OAAO,CAAC,CAAC,CAAC;QAChB,KAAK,MAAM;YACT,OAAO,CAAC,GAAG,CAAC,MAAM,CACb,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAY,CAAC,IAAI,EAC7D,OAAO,CAAC,CAAC,CAAC;QAChB,KAAK,MAAM;YACT,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;QACzB,KAAK,OAAO;YACV,MAAM,KAAK,GAAG,aAAa,CAAC,GAAG,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACrE,MAAM,IAAI,GACN,aAAa,CAAC,MAAM,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAa,CAAC;YAChE,MAAM,OAAO,GACT,aAAa,CAAC,SAAS,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACjE,MAAM,SAAS,GACX,aAAa,CAAC,WAAW,EAAE,IAAI,EAAE,SAAS,EAAE,OAAO,CAAW,CAAC;YACnE,OAAO,CAAC,IAAI,CACR,uCAAuC;gBACvC,2DAA2D,CAAC,CAAC;YACjE,OAAO,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC;YACrB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;gBACpC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,SAAS,CAAC,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,QAAQ,EAAE,CAAC;qBACzC,KAAK,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;aACvC;YACD,OAAO,CAAC,KAAK,CAAC,CAAC;QAEjB;YACE,MAAM,SAAS,CAAC,aAAa,IAAI,CAAC,EAAE,qBAAqB,CAAC,CAAC;KAC9D;AACH,CAAC,CAAC;AAEN,MAAM,CAAC,MAAM,QAAQ,GAAG,OAAO,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {Tensor} from '@tensorflow/tfjs-core';\n// tslint:disable-next-line: no-imports-from-dist\nimport * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';\n\nimport {NamedTensorsMap} from '../../data/types';\nimport {ExecutionContext} from '../../executor/execution_context';\nimport {InternalOpExecutor, Node} from '../types';\n\nimport {cloneTensor, getParamValue, getTensor} from './utils';\n\nexport const executeOp: InternalOpExecutor =\n    (node: Node, tensorMap: NamedTensorsMap,\n     context: ExecutionContext, ops = tfOps): Tensor[] => {\n      switch (node.op) {\n        case 'Const': {\n          return tensorMap[node.name];\n        }\n        case 'PlaceholderWithDefault':\n          const def =\n              getParamValue('default', node, tensorMap, context) as Tensor;\n          return [getTensor(node.name, tensorMap, context) || def];\n        case 'Placeholder':\n          return [getTensor(node.name, tensorMap, context)];\n        case 'Identity':\n        case 'StopGradient':\n        case 'FakeQuantWithMinMaxVars': {  // This op is currently ignored.\n          const data = getParamValue('x', node, tensorMap, context) as Tensor;\n          return [cloneTensor(data)];\n        }\n        case 'IdentityN':\n          return (getParamValue('x', node, tensorMap, context) as Tensor[])\n              .map((t: Tensor) => cloneTensor(t));\n        case 'Snapshot':\n          const snapshot =\n              (getParamValue('x', node, tensorMap, context) as Tensor);\n          return [cloneTensor(snapshot)];\n        case 'Shape':\n          return [ops.tensor1d(\n              (getParamValue('x', node, tensorMap, context) as Tensor).shape,\n              'int32')];\n        case 'ShapeN':\n          return (getParamValue('x', node, tensorMap, context) as Tensor[])\n              .map((t: Tensor) => ops.tensor1d(t.shape));\n        case 'Size':\n          return [ops.scalar(\n              (getParamValue('x', node, tensorMap, context) as Tensor).size,\n              'int32')];\n        case 'Rank':\n          return [ops.scalar(\n              (getParamValue('x', node, tensorMap, context) as Tensor).rank,\n              'int32')];\n        case 'NoOp':\n          return [ops.scalar(1)];\n        case 'Print':\n          const input = getParamValue('x', node, tensorMap, context) as Tensor;\n          const data =\n              getParamValue('data', node, tensorMap, context) as Tensor[];\n          const message =\n              getParamValue('message', node, tensorMap, context) as string;\n          const summarize =\n              getParamValue('summarize', node, tensorMap, context) as number;\n          console.warn(\n              'The graph has a tf.print() operation,' +\n              'usually used for debugging, which slows down performance.');\n          console.log(message);\n          for (let i = 0; i < data.length; i++) {\n            console.log(Array.prototype.slice.call(data[i].dataSync())\n                            .slice(0, summarize));\n          }\n          return [input];\n\n        default:\n          throw TypeError(`Node type ${node.op} is not implemented`);\n      }\n    };\n\nexport const CATEGORY = 'graph';\n"]}