gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
/**
 * @license
 * Copyright 2017 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import * as util from './util';
/**
 * Computes a list of TapeNodes that connect x to y, filtering everything else
 * out and preserving the order of the original tape elements.
 *
 * @param tape The tape elements to filter.
 * @param xs The input Tensors.
 * @param y The output Tensor.
 */
export function getFilteredNodesXToY(tape, xs, y) {
    // Forward pass to compute all the nodes and Tensors that are transitively a
    // function of x.
    const tensorsFromX = {};
    const nodesFromX = {};
    for (let i = 0; i < xs.length; i++) {
        tensorsFromX[xs[i].id] = true;
    }
    for (let i = 0; i < tape.length; i++) {
        const node = tape[i];
        const nodeInputs = node.inputs;
        for (const inputName in nodeInputs) {
            const input = nodeInputs[inputName];
            let anyInputFromX = false;
            for (let j = 0; j < xs.length; j++) {
                if (tensorsFromX[input.id]) {
                    node.outputs.forEach(output => tensorsFromX[output.id] = true);
                    anyInputFromX = true;
                    nodesFromX[node.id] = true;
                    break;
                }
            }
            if (anyInputFromX) {
                break;
            }
        }
    }
    // Backward pass to find all of the nodes and Tensors that lead to y.
    const tensorsLeadToY = {};
    tensorsLeadToY[y.id] = true;
    const nodesToY = {};
    for (let i = tape.length - 1; i >= 0; i--) {
        const node = tape[i];
        const nodeInputs = node.inputs;
        // If any of the outputs lead to y, mark all of the inputs as leading to y.
        for (let j = 0; j < node.outputs.length; j++) {
            if (tensorsLeadToY[node.outputs[j].id]) {
                for (const inputName in nodeInputs) {
                    tensorsLeadToY[nodeInputs[inputName].id] = true;
                    nodesToY[node.id] = true;
                }
                break;
            }
        }
    }
    // Return the paths that come from x and lead to y.
    const filteredTape = [];
    for (let i = 0; i < tape.length; i++) {
        const node = tape[i];
        if (nodesFromX[node.id] && nodesToY[node.id]) {
            // Prune the inputs from the node that aren't a function of x.
            const prunedInputs = {};
            for (const inputName in node.inputs) {
                const nodeInput = node.inputs[inputName];
                if (tensorsFromX[nodeInput.id]) {
                    prunedInputs[inputName] = nodeInput;
                }
            }
            // Copy the node and overwrite inputsAndArgs to the pruned version.
            const prunedNode = Object.assign({}, node);
            prunedNode.inputs = prunedInputs;
            prunedNode.outputs = node.outputs;
            filteredTape.push(prunedNode);
        }
    }
    return filteredTape;
}
/**
 * Backpropagate gradients through the filtered TapeNodes.
 *
 * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
 * is mutated by this method.
 * @param filteredTape The filtered TapeNodes to backprop through.
 */
export function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) {
    // Walk the tape backward and keep a map of Tensor to its gradient.
    for (let i = filteredTape.length - 1; i >= 0; i--) {
        const node = filteredTape[i];
        const dys = [];
        node.outputs.forEach(o => {
            const gradTensor = tensorAccumulatedGradientMap[o.id];
            if (gradTensor != null) {
                dys.push(gradTensor);
            }
            else {
                // This particular output is not in the back-propagation subgraph, so it
                // does not affect the final output, thus we put null for its dy.
                dys.push(null);
            }
        });
        if (node.gradient == null) {
            throw new Error(`Cannot compute gradient: gradient function not found ` +
                `for ${node.kernelName}.`);
        }
        // Backprop dy through this node and accumulate gradients over the inputs.
        const inputGradients = node.gradient(dys);
        for (const inputName in node.inputs) {
            if (!(inputName in inputGradients)) {
                throw new Error(`Cannot backprop through input ${inputName}. ` +
                    `Available gradients found: ${Object.keys(inputGradients)}.`);
            }
            // Call the gradient function.
            const dx = tidy(() => inputGradients[inputName]());
            if (dx.dtype !== 'float32') {
                throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
                    `${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);
            }
            const x = node.inputs[inputName];
            if (!util.arraysEqual(dx.shape, x.shape)) {
                throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
                    `'${inputName}' has shape '${dx.shape}', which does not match ` +
                    `the shape of the input '${x.shape}'`);
            }
            if (tensorAccumulatedGradientMap[x.id] == null) {
                tensorAccumulatedGradientMap[x.id] = dx;
            }
            else {
                const curGradient = tensorAccumulatedGradientMap[x.id];
                tensorAccumulatedGradientMap[x.id] = add(curGradient, dx);
                curGradient.dispose();
            }
        }
    }
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoidGFwZS5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvdGFwZS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFJSCxPQUFPLEtBQUssSUFBSSxNQUFNLFFBQVEsQ0FBQztBQWdCL0I7Ozs7Ozs7R0FPRztBQUNILE1BQU0sVUFBVSxvQkFBb0IsQ0FDaEMsSUFBZ0IsRUFBRSxFQUFZLEVBQUUsQ0FBUztJQUMzQyw0RUFBNEU7SUFDNUUsaUJBQWlCO0lBQ2pCLE1BQU0sWUFBWSxHQUFrQyxFQUFFLENBQUM7SUFDdkQsTUFBTSxVQUFVLEdBQWdDLEVBQUUsQ0FBQztJQUNuRCxLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsRUFBRSxDQUFDLE1BQU0sRUFBRSxDQUFDLEVBQUUsRUFBRTtRQUNsQyxZQUFZLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQztLQUMvQjtJQUVELEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsTUFBTSxFQUFFLENBQUMsRUFBRSxFQUFFO1FBQ3BDLE1BQU0sSUFBSSxHQUFHLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUNyQixNQUFNLFVBQVUsR0FBRyxJQUFJLENBQUMsTUFBTSxDQUFDO1FBQy9CLEtBQUssTUFBTSxTQUFTLElBQUksVUFBVSxFQUFFO1lBQ2xDLE1BQU0sS0FBSyxHQUFHLFVBQVUsQ0FBQyxTQUFTLENBQUMsQ0FBQztZQUVwQyxJQUFJLGFBQWEsR0FBRyxLQUFLLENBQUM7WUFDMUIsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLEVBQUUsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxFQUFFLEVBQUU7Z0JBQ2xDLElBQUksWUFBWSxDQUFDLEtBQUssQ0FBQyxFQUFFLENBQUMsRUFBRTtvQkFDMUIsSUFBSSxDQUFDLE9BQU8sQ0FBQyxPQUFPLENBQUMsTUFBTSxDQUFDLEVBQUUsQ0FBQyxZQUFZLENBQUMsTUFBTSxDQUFDLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQyxDQUFDO29CQUMvRCxhQUFhLEdBQUcsSUFBSSxDQUFDO29CQUNyQixVQUFVLENBQUMsSUFBSSxDQUFDLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQztvQkFDM0IsTUFBTTtpQkFDUDthQUNGO1lBRUQsSUFBSSxhQUFhLEVBQUU7Z0JBQ2pCLE1BQU07YUFDUDtTQUNGO0tBQ0Y7SUFFRCxxRUFBcUU7SUFDckUsTUFBTSxjQUFjLEdBQWtDLEVBQUUsQ0FBQztJQUN6RCxjQUFjLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQztJQUM1QixNQUFNLFFBQVEsR0FBZ0MsRUFBRSxDQUFDO0lBRWpELEtBQUssSUFBSSxDQUFDLEdBQUcsSUFBSSxDQUFDLE1BQU0sR0FBRyxDQUFDLEVBQUUsQ0FBQyxJQUFJLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRTtRQUN6QyxNQUFNLElBQUksR0FBRyxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDckIsTUFBTSxVQUFVLEdBQUcsSUFBSSxDQUFDLE1BQU0sQ0FBQztRQUUvQiwyRUFBMkU7UUFDM0UsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQyxPQUFPLENBQUMsTUFBTSxFQUFFLENBQUMsRUFBRSxFQUFFO1lBQzVDLElBQUksY0FBYyxDQUFDLElBQUksQ0FBQyxPQUFPLENBQUMsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUU7Z0JBQ3RDLEtBQUssTUFBTSxTQUFTLElBQUksVUFBVSxFQUFFO29CQUNsQyxjQUFjLENBQUMsVUFBVSxDQUFDLFNBQVMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQztvQkFDaEQsUUFBUSxDQUFDLElBQUksQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUM7aUJBQzFCO2dCQUNELE1BQU07YUFDUDtTQUNGO0tBQ0Y7SUFFRCxtREFBbUQ7SUFDbkQsTUFBTSxZQUFZLEdBQWUsRUFBRSxDQUFDO0lBQ3BDLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsTUFBTSxFQUFFLENBQUMsRUFBRSxFQUFFO1FBQ3BDLE1BQU0sSUFBSSxHQUFHLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUVyQixJQUFJLFVBQVUsQ0FBQyxJQUFJLENBQUMsRUFBRSxDQUFDLElBQUksUUFBUSxDQUFDLElBQUksQ0FBQyxFQUFFLENBQUMsRUFBRTtZQUM1Qyw4REFBOEQ7WUFDOUQsTUFBTSxZQUFZLEdBQWtDLEVBQUUsQ0FBQztZQUN2RCxLQUFLLE1BQU0sU0FBUyxJQUFJLElBQUksQ0FBQyxNQUFNLEVBQUU7Z0JBQ25DLE1BQU0sU0FBUyxHQUFHLElBQUksQ0FBQyxNQUFNLENBQUMsU0FBUyxDQUFDLENBQUM7Z0JBQ3pDLElBQUksWUFBWSxDQUFDLFNBQVMsQ0FBQyxFQUFFLENBQUMsRUFBRTtvQkFDOUIsWUFBWSxDQUFDLFNBQVMsQ0FBQyxHQUFHLFNBQVMsQ0FBQztpQkFDckM7YUFDRjtZQUVELG1FQUFtRTtZQUNuRSxNQUFNLFVBQVUsR0FBRyxNQUFNLENBQUMsTUFBTSxDQUFDLEVBQUUsRUFBRSxJQUFJLENBQUMsQ0FBQztZQUMzQyxVQUFVLENBQUMsTUFBTSxHQUFHLFlBQVksQ0FBQztZQUNqQyxVQUFVLENBQUMsT0FBTyxHQUFHLElBQUksQ0FBQyxPQUFPLENBQUM7WUFFbEMsWUFBWSxDQUFDLElBQUksQ0FBQyxVQUFVLENBQUMsQ0FBQztTQUMvQjtLQUNGO0lBRUQsT0FBTyxZQUFZLENBQUM7QUFDdEIsQ0FBQztBQUVEOzs7Ozs7R0FNRztBQUNILE1BQU0sVUFBVSxzQkFBc0IsQ0FDbEMsNEJBQTBELEVBQzFELFlBQXdCLEVBQUUsSUFBNkIsRUFDdkQsR0FBcUM7SUFDdkMsbUVBQW1FO0lBQ25FLEtBQUssSUFBSSxDQUFDLEdBQUcsWUFBWSxDQUFDLE1BQU0sR0FBRyxDQUFDLEVBQUUsQ0FBQyxJQUFJLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRTtRQUNqRCxNQUFNLElBQUksR0FBRyxZQUFZLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFFN0IsTUFBTSxHQUFHLEdBQWEsRUFBRSxDQUFDO1FBQ3pCLElBQUksQ0FBQyxPQUFPLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxFQUFFO1lBQ3ZCLE1BQU0sVUFBVSxHQUFHLDRCQUE0QixDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQztZQUN0RCxJQUFJLFVBQVUsSUFBSSxJQUFJLEVBQUU7Z0JBQ3RCLEdBQUcsQ0FBQyxJQUFJLENBQUMsVUFBVSxDQUFDLENBQUM7YUFDdEI7aUJBQU07Z0JBQ0wsd0VBQXdFO2dCQUN4RSxpRUFBaUU7Z0JBQ2pFLEdBQUcsQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLENBQUM7YUFDaEI7UUFDSCxDQUFDLENBQUMsQ0FBQztRQUVILElBQUksSUFBSSxDQUFDLFFBQVEsSUFBSSxJQUFJLEVBQUU7WUFDekIsTUFBTSxJQUFJLEtBQUssQ0FDWCx1REFBdUQ7Z0JBQ3ZELE9BQU8sSUFBSSxDQUFDLFVBQVUsR0FBRyxDQUFDLENBQUM7U0FDaEM7UUFFRCwwRUFBMEU7UUFDMUUsTUFBTSxjQUFjLEdBQUcsSUFBSSxDQUFDLFFBQVEsQ0FBQyxHQUFHLENBQUMsQ0FBQztRQUUxQyxLQUFLLE1BQU0sU0FBUyxJQUFJLElBQUksQ0FBQyxNQUFNLEVBQUU7WUFDbkMsSUFBSSxDQUFDLENBQUMsU0FBUyxJQUFJLGNBQWMsQ0FBQyxFQUFFO2dCQUNsQyxNQUFNLElBQUksS0FBSyxDQUNYLGlDQUFpQyxTQUFTLElBQUk7b0JBQzlDLDhCQUE4QixNQUFNLENBQUMsSUFBSSxDQUFDLGNBQWMsQ0FBQyxHQUFHLENBQUMsQ0FBQzthQUNuRTtZQUVELDhCQUE4QjtZQUM5QixNQUFNLEVBQUUsR0FBRyxJQUFJLENBQUMsR0FBRyxFQUFFLENBQUMsY0FBYyxDQUFDLFNBQVMsQ0FBQyxFQUFFLENBQUMsQ0FBQztZQUNuRCxJQUFJLEVBQUUsQ0FBQyxLQUFLLEtBQUssU0FBUyxFQUFFO2dCQUMxQixNQUFNLElBQUksS0FBSyxDQUNYLDRCQUNJLElBQUksQ0FBQyxVQUFVLDBCQUEwQjtvQkFDN0MsR0FBRyxTQUFTLHdDQUF3QyxFQUFFLENBQUMsS0FBSyxHQUFHLENBQUMsQ0FBQzthQUN0RTtZQUNELE1BQU0sQ0FBQyxHQUFHLElBQUksQ0FBQyxNQUFNLENBQUMsU0FBUyxDQUFDLENBQUM7WUFDakMsSUFBSSxDQUFDLElBQUksQ0FBQyxXQUFXLENBQUMsRUFBRSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUMsS0FBSyxDQUFDLEVBQUU7Z0JBQ3hDLE1BQU0sSUFBSSxLQUFLLENBQ1gsNEJBQ0ksSUFBSSxDQUFDLFVBQVUsMEJBQTBCO29CQUM3QyxJQUFJLFNBQVMsZ0JBQWdCLEVBQUUsQ0FBQyxLQUFLLDBCQUEwQjtvQkFDL0QsMkJBQTJCLENBQUMsQ0FBQyxLQUFLLEdBQUcsQ0FBQyxDQUFDO2FBQzVDO1lBRUQsSUFBSSw0QkFBNEIsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLElBQUksSUFBSSxFQUFFO2dCQUM5Qyw0QkFBNEIsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEdBQUcsRUFBRSxDQUFDO2FBQ3pDO2lCQUFNO2dCQUNMLE1BQU0sV0FBVyxHQUFHLDRCQUE0QixDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQztnQkFDdkQsNEJBQTRCLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxHQUFHLEdBQUcsQ0FBQyxXQUFXLEVBQUUsRUFBRSxDQUFDLENBQUM7Z0JBQzFELFdBQVcsQ0FBQyxPQUFPLEVBQUUsQ0FBQzthQUN2QjtTQUNGO0tBQ0Y7QUFDSCxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge1RlbnNvcn0gZnJvbSAnLi90ZW5zb3InO1xuaW1wb3J0IHtOYW1lZFRlbnNvck1hcH0gZnJvbSAnLi90ZW5zb3JfdHlwZXMnO1xuaW1wb3J0ICogYXMgdXRpbCBmcm9tICcuL3V0aWwnO1xuXG5leHBvcnQgaW50ZXJmYWNlIFRhcGVOb2RlIHtcbiAgaWQ6IG51bWJlcjtcbiAga2VybmVsTmFtZTogc3RyaW5nO1xuICBvdXRwdXRzOiBUZW5zb3JbXTtcbiAgaW5wdXRzOiBOYW1lZFRlbnNvck1hcDtcbiAgLy8gT3B0aW9uYWwgcGFyYW1zLCBkZWZpbmVkIG9ubHkgZm9yIG9wcyB3aXRoIGdyYWRpZW50IGltcGwuXG4gIGdyYWRpZW50PzogKGR5czogVGVuc29yW10pID0+IE5hbWVkR3JhZGllbnRNYXA7XG4gIHNhdmVkPzogVGVuc29yW107XG59XG5cbmV4cG9ydCB0eXBlIE5hbWVkR3JhZGllbnRNYXAgPSB7XG4gIFtpbnB1dE5hbWU6IHN0cmluZ106ICgpID0+IFRlbnNvcjtcbn07XG5cbi8qKlxuICogQ29tcHV0ZXMgYSBsaXN0IG9mIFRhcGVOb2RlcyB0aGF0IGNvbm5lY3QgeCB0byB5LCBmaWx0ZXJpbmcgZXZlcnl0aGluZyBlbHNlXG4gKiBvdXQgYW5kIHByZXNlcnZpbmcgdGhlIG9yZGVyIG9mIHRoZSBvcmlnaW5hbCB0YXBlIGVsZW1lbnRzLlxuICpcbiAqIEBwYXJhbSB0YXBlIFRoZSB0YXBlIGVsZW1lbnRzIHRvIGZpbHRlci5cbiAqIEBwYXJhbSB4cyBUaGUgaW5wdXQgVGVuc29ycy5cbiAqIEBwYXJhbSB5IFRoZSBvdXRwdXQgVGVuc29yLlxuICovXG5leHBvcnQgZnVuY3Rpb24gZ2V0RmlsdGVyZWROb2Rlc1hUb1koXG4gICAgdGFwZTogVGFwZU5vZGVbXSwgeHM6IFRlbnNvcltdLCB5OiBUZW5zb3IpOiBUYXBlTm9kZVtdIHtcbiAgLy8gRm9yd2FyZCBwYXNzIHRvIGNvbXB1dGUgYWxsIHRoZSBub2RlcyBhbmQgVGVuc29ycyB0aGF0IGFyZSB0cmFuc2l0aXZlbHkgYVxuICAvLyBmdW5jdGlvbiBvZiB4LlxuICBjb25zdCB0ZW5zb3JzRnJvbVg6IHtbdGVuc29ySWQ6IG51bWJlcl06IGJvb2xlYW59ID0ge307XG4gIGNvbnN0IG5vZGVzRnJvbVg6IHtbbm9kZUlkOiBudW1iZXJdOiBib29sZWFufSA9IHt9O1xuICBmb3IgKGxldCBpID0gMDsgaSA8IHhzLmxlbmd0aDsgaSsrKSB7XG4gICAgdGVuc29yc0Zyb21YW3hzW2ldLmlkXSA9IHRydWU7XG4gIH1cblxuICBmb3IgKGxldCBpID0gMDsgaSA8IHRhcGUubGVuZ3RoOyBpKyspIHtcbiAgICBjb25zdCBub2RlID0gdGFwZVtpXTtcbiAgICBjb25zdCBub2RlSW5wdXRzID0gbm9kZS5pbnB1dHM7XG4gICAgZm9yIChjb25zdCBpbnB1dE5hbWUgaW4gbm9kZUlucHV0cykge1xuICAgICAgY29uc3QgaW5wdXQgPSBub2RlSW5wdXRzW2lucHV0TmFtZV07XG5cbiAgICAgIGxldCBhbnlJbnB1dEZyb21YID0gZmFsc2U7XG4gICAgICBmb3IgKGxldCBqID0gMDsgaiA8IHhzLmxlbmd0aDsgaisrKSB7XG4gICAgICAgIGlmICh0ZW5zb3JzRnJvbVhbaW5wdXQuaWRdKSB7XG4gICAgICAgICAgbm9kZS5vdXRwdXRzLmZvckVhY2gob3V0cHV0ID0+IHRlbnNvcnNGcm9tWFtvdXRwdXQuaWRdID0gdHJ1ZSk7XG4gICAgICAgICAgYW55SW5wdXRGcm9tWCA9IHRydWU7XG4gICAgICAgICAgbm9kZXNGcm9tWFtub2RlLmlkXSA9IHRydWU7XG4gICAgICAgICAgYnJlYWs7XG4gICAgICAgIH1cbiAgICAgIH1cblxuICAgICAgaWYgKGFueUlucHV0RnJvbVgpIHtcbiAgICAgICAgYnJlYWs7XG4gICAgICB9XG4gICAgfVxuICB9XG5cbiAgLy8gQmFja3dhcmQgcGFzcyB0byBmaW5kIGFsbCBvZiB0aGUgbm9kZXMgYW5kIFRlbnNvcnMgdGhhdCBsZWFkIHRvIHkuXG4gIGNvbnN0IHRlbnNvcnNMZWFkVG9ZOiB7W3RlbnNvcklkOiBudW1iZXJdOiBib29sZWFufSA9IHt9O1xuICB0ZW5zb3JzTGVhZFRvWVt5LmlkXSA9IHRydWU7XG4gIGNvbnN0IG5vZGVzVG9ZOiB7W25vZGVJZDogbnVtYmVyXTogYm9vbGVhbn0gPSB7fTtcblxuICBmb3IgKGxldCBpID0gdGFwZS5sZW5ndGggLSAxOyBpID49IDA7IGktLSkge1xuICAgIGNvbnN0IG5vZGUgPSB0YXBlW2ldO1xuICAgIGNvbnN0IG5vZGVJbnB1dHMgPSBub2RlLmlucHV0cztcblxuICAgIC8vIElmIGFueSBvZiB0aGUgb3V0cHV0cyBsZWFkIHRvIHksIG1hcmsgYWxsIG9mIHRoZSBpbnB1dHMgYXMgbGVhZGluZyB0byB5LlxuICAgIGZvciAobGV0IGogPSAwOyBqIDwgbm9kZS5vdXRwdXRzLmxlbmd0aDsgaisrKSB7XG4gICAgICBpZiAodGVuc29yc0xlYWRUb1lbbm9kZS5vdXRwdXRzW2pdLmlkXSkge1xuICAgICAgICBmb3IgKGNvbnN0IGlucHV0TmFtZSBpbiBub2RlSW5wdXRzKSB7XG4gICAgICAgICAgdGVuc29yc0xlYWRUb1lbbm9kZUlucHV0c1tpbnB1dE5hbWVdLmlkXSA9IHRydWU7XG4gICAgICAgICAgbm9kZXNUb1lbbm9kZS5pZF0gPSB0cnVlO1xuICAgICAgICB9XG4gICAgICAgIGJyZWFrO1xuICAgICAgfVxuICAgIH1cbiAgfVxuXG4gIC8vIFJldHVybiB0aGUgcGF0aHMgdGhhdCBjb21lIGZyb20geCBhbmQgbGVhZCB0byB5LlxuICBjb25zdCBmaWx0ZXJlZFRhcGU6IFRhcGVOb2RlW10gPSBbXTtcbiAgZm9yIChsZXQgaSA9IDA7IGkgPCB0YXBlLmxlbmd0aDsgaSsrKSB7XG4gICAgY29uc3Qgbm9kZSA9IHRhcGVbaV07XG5cbiAgICBpZiAobm9kZXNGcm9tWFtub2RlLmlkXSAmJiBub2Rlc1RvWVtub2RlLmlkXSkge1xuICAgICAgLy8gUHJ1bmUgdGhlIGlucHV0cyBmcm9tIHRoZSBub2RlIHRoYXQgYXJlbid0IGEgZnVuY3Rpb24gb2YgeC5cbiAgICAgIGNvbnN0IHBydW5lZElucHV0czoge1tpbnB1dE5hbWU6IHN0cmluZ106IFRlbnNvcn0gPSB7fTtcbiAgICAgIGZvciAoY29uc3QgaW5wdXROYW1lIGluIG5vZGUuaW5wdXRzKSB7XG4gICAgICAgIGNvbnN0IG5vZGVJbnB1dCA9IG5vZGUuaW5wdXRzW2lucHV0TmFtZV07XG4gICAgICAgIGlmICh0ZW5zb3JzRnJvbVhbbm9kZUlucHV0LmlkXSkge1xuICAgICAgICAgIHBydW5lZElucHV0c1tpbnB1dE5hbWVdID0gbm9kZUlucHV0O1xuICAgICAgICB9XG4gICAgICB9XG5cbiAgICAgIC8vIENvcHkgdGhlIG5vZGUgYW5kIG92ZXJ3cml0ZSBpbnB1dHNBbmRBcmdzIHRvIHRoZSBwcnVuZWQgdmVyc2lvbi5cbiAgICAgIGNvbnN0IHBydW5lZE5vZGUgPSBPYmplY3QuYXNzaWduKHt9LCBub2RlKTtcbiAgICAgIHBydW5lZE5vZGUuaW5wdXRzID0gcHJ1bmVkSW5wdXRzO1xuICAgICAgcHJ1bmVkTm9kZS5vdXRwdXRzID0gbm9kZS5vdXRwdXRzO1xuXG4gICAgICBmaWx0ZXJlZFRhcGUucHVzaChwcnVuZWROb2RlKTtcbiAgICB9XG4gIH1cblxuICByZXR1cm4gZmlsdGVyZWRUYXBlO1xufVxuXG4vKipcbiAqIEJhY2twcm9wYWdhdGUgZ3JhZGllbnRzIHRocm91Z2ggdGhlIGZpbHRlcmVkIFRhcGVOb2Rlcy5cbiAqXG4gKiBAcGFyYW0gdGVuc29yQWNjdW11bGF0ZWRHcmFkaWVudE1hcCBBIG1hcCBvZiBUZW5zb3IgdG8gaXRzIGdyYWRpZW50LiBUaGlzIG1hcFxuICogaXMgbXV0YXRlZCBieSB0aGlzIG1ldGhvZC5cbiAqIEBwYXJhbSBmaWx0ZXJlZFRhcGUgVGhlIGZpbHRlcmVkIFRhcGVOb2RlcyB0byBiYWNrcHJvcCB0aHJvdWdoLlxuICovXG5leHBvcnQgZnVuY3Rpb24gYmFja3Byb3BhZ2F0ZUdyYWRpZW50cyhcbiAgICB0ZW5zb3JBY2N1bXVsYXRlZEdyYWRpZW50TWFwOiB7W3RlbnNvcklkOiBudW1iZXJdOiBUZW5zb3J9LFxuICAgIGZpbHRlcmVkVGFwZTogVGFwZU5vZGVbXSwgdGlkeTogKGY6IEZ1bmN0aW9uKSA9PiBUZW5zb3IsXG4gICAgYWRkOiAoYTogVGVuc29yLCBiOiBUZW5zb3IpID0+IFRlbnNvcikge1xuICAvLyBXYWxrIHRoZSB0YXBlIGJhY2t3YXJkIGFuZCBrZWVwIGEgbWFwIG9mIFRlbnNvciB0byBpdHMgZ3JhZGllbnQuXG4gIGZvciAobGV0IGkgPSBmaWx0ZXJlZFRhcGUubGVuZ3RoIC0gMTsgaSA+PSAwOyBpLS0pIHtcbiAgICBjb25zdCBub2RlID0gZmlsdGVyZWRUYXBlW2ldO1xuXG4gICAgY29uc3QgZHlzOiBUZW5zb3JbXSA9IFtdO1xuICAgIG5vZGUub3V0cHV0cy5mb3JFYWNoKG8gPT4ge1xuICAgICAgY29uc3QgZ3JhZFRlbnNvciA9IHRlbnNvckFjY3VtdWxhdGVkR3JhZGllbnRNYXBbby5pZF07XG4gICAgICBpZiAoZ3JhZFRlbnNvciAhPSBudWxsKSB7XG4gICAgICAgIGR5cy5wdXNoKGdyYWRUZW5zb3IpO1xuICAgICAgfSBlbHNlIHtcbiAgICAgICAgLy8gVGhpcyBwYXJ0aWN1bGFyIG91dHB1dCBpcyBub3QgaW4gdGhlIGJhY2stcHJvcGFnYXRpb24gc3ViZ3JhcGgsIHNvIGl0XG4gICAgICAgIC8vIGRvZXMgbm90IGFmZmVjdCB0aGUgZmluYWwgb3V0cHV0LCB0aHVzIHdlIHB1dCBudWxsIGZvciBpdHMgZHkuXG4gICAgICAgIGR5cy5wdXNoKG51bGwpO1xuICAgICAgfVxuICAgIH0pO1xuXG4gICAgaWYgKG5vZGUuZ3JhZGllbnQgPT0gbnVsbCkge1xuICAgICAgdGhyb3cgbmV3IEVycm9yKFxuICAgICAgICAgIGBDYW5ub3QgY29tcHV0ZSBncmFkaWVudDogZ3JhZGllbnQgZnVuY3Rpb24gbm90IGZvdW5kIGAgK1xuICAgICAgICAgIGBmb3IgJHtub2RlLmtlcm5lbE5hbWV9LmApO1xuICAgIH1cblxuICAgIC8vIEJhY2twcm9wIGR5IHRocm91Z2ggdGhpcyBub2RlIGFuZCBhY2N1bXVsYXRlIGdyYWRpZW50cyBvdmVyIHRoZSBpbnB1dHMuXG4gICAgY29uc3QgaW5wdXRHcmFkaWVudHMgPSBub2RlLmdyYWRpZW50KGR5cyk7XG5cbiAgICBmb3IgKGNvbnN0IGlucHV0TmFtZSBpbiBub2RlLmlucHV0cykge1xuICAgICAgaWYgKCEoaW5wdXROYW1lIGluIGlucHV0R3JhZGllbnRzKSkge1xuICAgICAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgICAgICBgQ2Fubm90IGJhY2twcm9wIHRocm91Z2ggaW5wdXQgJHtpbnB1dE5hbWV9LiBgICtcbiAgICAgICAgICAgIGBBdmFpbGFibGUgZ3JhZGllbnRzIGZvdW5kOiAke09iamVjdC5rZXlzKGlucHV0R3JhZGllbnRzKX0uYCk7XG4gICAgICB9XG5cbiAgICAgIC8vIENhbGwgdGhlIGdyYWRpZW50IGZ1bmN0aW9uLlxuICAgICAgY29uc3QgZHggPSB0aWR5KCgpID0+IGlucHV0R3JhZGllbnRzW2lucHV0TmFtZV0oKSk7XG4gICAgICBpZiAoZHguZHR5cGUgIT09ICdmbG9hdDMyJykge1xuICAgICAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgICAgICBgRXJyb3IgaW4gZ3JhZGllbnQgZm9yIG9wICR7XG4gICAgICAgICAgICAgICAgbm9kZS5rZXJuZWxOYW1lfS4gVGhlIGdyYWRpZW50IG9mIGlucHV0IGAgK1xuICAgICAgICAgICAgYCR7aW5wdXROYW1lfSBtdXN0IGhhdmUgJ2Zsb2F0MzInIGR0eXBlLCBidXQgaGFzICcke2R4LmR0eXBlfSdgKTtcbiAgICAgIH1cbiAgICAgIGNvbnN0IHggPSBub2RlLmlucHV0c1tpbnB1dE5hbWVdO1xuICAgICAgaWYgKCF1dGlsLmFycmF5c0VxdWFsKGR4LnNoYXBlLCB4LnNoYXBlKSkge1xuICAgICAgICB0aHJvdyBuZXcgRXJyb3IoXG4gICAgICAgICAgICBgRXJyb3IgaW4gZ3JhZGllbnQgZm9yIG9wICR7XG4gICAgICAgICAgICAgICAgbm9kZS5rZXJuZWxOYW1lfS4gVGhlIGdyYWRpZW50IG9mIGlucHV0IGAgK1xuICAgICAgICAgICAgYCcke2lucHV0TmFtZX0nIGhhcyBzaGFwZSAnJHtkeC5zaGFwZX0nLCB3aGljaCBkb2VzIG5vdCBtYXRjaCBgICtcbiAgICAgICAgICAgIGB0aGUgc2hhcGUgb2YgdGhlIGlucHV0ICcke3guc2hhcGV9J2ApO1xuICAgICAgfVxuXG4gICAgICBpZiAodGVuc29yQWNjdW11bGF0ZWRHcmFkaWVudE1hcFt4LmlkXSA9PSBudWxsKSB7XG4gICAgICAgIHRlbnNvckFjY3VtdWxhdGVkR3JhZGllbnRNYXBbeC5pZF0gPSBkeDtcbiAgICAgIH0gZWxzZSB7XG4gICAgICAgIGNvbnN0IGN1ckdyYWRpZW50ID0gdGVuc29yQWNjdW11bGF0ZWRHcmFkaWVudE1hcFt4LmlkXTtcbiAgICAgICAgdGVuc29yQWNjdW11bGF0ZWRHcmFkaWVudE1hcFt4LmlkXSA9IGFkZChjdXJHcmFkaWVudCwgZHgpO1xuICAgICAgICBjdXJHcmFkaWVudC5kaXNwb3NlKCk7XG4gICAgICB9XG4gICAgfVxuICB9XG59XG4iXX0=