/** * @license * Copyright 2023 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ !function(t,e){"object"==typeof exports&&"undefined"!=typeof module?e(exports,require("@tensorflow/tfjs-core")):"function"==typeof define&&define.amd?define(["exports","@tensorflow/tfjs-core"],e):e((t="undefined"!=typeof globalThis?globalThis:t||self).tf=t.tf||{},t.tf)}(this,(function(t,e){"use strict";function n(t){var e=Object.create(null);return t&&Object.keys(t).forEach((function(n){if("default"!==n){var s=Object.getOwnPropertyDescriptor(t,n);Object.defineProperty(e,n,s.get?s:{enumerable:!0,get:function(){return t[n]}})}})),e.default=t,e}function s(t,e){return e.forEach((function(e){e&&"string"!=typeof e&&!Array.isArray(e)&&Object.keys(e).forEach((function(n){if("default"!==n&&!(n in t)){var s=Object.getOwnPropertyDescriptor(e,n);Object.defineProperty(t,n,s.get?s:{enumerable:!0,get:function(){return e[n]}})}}))})),t}var i=n(e);class r extends Error{constructor(t){super(t),Object.setPrototypeOf(this,r.prototype)}}class a extends Error{constructor(t){super(t),Object.setPrototypeOf(this,a.prototype)}}class o extends Error{constructor(t){super(t),Object.setPrototypeOf(this,o.prototype)}}class l extends Error{constructor(t){super(t),Object.setPrototypeOf(this,l.prototype)}}class u extends Error{constructor(t){super(t),Object.setPrototypeOf(this,u.prototype)}}class h{constructor(t){this.maxEntries=t||100,this.cache=new Map}get(t){let e;return this.cache.has(t)&&(e=this.cache.get(t),this.cache.delete(t),this.cache.set(t,e)),e}put(t,e){if(this.cache.has(t))this.cache.delete(t);else if(this.cache.size>=this.maxEntries){const t=this.cache.keys().next().value;this.cache.delete(t)}this.cache.set(t,e)}getMaxEntries(){return this.maxEntries}setMaxEntries(t){if(t<0)throw new Error(`The maxEntries of LRU caches must be at least 0, but got ${t}.`);if(this.maxEntries>t)for(let e=0;ee.toUpperCase()))}let b={};function w(t){if(null==t)return null;const e={};return e.className=t.getClassName(),e.config=t.getConfig(),e}function k(t){if(null!=t&&"object"==typeof t)if(Array.isArray(t))t.forEach((t=>k(t)));else{const e=Object.keys(t);for(const n of e){const e=t[n];null!=e&&"object"==typeof e&&(Array.isArray(e)||"ndarray"!==e.type||"number"!=typeof e.value?k(e):t[n]=e.value)}}}function v(t,e={},n={},s="object",i=!1){if("string"==typeof t){const i=t;let r;if(i in n)r=n[i];else if(i in b)r=b[i];else if(r=e[i],null==r)throw new o(`Unknown ${s}: ${t}. This may be due to one of the following reasons:\n1. The ${s} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.\n2. The custom ${s} is defined in JavaScript, but is not registered properly with tf.serialization.registerClass().`);return r}{const r=t;if(null==r.className||null==r.config)throw new o(`${s}: Improper config format: ${JSON.stringify(r)}.\n'className' and 'config' must set.`);const a=r.className;let l,u;if(a in n?[l,u]=n[a]:a in b?[l,u]=b.className:a in e&&([l,u]=e[a]),null==l)throw new o(`Unknown ${s}: ${a}. This may be due to one of the following reasons:\n1. The ${s} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.\n2. The custom ${s} is defined in JavaScript, but is not registered properly with tf.serialization.registerClass().`);if(null!=u){const t={};for(const e of Object.keys(b))t[e]=b[e];for(const e of Object.keys(n))t[e]=n[e];r.config.customObjects=t;const e=Object.assign({},b);for(const t of Object.keys(n))b[t]=n[t];k(r.config);const s=u(l,r.config,n,i);return b=Object.assign({},e),s}{const t=Object.assign({},b);for(const t of Object.keys(n))b[t]=n[t];const e=new l(r.config);return b=Object.assign({},t),e}}}function S(t,e){return-1*function(t,e){return te?1:0}(t,e)}function x(t){if(null==t)return t;const e=[];for(const n of t)-1===e.indexOf(n)&&e.push(n);return e}function N(t){if(null==t)throw new o(`Invalid value in obj: ${JSON.stringify(t)}`);for(const e in t)if(t.hasOwnProperty(e))return!1;return!0}function I(t,e,n){if(null!=n&&t.indexOf(n)<0)throw new o(`${n} is not a valid ${e}. Valid values are ${t} or null/undefined.`)}function A(t,e,n=0,s=1/0){return p(n>=0),p(s>=n),Array.isArray(t)&&t.length>=n&&t.length<=s&&t.every((t=>typeof t===e))}function z(t,n){Array.isArray(t)?(e.util.assert(t.length>0,(()=>`${n} is unexpectedly an empty array.`)),t.forEach(((t,e)=>z(t,`element ${e+1} of ${n}`)))):e.util.assert(Number.isInteger(t)&&t>0,(()=>`Expected ${n} to be a positive integer, but got ${E(t)}.`))}function E(t){return null===t?"null":Array.isArray(t)?"["+t.map((t=>E(t))).join(",")+"]":"string"==typeof t?`"${t}"`:`${t}`}function T(t){return"relu"===t?"relu":"linear"===t?"linear":"elu"===t?"elu":null}let C=0;function $(){return C++}const F={};function D(t=""){return t in F||(F[t]=0),F[t]+=1,t+F[t].toString()}const L=["channelsFirst","channelsLast"],_=["nearest","bilinear"],R=["valid","same","causal"],O=["max","avg"],M=["sum","mul","concat","ave"],B=new Map;function P(t){I(L,"DataFormat",t)}function U(t){I(R,"PaddingMode",t)}function W(t){I(O,"PoolMode",t)}const j=[];function q(t,e){j.push(t);try{const t=e();return j.pop(),t}catch(t){throw j.pop(),t}}function V(t){if(!H(t))throw new Error("Not a valid tensor name: '"+t+"'");return(0===j.length?"":j.join("/")+"/")+t}function K(t){if(!H(t))throw new Error("Not a valid tensor name: '"+t+"'");B.has(t)||B.set(t,0);const e=B.get(t);if(B.set(t,B.get(t)+1),e>0){const n=`${t}_${e}`;return B.set(n,1),n}return t}const G=new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);function H(t){return!!t.match(G)}function J(t,e,n){null==e&&(e=0),null==n&&(n=t.length);let s=1;for(let i=e;ie&&(e=s)}return e}function X(t,e){if(e{switch(t.rank){case 1:return i.slice1d(t,n,s);case 2:return i.slice2d(t,[n,0],[s,t.shape[1]]);case 3:return i.slice3d(t,[n,0,0],[s,t.shape[1],t.shape[2]]);case 4:return i.slice4d(t,[n,0,0,0],[s,t.shape[1],t.shape[2],t.shape[3]]);case 5:return i.slice(t,[n,0,0,0,0],[s,t.shape[1],t.shape[2],t.shape[3],t.shape[4]]);case 6:return i.slice(t,[n,0,0,0,0,0],[s,t.shape[1],t.shape[2],t.shape[3],t.shape[4],t.shape[5]]);default:throw new o(`sliceAlongFirstAxis() received an unsupported tensor rank: ${t.rank}`)}}))}function it(t,n,s){return e.tidy((()=>{switch(t.rank){case 1:return i.slice1d(t,n,s);case 2:return i.slice2d(t,[0,n],[t.shape[0],s]);case 3:return i.slice3d(t,[0,0,n],[t.shape[0],t.shape[1],s]);case 4:return i.slice4d(t,[0,0,0,n],[t.shape[0],t.shape[1],t.shape[2],s]);default:throw new o(`sliceAlongLastAxis() received an unsupported tensor rank: ${t.rank}`)}}))}function rt(t,n,s,r){return e.tidy((()=>{switch(t.rank){case 1:return i.slice1d(t,n,s);case 2:switch(r){case 1:return st(t,n,s);case 2:return it(t,n,s);default:throw new o(`The axis is not within the rank of the tensor ${r}`)}case 3:switch(r){case 1:return st(t,n,s);case 2:return i.slice3d(t,[0,n,0],[t.shape[0],s,t.shape[2]]);case 3:return it(t,n,s);default:throw new o(`The axis is not within the rank of the tensor ${r}`)}case 4:switch(r){case 1:return st(t,n,s);case 2:return i.slice4d(t,[0,n,0,0],[t.shape[0],s,t.shape[2],t.shape[3]]);case 3:return i.slice4d(t,[0,0,n,0],[t.shape[0],t.shape[1],s,t.shape[3]]);case 4:return it(t,n,s);default:throw new o(`The axis is not within the rank of the tensor ${r}`)}default:throw new o(`sliceAlongLastAxis() received an unsupported tensor rank: ${t.rank}`)}}))}function at(t,e=-1){let n;return e<0&&(n=t[0].rank,e=0!==n?n:0),e===t[0].rank&&(e=-1),i.concat(t,e)}function ot(t,e){switch(t.rank){case 1:return i.concat1d([t,e]);case 2:return i.concat2d([t,e],0);case 3:return i.concat3d([t,e],0);case 4:return i.concat4d([t,e],0);default:throw new o(`concatAlongFirstAxis() received an unsupported tensor rank: ${t.rank}`)}}function lt(t,e){if(Array.isArray(e)||(e=[e]),t.rank!==e.length)throw new o(`The length of input n (${e.length}) does not match the number of dimensions in input x (${t.rank})`);return i.tile(t,e)}function ut(t,e=0,n=1,s,r){return i.randomNormal(t,e,n,s,r)}function ht(t,e,n,s){if(t.rank<2||e.rank<2)throw new l(`dot requires both inputs to be rank >= 2 but got x shape = ${t.shape} and y shape = ${e.shape}`);if(e.rank>=3){if(t.shape.slice(-1)[0]!==e.shape.slice(-2)[0])throw new l(`If rank y >= 3, then the second last dim of y must equal the last dim of x but got x shape = ${t.shape} and y shape = ${e.shape}`)}if(2===t.rank&&2===e.rank){const r=!1,a=!1;return i.fused.matMul({a:t,b:e,transposeA:r,transposeB:a,bias:s?dt(t.rank,s,"channelsLast"):null,activation:n})}{const r=t.shape.slice(),a=r.pop();t=i.reshape(t,[-1,a]);const o=e.shape.slice(),l=o.pop(),u=o.pop(),h=[...o,l],c=Array.from({length:e.rank},((t,n)=>0===n?e.rank-2:n<=e.rank-2?n-1:n));e=i.reshape(i.transpose(e,c),[u,-1]);const p=[...r,...h],d=!1,f=!1;return i.reshape(i.fused.matMul({a:t,b:e,transposeA:d,transposeB:f,bias:s?dt(t.rank,s,"channelsLast"):null,activation:n}),p)}}function ct(t,n,s){return e.tidy((()=>(n=Array.isArray(n)?e.tensor1d(n,"int32"):i.cast(n,"int32"),i.gather(t,n,s))))}function pt(t){return i.mul(t,t)}function dt(t,e,n){const s=e.shape;if(1!==e.rank&&e.rank!==t)throw new o(`Unexpected bias dimensions: ${e.rank}; expected it to be 1 or ${t}`);if(5===t){if("channelsFirst"===n)return 1===s.length?i.reshape(e,[1,s[0],1,1,1]):i.reshape(e,[1,s[3],s[0],s[1],s[2]]);if("channelsLast"===n)return 1===s.length?i.reshape(e,[1,1,1,1,s[0]]):i.reshape(e,[1].concat(s))}else if(4===t){if("channelsFirst"===n)return 1===s.length?i.reshape(e,[1,s[0],1,1]):i.reshape(e,[1,s[2],s[0],s[1]]);if("channelsLast"===n)return 1===s.length?i.reshape(e,[1,1,1,s[0]]):i.reshape(e,[1].concat(s))}else if(3===t){if("channelsFirst"===n)return 1===s.length?i.reshape(e,[1,s[0],1]):i.reshape(e,[1,s[1],s[0]]);if("channelsLast"===n)return 1===s.length?i.reshape(e,[1,1,s[0]]):i.reshape(e,[1].concat(s))}else if(t<3)return e;throw new o(`Unsupported input rank by biasAdd: ${e.rank}`)}function ft(t,n,s){return e.tidy((()=>(null==s&&(s="channelsLast"),P(s),i.add(t,dt(t.rank,n,s)))))}function gt(t,n,s,r){return e.tidy((()=>i.dropout(t,n,s,r)))}function mt(t,e,n=!1){return n?t():e()}const yt=["fanIn","fanOut","fanAvg"],bt=["normal","uniform","truncatedNormal"];class wt extends e.serialization.Serializable{fromConfigUsesCustomObjects(){return!1}getConfig(){return{}}}class kt extends wt{apply(t,n){return e.zeros(t,n)}}kt.className="Zeros",e.serialization.registerClass(kt);class vt extends wt{apply(t,n){return e.ones(t,n)}}vt.className="Ones",e.serialization.registerClass(vt);class St extends wt{constructor(t){if(super(),"object"!=typeof t)throw new o(`Expected argument of type ConstantConfig but got ${t}`);if(void 0===t.value)throw new o(`config must have value set but got ${t}`);this.value=t.value}apply(t,n){return e.tidy((()=>e.mul(e.scalar(this.value),e.ones(t,n))))}getConfig(){return{value:this.value}}}St.className="Constant",e.serialization.registerClass(St);class xt extends wt{constructor(t){super(),this.DEFAULT_MINVAL=-.05,this.DEFAULT_MAXVAL=.05,this.minval=t.minval||this.DEFAULT_MINVAL,this.maxval=t.maxval||this.DEFAULT_MAXVAL,this.seed=t.seed}apply(t,n){return e.randomUniform(t,this.minval,this.maxval,n,this.seed)}getConfig(){return{minval:this.minval,maxval:this.maxval,seed:this.seed}}}xt.className="RandomUniform",e.serialization.registerClass(xt);class Nt extends wt{constructor(t){super(),this.DEFAULT_MEAN=0,this.DEFAULT_STDDEV=.05,this.mean=t.mean||this.DEFAULT_MEAN,this.stddev=t.stddev||this.DEFAULT_STDDEV,this.seed=t.seed}apply(t,e){if("float32"!==(e=e||"float32")&&"int32"!==e)throw new l(`randomNormal does not support dType ${e}.`);return ut(t,this.mean,this.stddev,e,this.seed)}getConfig(){return{mean:this.mean,stddev:this.stddev,seed:this.seed}}}Nt.className="RandomNormal",e.serialization.registerClass(Nt);class It extends wt{constructor(t){super(),this.DEFAULT_MEAN=0,this.DEFAULT_STDDEV=.05,this.mean=t.mean||this.DEFAULT_MEAN,this.stddev=t.stddev||this.DEFAULT_STDDEV,this.seed=t.seed}apply(t,n){if("float32"!==(n=n||"float32")&&"int32"!==n)throw new l(`truncatedNormal does not support dType ${n}.`);return e.truncatedNormal(t,this.mean,this.stddev,n,this.seed)}getConfig(){return{mean:this.mean,stddev:this.stddev,seed:this.seed}}}It.className="TruncatedNormal",e.serialization.registerClass(It);let At=class extends wt{constructor(t){super(),this.gain=null!=t.gain?t.gain:1}apply(t,n){return e.tidy((()=>{if(2!==t.length||t[0]!==t[1])throw new o("Identity matrix initializer can only be used for 2D square matrices.");return e.mul(this.gain,e.eye(t[0]))}))}getConfig(){return{gain:this.gain}}};At.className="Identity",e.serialization.registerClass(At);class zt extends wt{constructor(t){if(super(),t.scale<0)throw new o(`scale must be a positive float. Got: ${t.scale}`);var e;this.scale=null==t.scale?1:t.scale,this.mode=null==t.mode?"fanIn":t.mode,e=this.mode,I(yt,"FanMode",e),this.distribution=null==t.distribution?"normal":t.distribution,function(t){I(bt,"Distribution",t)}(this.distribution),this.seed=t.seed}apply(t,n){const s=function(t,e="channelsLast"){let n,s;if(P(e),2===t.length)n=t[0],s=t[1];else if(-1!==[3,4,5].indexOf(t.length)){if("channelsFirst"===e){const e=J(t,2);n=t[1]*e,s=t[0]*e}else if("channelsLast"===e){const e=J(t,0,t.length-2);n=t[t.length-2]*e,s=t[t.length-1]*e}}else{const e=J(t);n=Math.sqrt(e),s=Math.sqrt(e)}return[n,s]}(t),i=s[0],r=s[1];let a=this.scale;if("fanIn"===this.mode?a/=Math.max(1,i):"fanOut"===this.mode?a/=Math.max(1,r):a/=Math.max(1,(i+r)/2),"normal"===this.distribution){const s=Math.sqrt(a);if("float32"!==(n=n||"float32")&&"int32"!==n)throw new l(`${this.getClassName()} does not support dType ${n}.`);return e.truncatedNormal(t,0,s,n,this.seed)}{const s=Math.sqrt(3*a);return e.randomUniform(t,-s,s,n,this.seed)}}getConfig(){return{scale:this.scale,mode:this.mode,distribution:this.distribution,seed:this.seed}}}zt.className="VarianceScaling",e.serialization.registerClass(zt);class Et extends zt{constructor(t){super({scale:1,mode:"fanAvg",distribution:"uniform",seed:null==t?null:t.seed})}getClassName(){return zt.className}}Et.className="GlorotUniform",e.serialization.registerClass(Et);class Tt extends zt{constructor(t){super({scale:1,mode:"fanAvg",distribution:"normal",seed:null==t?null:t.seed})}getClassName(){return zt.className}}Tt.className="GlorotNormal",e.serialization.registerClass(Tt);class Ct extends zt{constructor(t){super({scale:2,mode:"fanIn",distribution:"normal",seed:null==t?null:t.seed})}getClassName(){return zt.className}}Ct.className="HeNormal",e.serialization.registerClass(Ct);class $t extends zt{constructor(t){super({scale:2,mode:"fanIn",distribution:"uniform",seed:null==t?null:t.seed})}getClassName(){return zt.className}}$t.className="HeUniform",e.serialization.registerClass($t);class Ft extends zt{constructor(t){super({scale:1,mode:"fanIn",distribution:"normal",seed:null==t?null:t.seed})}getClassName(){return zt.className}}Ft.className="LeCunNormal",e.serialization.registerClass(Ft);class Dt extends zt{constructor(t){super({scale:1,mode:"fanIn",distribution:"uniform",seed:null==t?null:t.seed})}getClassName(){return zt.className}}Dt.className="LeCunUniform",e.serialization.registerClass(Dt);class Lt extends wt{constructor(t){super(),this.DEFAULT_GAIN=1,this.ELEMENTS_WARN_SLOW=2e3,this.gain=null==t.gain?this.DEFAULT_GAIN:t.gain,this.seed=t.seed}apply(t,n){return e.tidy((()=>{if(t.length<2)throw new l("Shape must be at least 2D.");if("int32"!==n&&"float32"!==n&&void 0!==n)throw new TypeError(`Unsupported data type ${n}.`);const s=e.util.sizeFromShape(t.slice(0,-1)),i=t[t.length-1],r=s*i;r>this.ELEMENTS_WARN_SLOW&&console.warn(`Orthogonal initializer is being called on a matrix with more than ${this.ELEMENTS_WARN_SLOW} (${r}) elements: Slowness may result.`);const a=ut([Math.max(i,s),Math.min(i,s)],0,1,n,this.seed),o=e.linalg.qr(a,!1);let u=o[0];const h=o[1].flatten().stridedSlice([0],[Math.min(i,s)*Math.min(i,s)],[Math.min(i,s)+1]);return u=e.mul(u,h.sign()),st*e));return e}const qt="Variable";class Vt{constructor(t,e="float32",n="Variable",s=!0,r=null){this.dtype=null==e?"float32":e,this.shape=t.shape,this.id=$(),n=null==n?qt:n,this.originalName=V(n),this.name=K(this.originalName),this.trainable_=s,this.constraint=r,this.val=i.variable(t,this.trainable_,this.name,this.dtype)}read(){return this.assertNotDisposed(),this.val}write(t){return this.assertNotDisposed(),function(t,e){if(t.shape.toString()!==e.shape.toString())throw new Error("Shape mismatch: "+JSON.stringify(t.shape)+" vs. "+JSON.stringify(e.shape))}(this.val,t),this.val.id!==t.id&&(this.val.assign(t),null!=this.constraint&&this.val.assign(this.constraint.apply(this.val))),this}dispose(){this.assertNotDisposed(),this.val.dispose()}assertNotDisposed(){if(this.val.isDisposed)throw new Error(`LayersVariable ${this.name} is already disposed.`)}get trainable(){return this.trainable_}set trainable(t){this.trainable_=t,this.val.trainable=t}}function Kt(t){return t.map((t=>t.read()))}function Gt(t){t.forEach((t=>{t[0].write(t[1])}))}class Ht{constructor(t){this.dtype=t.dtype,this.shape=t.shape,null!=t.shape?this.ndim=t.shape.length:this.ndim=t.ndim,this.maxNDim=t.maxNDim,this.minNDim=t.minNDim,this.axes=t.axes||{}}}class Jt{constructor(t,e,n,s,i,r,a){this.dtype=t,this.shape=e,this.sourceLayer=n,this.inputs=s,this.callArgs=i,this.outputTensorIndex=a,this.id=$(),null!=r&&(this.originalName=V(r),this.name=K(this.originalName)),this.rank=e.length}}let Zt=0;class Yt{constructor(t,e){this.callArgs=e,this.id=Zt++,this.outboundLayer=t.outboundLayer,this.inboundLayers=t.inboundLayers,this.nodeIndices=t.nodeIndices,this.tensorIndices=t.tensorIndices,this.inputTensors=t.inputTensors,this.outputTensors=t.outputTensors,this.inputMasks=t.inputMasks,this.outputMasks=t.outputMasks,this.inputShapes=t.inputShapes,this.outputShapes=t.outputShapes;for(const e of t.inboundLayers)null!=e&&e.outboundNodes.push(this);t.outboundLayer.inboundNodes.push(this)}getConfig(){const t=[];for(const e of this.inboundLayers)null!=e?t.push(e.name):t.push(null);return{outboundLayer:this.outboundLayer?this.outboundLayer.name:null,inboundLayers:t,nodeIndices:this.nodeIndices,tensorIndices:this.tensorIndices}}}let Xt=0;class Qt extends e.serialization.Serializable{constructor(t={}){super(),this._callHook=null,this._addedWeightNames=[],this._stateful=!1,this.id=Xt++,this.activityRegularizer=null,this.inputSpec=null,this.supportsMasking=!1,this._trainableWeights=[],this._nonTrainableWeights=[],this._losses=[],this._updates=[],this._built=!1,this.inboundNodes=[],this.outboundNodes=[];let e=t.name;if(!e){const t=this.getClassName();e=m(t)+"_"+D(t)}if(this.name=e,this.trainable_=null==t.trainable||t.trainable,null!=t.inputShape||null!=t.batchInputShape){let e;if(null!=t.batchInputShape)e=t.batchInputShape;else if(null!=t.inputShape){let n=null;null!=t.batchSize&&(n=t.batchSize),e=[n].concat(t.inputShape)}this.batchInputShape=e;let n=t.dtype;null==n&&(n=t.inputDType),null==n&&(n="float32"),this.dtype=n}null!=t.weights?this.initialWeights=t.weights:this.initialWeights=null,this._refCount=null,this.fastWeightInitDuringBuild=!1}static nodeKey(t,e){return t.name+"_ib-"+e.toString()}getNodeAtIndex(t,e){if(0===this.inboundNodes.length)throw new a(`The layer has never been called and thus has no defined ${e}.`);if(this.inboundNodes.length<=t)throw new o(`Asked to get ${e} at node ${t}, but the layer has only ${this.inboundNodes.length} inbound nodes.`);return this.inboundNodes[t]}getInputAt(t){return f(this.getNodeAtIndex(t,"input").inputTensors)}getOutputAt(t){return f(this.getNodeAtIndex(t,"output").outputTensors)}get input(){if(this.inboundNodes.length>1)throw new r(`Layer ${this.name} has multiple inbound nodes, hence the notion of "layer input" is ill-defined. Use \`getInputAt(nodeIndex)\` instead.`);if(0===this.inboundNodes.length)throw new r(`Layer ${this.name} is not connected, no input to return.`);return f(this.getNodeAtIndex(0,"input").inputTensors)}get output(){if(0===this.inboundNodes.length)throw new r(`Layer ${this.name} has no inbound nodes.`);if(this.inboundNodes.length>1)throw new r(`Layer ${this.name} has multiple inbound nodes, hence the notion of "layer output" is ill-defined. Use \`getOutputAt(nodeIndex)\` instead.`);return f(this.getNodeAtIndex(0,"output").outputTensors)}get losses(){return this._losses}calculateLosses(){return this.losses.map((t=>t()))}get updates(){return this._updates}get built(){return this._built}set built(t){this._built=t}get trainable(){return this.trainable_}set trainable(t){this._trainableWeights.forEach((e=>e.trainable=t)),this.trainable_=t}get trainableWeights(){return this.trainable_?this._trainableWeights.filter((t=>t.trainable)):[]}set trainableWeights(t){this._trainableWeights=t}get nonTrainableWeights(){return this.trainable?this._trainableWeights.filter((t=>!t.trainable)).concat(this._nonTrainableWeights):this._trainableWeights.concat(this._nonTrainableWeights)}set nonTrainableWeights(t){this._nonTrainableWeights=t}get weights(){return this.trainableWeights.concat(this.nonTrainableWeights)}get stateful(){return this._stateful}resetStates(){if(!this.stateful)throw new Error("Cannot call the resetStates() method of a non-stateful Layer object.")}assertInputCompatibility(t){const e=g(t);if(null==this.inputSpec||0===this.inputSpec.length)return;const n=g(this.inputSpec);if(e.length!==n.length)throw new o(`Layer ${this.name} expects ${n.length} inputs, but it received ${e.length} input tensors. Input received: ${t}`);for(let t=0;ti.maxNDim)throw new o(`Input ${t} is incompatible with layer ${this.name}: expected max_ndim=${i.maxNDim}, found ndim=${r}`);if(null!=i.minNDim&&r=0?e[s]:e[e.length+s];if(null!=r&&-1===[r,null].indexOf(a))throw new o(`Input ${t} is incompatible with layer ${this.name}: expected axis ${s} of input shape to have value ${r} but got shape ${e}.`)}}if(null!=i.shape)for(let e=0;e{if(!this.built){this.assertInputCompatibility(t);const e=[];for(const n of g(t))e.push(n.shape);this.build(f(e)),this.built=!0,this.initialWeights&&this.setWeights(this.initialWeights),null===this._refCount&&i&&(this._refCount=1)}if(this.assertInputCompatibility(t),i){let s=this.call(t,e);this.supportsMasking&&this.setMaskMetadata(t,s);const i=g(s),r=[];for(let t of i)-1!==n.indexOf(t)&&(t=t.clone()),r.push(t);if(s=f(r),null!=this.activityRegularizer)throw new l("Layer invocation in the presence of activity regularizer(s) is not supported yet.");return s}{const n=function(t){t=g(t);const e=[];for(const n of t)e.push(n.shape);return f(e)}(t),s=this.computeOutputShape(n);let i;const r="float32";if(this.warnOnIncompatibleInputShape(Array.isArray(t)?n[0]:n),i=null!=s&&s.length>0&&Array.isArray(s[0])?s.map(((n,s)=>new Jt(r,n,this,g(t),e,this.name,s))):new Jt(r,s,this,g(t),e,this.name),this.addInboundNode(t,i,null,null,n,s,e),this._refCount++,null!=this.activityRegularizer)throw new l("Layer invocation in the presence of activity regularizer(s) is not supported yet.");return i}}))}warnOnIncompatibleInputShape(t){if(null!=this.batchInputShape)if(t.length!==this.batchInputShape.length)console.warn(`The rank of the input tensor provided (shape: ${JSON.stringify(t)}) does not match that of the batchInputShape (${JSON.stringify(this.batchInputShape)}) of the layer ${this.name}`);else{let e=!1;this.batchInputShape.forEach(((n,s)=>{null!=n&&null!=t[s]&&t[s]!==n&&(e=!0)})),e&&console.warn(`The shape of the input tensor (${JSON.stringify(t)}) does not match the expectation of layer ${this.name}: ${JSON.stringify(this.batchInputShape)}`)}}get outputShape(){if(null==this.inboundNodes||0===this.inboundNodes.length)throw new r(`The layer ${this.name} has never been called and thus has no defined output shape.`);const t=[];for(const e of this.inboundNodes){const n=JSON.stringify(e.outputShapes);-1===t.indexOf(n)&&t.push(n)}if(1===t.length){const t=this.inboundNodes[0].outputShapes;return Array.isArray(t)&&Array.isArray(t[0])&&1===t.length?t[0]:t}throw new r(`The layer ${this.name} has multiple inbound nodes with different output shapes. Hence the notion of "output shape" is ill-defined for the layer.`)}countParams(){if(!this.built)throw new a(`You tried to call countParams() on ${this.name}, but the layer is not built yet. Build it first by calling build(batchInputShape).`);return jt(this.weights)}build(t){this.built=!0}getWeights(t=!1){return Kt(t?this.trainableWeights:this.weights)}setWeights(t){e.tidy((()=>{const n=this.weights;if(n.length!==t.length)throw new o(`You called setWeights(weights) on layer "${this.name}" with a weight list of length ${t.length}, but the layer was expecting ${n.length} weights. Provided weights: ${t}...`);if(0===n.length)return;const s=[],i=Kt(n);for(let r=0;ri.apply(h.read()))),null==r&&(r=!0),r?this._trainableWeights.push(h):this._nonTrainableWeights.push(h),h}setFastWeightInitDuringBuild(t){this.fastWeightInitDuringBuild=t}addLoss(t){null==t||Array.isArray(t)&&0===t.length||(t=g(t),void 0!==this._losses&&null!==this._losses&&this.losses.push(...t))}computeOutputShape(t){return t}computeMask(t,e){if(!this.supportsMasking){if(null!=e){if(!Array.isArray(e))throw new TypeError(`Layer ${this.name} does not support masking, but was passed an inputMask.`);e.forEach((t=>{if(null!=t)throw new TypeError(`Layer ${this.name} does not support masking, but was passed an inputMask.`)}))}return null}return e}setMaskMetadata(t,e,n){if(!this.supportsMasking)return;const s=this.computeMask(t,n),i=g(e),r=g(s);if(i.length!==r.length)throw new Error(`${this.name} outputs ${i.length} tensors but ${i.length} masks for those tensors`);for(let t=0;tt.dispose())),this.weights.length}assertNotDisposed(){if(0===this._refCount)throw new Error(`Layer '${this.name}' is already disposed.`)}dispose(){if(!this.built)throw new Error(`Cannot dispose Layer ${this.name} because it has not been built yet.`);if(null===this._refCount)throw new Error(`Cannot dispose Layer ${this.name} because it has not been used yet.`);this.assertNotDisposed();let t=0;return 0==--this._refCount&&(t=this.disposeWeights()),{refCountAfterDispose:this._refCount,numDisposedVariables:t}}}function te(t,e,n){if((null==e||null!=n&&n>0)&&(e=t.sourceLayer,n=t.nodeIndex),0===e.inboundNodes.length)return[t];{const t=e.inboundNodes[n];if(0===t.inboundLayers.length)return t.inputTensors;{const e=[];for(let n=0;nt.name)),u=[],h=n.names();for(const t of l)-1!==h.indexOf(t)?u.push(n.getValue(t)):u.push(null);null!=i&&(i.maxNumTensors=-1/0,i.minNumTensors=1/0);const c=l.join(",")+"|"+n.names().sort().join(",");let p,d=ie.get(c);if(null==d){const t=function(t,n){e.util.assert(null!=t&&t.length>0,(()=>"Expected at least one fetch, got none"));let s=[],i={};if(1===t.length){const e=le(t[0],n);s=e.sorted,i=e.recipientMap}else{const e=new Set;for(const r of t){const{sorted:t,recipientMap:a}=le(r,n);for(const n of t)e.has(n.name)||(s.push(n),e.add(n.name));for(const t in a)null==i[t]&&(i[t]=new Set),a[t].forEach((e=>i[t].add(e)))}}return{sorted:s,recipientCounts:oe(i)}}(o,n);d=t.sorted,p=t.recipientCounts,ie.put(c,d),re.put(c,p)}p={},r||Object.assign(p,re.get(c));const f=new se(n);for(let t=0;ti.maxNumTensors&&(i.maxNumTensors=t),t0;){const t=r[r.length-1];if(n.has(t.name)){r.pop();continue}const e=a[a.length-1]===r.length-1;if(0===t.inputs.length||e)r.pop(),s.push(t),n.add(t.name),e&&a.pop();else{a.push(r.length-1);for(const e of t.inputs)null==i[e.name]&&(i[e.name]=new Set),i[e.name].add(t.name),n.has(e.name)||r.push(e)}}return{sorted:s,recipientMap:i}}function ue(t){let e;if(1===t.sourceLayer.inboundNodes.length)e=t.sourceLayer.output;else{let n=null;for(let e=0;e100),(function(t){null!=ie&&ie.setMaxEntries(t),null!=re&&re.setMaxEntries(t)}));const he="Add",ce="BatchMatMul",pe="BatchToSpaceND",de="Cast",fe="Concat",ge="Conv2D",me="Conv2DBackpropInput",ye="Cosh",be="Cumsum",we="RealDiv",ke="ExpandDims",ve="Floor",Se="FloorDiv",xe="GatherV2",Ne="GreaterEqual",Ie="Identity",Ae="Maximum",ze="Multiply",Ee="Pack",Te="PadV2",Ce="Reshape",$e="Reverse",Fe="Rsqrt",De="Select",Le="Slice",_e="Sinh",Re="Sigmoid",Oe="Sqrt",Me="SpaceToBatchND",Be="SplitV",Pe="Tile",Ue="Transpose",We="Unpack",je="UnsortedSegmentSum",qe="ZerosLike",Ve="Step";function Ke(t){throw new Error(`'${t}' not yet implemented or not found in the registry. This kernel may not be supported by the tfjs backend you have chosen`)}function Ge(t,e){if(!t)throw new Error("string"==typeof e?e:e())}function He(t){if(0===t.length)return 1;let e=t[0];for(let n=1;ne)):[].concat(t)).every((t=>t>=-n&&t`All values in axis param must be in range [-${n}, ${n}) but got axis ${t}`)),Ge(t.every((t=>Ze(t))),(()=>`All values in axis param must be integers but got axis ${t}`)),t.map((t=>t<0?n+t:t))}function Qe(t){if("float32"===t||"int32"===t)return 4;if("complex64"===t)return 8;if("bool"===t)return 1;throw new Error(`Unknown dtype ${t}`)}function tn(t){return"string"==typeof t||t instanceof String}function en(t){return Array.isArray(t)?en(t[0]):t instanceof Float32Array?"float32":t instanceof Int32Array||t instanceof Uint8Array||t instanceof Uint8ClampedArray?"int32":"number"==typeof t?"float32":tn(t)?"string":function(t){return"boolean"==typeof t}(t)?"bool":"float32"}function nn(t){return!!(t&&t.constructor&&t.call&&t.apply)}function sn(t){const e=t.length;if(e<2)return[];const n=new Array(e-1);n[e-2]=t[e-1];for(let s=e-3;s>=0;--s)n[s]=n[s+1]*t[s+1];return n}function rn(t,e,n,s=!1){const i=new Array;if(1===e.length){const r=e[0]*(s?2:1);for(let e=0;et*e))*(s?2:1);for(let e=0;et*e))*(n?2:1);if(0===s)return[];if(s!==e.length)throw new Error(`[${t}] does not match the input size ${e.length}${n?" for a complex tensor":""}.`);return rn(0,t,e,n)}function on(t,e){const n=ln(t,e);for(let t=0;t{Ge(Number.isInteger(e)&&e>=0,(()=>`Tensor must have a shape comprised of positive integers but got shape [${t}].`))}))}function hn(t){return t&&t.then&&"function"==typeof t.then}const cn="tfjsflags";class pn{constructor(t){this.global=t,this.flags={},this.flagRegistry={},this.urlFlags={},this.getQueryParams=dn,this.populateURLFlags()}setPlatform(t,e){null!=this.platform&&(fn().getBool("IS_TEST")||fn().getBool("PROD")||console.warn(`Platform ${this.platformName} has already been set. Overwriting the platform with ${t}.`)),this.platformName=t,this.platform=e}registerFlag(t,e,n){if(this.flagRegistry[t]={evaluationFn:e,setHook:n},null!=this.urlFlags[t]){const e=this.urlFlags[t];fn().getBool("IS_TEST")||fn().getBool("PROD")||console.warn(`Setting feature override from URL ${t}: ${e}.`),this.set(t,e)}}async getAsync(t){return t in this.flags||(this.flags[t]=await this.evaluateFlag(t)),this.flags[t]}get(t){if(t in this.flags)return this.flags[t];const e=this.evaluateFlag(t);if(hn(e))throw new Error(`Flag ${t} cannot be synchronously evaluated. Please use getAsync() instead.`);return this.flags[t]=e,this.flags[t]}getNumber(t){return this.get(t)}getBool(t){return this.get(t)}getString(t){return this.get(t)}getFlags(){return this.flags}get features(){return this.flags}set(t,e){if(null==this.flagRegistry[t])throw new Error(`Cannot set flag ${t} as it has not been registered.`);this.flags[t]=e,null!=this.flagRegistry[t].setHook&&this.flagRegistry[t].setHook(e)}evaluateFlag(t){if(null==this.flagRegistry[t])throw new Error(`Cannot evaluate flag '${t}': no evaluation function found.`);return this.flagRegistry[t].evaluationFn()}setFlags(t){this.flags=Object.assign({},t)}reset(){this.flags={},this.urlFlags={},this.populateURLFlags()}populateURLFlags(){if("undefined"==typeof this.global||"undefined"==typeof this.global.location||"undefined"==typeof this.global.location.search)return;const t=this.getQueryParams(this.global.location.search);if(cn in t){t.tfjsflags.split(",").forEach((t=>{const[e,n]=t.split(":");this.urlFlags[e]=function(t,e){const n=e.toLowerCase();return"true"===n||"false"===n?"true"===n:""+ +n===n?+n:e}(0,n)}))}}}function dn(t){const e={};return t.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g,((t,...n)=>(function(t,e,n){t[decodeURIComponent(e)]=decodeURIComponent(n||"")}(e,n[0],n[1]),n.join("=")))),e}function fn(){return mn}let gn,mn=null;function yn(){if(null==gn){let t;if("undefined"!=typeof window)t=window;else if("undefined"!=typeof global)t=global;else if("undefined"!=typeof process)t=process;else{if("undefined"==typeof self)throw new Error("Could not find a global object");t=self}gn=t}return gn}function bn(t,e){const n=function(){const t=yn();return null==t._tfGlobals&&(t._tfGlobals=new Map),t._tfGlobals}();if(n.has(t))return n.get(t);{const s=e();return n.set(t,s),n.get(t)}}function wn(...t){fn().getBool("IS_TEST")||fn().getBool("PROD")||console.warn(...t)}const kn=bn("kernelRegistry",(()=>new Map)),vn=bn("gradRegistry",(()=>new Map));function Sn(t,e){const n=function(t,e){return`${e}_${t}`}(t,e);return kn.get(n)}function xn(t){return vn.get(t)}function Nn(t){const e=kn.entries(),n=[];for(;;){const{done:s,value:i}=e.next();if(s)break;const[r,a]=i,[o]=r.split("_");o===t&&n.push(a)}return n}function In(t){const{kernelName:e}=t;vn.has(e)&&fn().getBool("DEBUG")&&wn(`Overriding the gradient for '${e}'`),vn.set(e,t)}var An="undefined"!=typeof globalThis?globalThis:"undefined"!=typeof window?window:"undefined"!=typeof global?global:"undefined"!=typeof self?self:{};function zn(t){return t&&t.__esModule&&Object.prototype.hasOwnProperty.call(t,"default")?t.default:t}function En(t){if(t.__esModule)return t;var e=t.default;if("function"==typeof e){var n=function t(){if(this instanceof t){var n=[null];n.push.apply(n,arguments);var s=Function.bind.apply(e,n);return new s}return e.apply(this,arguments)};n.prototype=e.prototype}else n={};return Object.defineProperty(n,"__esModule",{value:!0}),Object.keys(t).forEach((function(e){var s=Object.getOwnPropertyDescriptor(t,e);Object.defineProperty(n,e,s.get?s:{enumerable:!0,get:function(){return t[e]}})})),n}var Tn=$n,Cn=null;try{Cn=new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([0,97,115,109,1,0,0,0,1,13,2,96,0,1,127,96,4,127,127,127,127,1,127,3,7,6,0,1,1,1,1,1,6,6,1,127,1,65,0,11,7,50,6,3,109,117,108,0,1,5,100,105,118,95,115,0,2,5,100,105,118,95,117,0,3,5,114,101,109,95,115,0,4,5,114,101,109,95,117,0,5,8,103,101,116,95,104,105,103,104,0,0,10,191,1,6,4,0,35,0,11,36,1,1,126,32,0,173,32,1,173,66,32,134,132,32,2,173,32,3,173,66,32,134,132,126,34,4,66,32,135,167,36,0,32,4,167,11,36,1,1,126,32,0,173,32,1,173,66,32,134,132,32,2,173,32,3,173,66,32,134,132,127,34,4,66,32,135,167,36,0,32,4,167,11,36,1,1,126,32,0,173,32,1,173,66,32,134,132,32,2,173,32,3,173,66,32,134,132,128,34,4,66,32,135,167,36,0,32,4,167,11,36,1,1,126,32,0,173,32,1,173,66,32,134,132,32,2,173,32,3,173,66,32,134,132,129,34,4,66,32,135,167,36,0,32,4,167,11,36,1,1,126,32,0,173,32,1,173,66,32,134,132,32,2,173,32,3,173,66,32,134,132,130,34,4,66,32,135,167,36,0,32,4,167,11])),{}).exports}catch(t){}function $n(t,e,n){this.low=0|t,this.high=0|e,this.unsigned=!!n}function Fn(t){return!0===(t&&t.__isLong__)}$n.prototype.__isLong__,Object.defineProperty($n.prototype,"__isLong__",{value:!0}),$n.isLong=Fn;var Dn={},Ln={};function _n(t,e){var n,s,i;return e?(i=0<=(t>>>=0)&&t<256)&&(s=Ln[t])?s:(n=On(t,(0|t)<0?-1:0,!0),i&&(Ln[t]=n),n):(i=-128<=(t|=0)&&t<128)&&(s=Dn[t])?s:(n=On(t,t<0?-1:0,!1),i&&(Dn[t]=n),n)}function Rn(t,e){if(isNaN(t))return e?Kn:Vn;if(e){if(t<0)return Kn;if(t>=Wn)return Yn}else{if(t<=-jn)return Xn;if(t+1>=jn)return Zn}return t<0?Rn(-t,e).neg():On(t%Un|0,t/Un|0,e)}function On(t,e,n){return new $n(t,e,n)}$n.fromInt=_n,$n.fromNumber=Rn,$n.fromBits=On;var Mn=Math.pow;function Bn(t,e,n){if(0===t.length)throw Error("empty string");if("NaN"===t||"Infinity"===t||"+Infinity"===t||"-Infinity"===t)return Vn;if("number"==typeof e?(n=e,e=!1):e=!!e,(n=n||10)<2||360)throw Error("interior hyphen");if(0===s)return Bn(t.substring(1),e,n).neg();for(var i=Rn(Mn(n,8)),r=Vn,a=0;a>>0:this.low},Qn.toNumber=function(){return this.unsigned?(this.high>>>0)*Un+(this.low>>>0):this.high*Un+(this.low>>>0)},Qn.toString=function(t){if((t=t||10)<2||36>>0).toString(t);if((r=o).isZero())return l+a;for(;l.length<6;)l="0"+l;a=""+l+a}},Qn.getHighBits=function(){return this.high},Qn.getHighBitsUnsigned=function(){return this.high>>>0},Qn.getLowBits=function(){return this.low},Qn.getLowBitsUnsigned=function(){return this.low>>>0},Qn.getNumBitsAbs=function(){if(this.isNegative())return this.eq(Xn)?64:this.neg().getNumBitsAbs();for(var t=0!=this.high?this.high:this.low,e=31;e>0&&0==(t&1<=0},Qn.isOdd=function(){return 1==(1&this.low)},Qn.isEven=function(){return 0==(1&this.low)},Qn.equals=function(t){return Fn(t)||(t=Pn(t)),(this.unsigned===t.unsigned||this.high>>>31!=1||t.high>>>31!=1)&&(this.high===t.high&&this.low===t.low)},Qn.eq=Qn.equals,Qn.notEquals=function(t){return!this.eq(t)},Qn.neq=Qn.notEquals,Qn.ne=Qn.notEquals,Qn.lessThan=function(t){return this.comp(t)<0},Qn.lt=Qn.lessThan,Qn.lessThanOrEqual=function(t){return this.comp(t)<=0},Qn.lte=Qn.lessThanOrEqual,Qn.le=Qn.lessThanOrEqual,Qn.greaterThan=function(t){return this.comp(t)>0},Qn.gt=Qn.greaterThan,Qn.greaterThanOrEqual=function(t){return this.comp(t)>=0},Qn.gte=Qn.greaterThanOrEqual,Qn.ge=Qn.greaterThanOrEqual,Qn.compare=function(t){if(Fn(t)||(t=Pn(t)),this.eq(t))return 0;var e=this.isNegative(),n=t.isNegative();return e&&!n?-1:!e&&n?1:this.unsigned?t.high>>>0>this.high>>>0||t.high===this.high&&t.low>>>0>this.low>>>0?-1:1:this.sub(t).isNegative()?-1:1},Qn.comp=Qn.compare,Qn.negate=function(){return!this.unsigned&&this.eq(Xn)?Xn:this.not().add(Gn)},Qn.neg=Qn.negate,Qn.add=function(t){Fn(t)||(t=Pn(t));var e=this.high>>>16,n=65535&this.high,s=this.low>>>16,i=65535&this.low,r=t.high>>>16,a=65535&t.high,o=t.low>>>16,l=0,u=0,h=0,c=0;return h+=(c+=i+(65535&t.low))>>>16,u+=(h+=s+o)>>>16,l+=(u+=n+a)>>>16,l+=e+r,On((h&=65535)<<16|(c&=65535),(l&=65535)<<16|(u&=65535),this.unsigned)},Qn.subtract=function(t){return Fn(t)||(t=Pn(t)),this.add(t.neg())},Qn.sub=Qn.subtract,Qn.multiply=function(t){if(this.isZero())return Vn;if(Fn(t)||(t=Pn(t)),Cn)return On(Cn.mul(this.low,this.high,t.low,t.high),Cn.get_high(),this.unsigned);if(t.isZero())return Vn;if(this.eq(Xn))return t.isOdd()?Xn:Vn;if(t.eq(Xn))return this.isOdd()?Xn:Vn;if(this.isNegative())return t.isNegative()?this.neg().mul(t.neg()):this.neg().mul(t).neg();if(t.isNegative())return this.mul(t.neg()).neg();if(this.lt(qn)&&t.lt(qn))return Rn(this.toNumber()*t.toNumber(),this.unsigned);var e=this.high>>>16,n=65535&this.high,s=this.low>>>16,i=65535&this.low,r=t.high>>>16,a=65535&t.high,o=t.low>>>16,l=65535&t.low,u=0,h=0,c=0,p=0;return c+=(p+=i*l)>>>16,h+=(c+=s*l)>>>16,c&=65535,h+=(c+=i*o)>>>16,u+=(h+=n*l)>>>16,h&=65535,u+=(h+=s*o)>>>16,h&=65535,u+=(h+=i*a)>>>16,u+=e*l+n*o+s*a+i*r,On((c&=65535)<<16|(p&=65535),(u&=65535)<<16|(h&=65535),this.unsigned)},Qn.mul=Qn.multiply,Qn.divide=function(t){if(Fn(t)||(t=Pn(t)),t.isZero())throw Error("division by zero");var e,n,s;if(Cn)return this.unsigned||-2147483648!==this.high||-1!==t.low||-1!==t.high?On((this.unsigned?Cn.div_u:Cn.div_s)(this.low,this.high,t.low,t.high),Cn.get_high(),this.unsigned):this;if(this.isZero())return this.unsigned?Kn:Vn;if(this.unsigned){if(t.unsigned||(t=t.toUnsigned()),t.gt(this))return Kn;if(t.gt(this.shru(1)))return Hn;s=Kn}else{if(this.eq(Xn))return t.eq(Gn)||t.eq(Jn)?Xn:t.eq(Xn)?Gn:(e=this.shr(1).div(t).shl(1)).eq(Vn)?t.isNegative()?Gn:Jn:(n=this.sub(t.mul(e)),s=e.add(n.div(t)));if(t.eq(Xn))return this.unsigned?Kn:Vn;if(this.isNegative())return t.isNegative()?this.neg().div(t.neg()):this.neg().div(t).neg();if(t.isNegative())return this.div(t.neg()).neg();s=Vn}for(n=this;n.gte(t);){e=Math.max(1,Math.floor(n.toNumber()/t.toNumber()));for(var i=Math.ceil(Math.log(e)/Math.LN2),r=i<=48?1:Mn(2,i-48),a=Rn(e),o=a.mul(t);o.isNegative()||o.gt(n);)o=(a=Rn(e-=r,this.unsigned)).mul(t);a.isZero()&&(a=Gn),s=s.add(a),n=n.sub(o)}return s},Qn.div=Qn.divide,Qn.modulo=function(t){return Fn(t)||(t=Pn(t)),Cn?On((this.unsigned?Cn.rem_u:Cn.rem_s)(this.low,this.high,t.low,t.high),Cn.get_high(),this.unsigned):this.sub(this.div(t).mul(t))},Qn.mod=Qn.modulo,Qn.rem=Qn.modulo,Qn.not=function(){return On(~this.low,~this.high,this.unsigned)},Qn.and=function(t){return Fn(t)||(t=Pn(t)),On(this.low&t.low,this.high&t.high,this.unsigned)},Qn.or=function(t){return Fn(t)||(t=Pn(t)),On(this.low|t.low,this.high|t.high,this.unsigned)},Qn.xor=function(t){return Fn(t)||(t=Pn(t)),On(this.low^t.low,this.high^t.high,this.unsigned)},Qn.shiftLeft=function(t){return Fn(t)&&(t=t.toInt()),0==(t&=63)?this:t<32?On(this.low<>>32-t,this.unsigned):On(0,this.low<>>t|this.high<<32-t,this.high>>t,this.unsigned):On(this.high>>t-32,this.high>=0?0:-1,this.unsigned)},Qn.shr=Qn.shiftRight,Qn.shiftRightUnsigned=function(t){if(Fn(t)&&(t=t.toInt()),0===(t&=63))return this;var e=this.high;return t<32?On(this.low>>>t|e<<32-t,e>>>t,this.unsigned):On(32===t?e:e>>>t-32,0,this.unsigned)},Qn.shru=Qn.shiftRightUnsigned,Qn.shr_u=Qn.shiftRightUnsigned,Qn.toSigned=function(){return this.unsigned?On(this.low,this.high,!1):this},Qn.toUnsigned=function(){return this.unsigned?this:On(this.low,this.high,!0)},Qn.toBytes=function(t){return t?this.toBytesLE():this.toBytesBE()},Qn.toBytesLE=function(){var t=this.high,e=this.low;return[255&e,e>>>8&255,e>>>16&255,e>>>24,255&t,t>>>8&255,t>>>16&255,t>>>24]},Qn.toBytesBE=function(){var t=this.high,e=this.low;return[t>>>24,t>>>16&255,t>>>8&255,255&t,e>>>24,e>>>16&255,e>>>8&255,255&e]},$n.fromBytes=function(t,e,n){return n?$n.fromBytesLE(t,e):$n.fromBytesBE(t,e)},$n.fromBytesLE=function(t,e){return new $n(t[0]|t[1]<<8|t[2]<<16|t[3]<<24,t[4]|t[5]<<8|t[6]<<16|t[7]<<24,e)},$n.fromBytesBE=function(t,e){return new $n(t[4]<<24|t[5]<<16|t[6]<<8|t[7],t[0]<<24|t[1]<<16|t[2]<<8|t[3],e)};var ts=zn(Tn);const es=ts||s({__proto__:null,default:ts},[Tn]);function ns(t){return es.fromString(t,!0,16)}function ss(t,e){if("string"===e)throw new Error("Cannot convert a string[] to a TypedArray");if(Array.isArray(t)&&(t=os(t)),fn().getBool("DEBUG")&&function(t,e){for(let n=0;n{s=n()};let r;const a=is();if(this.backendTimer.timerAvailable())r=this.backendTimer.time(i);else{i();for(const t of s)t.dataSync();r=Promise.resolve({kernelMs:is()-a})}if(fn().getBool("CHECK_COMPUTATION_FOR_ERRORS"))for(let e=0;e{us(e,n.dtype,t)}))}return{kernelName:t,outputs:s,inputs:e,timeMs:r.then((t=>t.kernelMs)),extraInfo:r.then((t=>null!=t.getExtraProfileInfo?t.getExtraProfileInfo():""))}}logKernelProfile(t){const{kernelName:e,outputs:n,timeMs:s,inputs:i,extraInfo:r}=t;n.forEach((t=>{Promise.all([t.data(),s,r]).then((n=>{this.logger.logKernelProfile(e,t,n[0],n[1],i,n[2])}))}))}}function us(t,e,n){if("float32"!==e)return!1;for(let e=0;e0?s:""} `}}console.log(`%c${o}\t%c${a}\t%c${l}D ${h}\t%c${u}\t%c${c}\t%c${r}`,"font-weight:bold","color:red","color:blue","color: orange","color: green","color: steelblue")}}function cs(t,e,n,s){const i=sn(e),r=function(t,e,n,s){const i=He(e),r=s[s.length-1],a=new Array(r).fill(0),o=e.length,l="complex64"===n?gs(t):t;if(o>1)for(let t=0;t" "+t)).join("\n")),l.join("\n")}function ps(t,e,n){let s;return s=Array.isArray(t)?`${parseFloat(t[0].toFixed(7))} + ${parseFloat(t[1].toFixed(7))}j`:tn(t)?`'${t}'`:"bool"===n?ds(t):parseFloat(t.toFixed(7)).toString(),Ye(s,e)}function ds(t){return 0===t?"false":"true"}function fs(t,e,n,s,i,r=!0){const a="complex64"===n?2:1,o=e[0],l=e.length;if(0===l){if("complex64"===n){return[ps(gs(t)[0],0,n)]}return"bool"===n?[ds(t[0])]:[t[0].toString()]}if(1===l){if(o>20){const e=3*a;let s=Array.from(t.slice(0,e)),r=Array.from(t.slice((o-3)*a,o*a));return"complex64"===n&&(s=gs(s),r=gs(r)),["["+s.map(((t,e)=>ps(t,i[e],n))).join(", ")+", ..., "+r.map(((t,e)=>ps(t,i[o-3+e],n))).join(", ")+"]"]}return["["+("complex64"===n?gs(t):Array.from(t)).map(((t,e)=>ps(t,i[e],n))).join(", ")+"]"]}const u=e.slice(1),h=s.slice(1),c=s[0]*a,p=[];if(o>20){for(let e=0;e<3;e++){const s=e*c,r=s+c;p.push(...fs(t.slice(s,r),u,n,h,i,!1))}p.push("...");for(let e=o-3;e0?p[0]+d:"");for(let t=1;trs(t)))}catch(t){throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().")}}return t}dataToGPU(t){return this.throwIfDisposed(),ms().readToGPU(this.dataId,t)}dataSync(){this.throwIfDisposed();const t=ms().readSync(this.dataId);if("string"===this.dtype)try{return t.map((t=>rs(t)))}catch(t){throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().")}return t}async bytes(){this.throwIfDisposed();const t=await ms().read(this.dataId);return"string"===this.dtype?t:new Uint8Array(t.buffer)}dispose(){this.isDisposed||(this.kerasMask&&this.kerasMask.dispose(),ms().disposeTensor(this),this.isDisposedInternal=!0)}get isDisposed(){return this.isDisposedInternal}throwIfDisposed(){if(this.isDisposed)throw new Error("Tensor is disposed.")}print(t=!1){return ys.print(this,t)}clone(){return this.throwIfDisposed(),ys.clone(this)}toString(t=!1){return cs(this.dataSync(),this.shape,this.dtype,t)}cast(t){return this.throwIfDisposed(),ys.cast(this,t)}variable(t=!0,e,n){return this.throwIfDisposed(),ms().makeVariable(this,t,e,n)}}function ws(){return bn("Tensor",(()=>bs))}Object.defineProperty(bs,Symbol.hasInstance,{value:t=>!!t&&null!=t.data&&null!=t.dataSync&&null!=t.throwIfDisposed}),ws();class ks extends bs{constructor(t,e,n,s){super(t.shape,t.dtype,t.dataId,s),this.trainable=e,this.name=n}assign(t){if(t.dtype!==this.dtype)throw new Error(`dtype of the new value (${t.dtype}) and previous value (${this.dtype}) must match`);if(!Je(t.shape,this.shape))throw new Error(`shape of the new value (${t.shape}) and previous value (${this.shape}) must match`);ms().disposeTensor(this),this.dataId=t.dataId,ms().incRef(this,null)}dispose(){ms().disposeVariable(this),this.isDisposedInternal=!0}}var vs,Ss,xs,Ns,Is;Object.defineProperty(ks,Symbol.hasInstance,{value:t=>t instanceof bs&&null!=t.assign&&t.assign instanceof Function}),function(t){t.R0="R0",t.R1="R1",t.R2="R2",t.R3="R3",t.R4="R4",t.R5="R5",t.R6="R6"}(vs||(vs={})),function(t){t.float32="float32",t.int32="int32",t.bool="int32",t.complex64="complex64"}(Ss||(Ss={})),function(t){t.float32="float32",t.int32="int32",t.bool="bool",t.complex64="complex64"}(xs||(xs={})),function(t){t.float32="float32",t.int32="float32",t.bool="float32",t.complex64="complex64"}(Ns||(Ns={})),function(t){t.float32="complex64",t.int32="complex64",t.bool="complex64",t.complex64="complex64"}(Is||(Is={}));const As={float32:Ns,int32:Ss,bool:xs,complex64:Is};function zs(t){return null!=t&&"object"==typeof t&&"texture"in t&&t.texture instanceof WebGLTexture}function Es(t){return"undefined"!=typeof GPUBuffer&&null!=t&&"object"==typeof t&&"buffer"in t&&t.buffer instanceof GPUBuffer}function Ts(t,e){if(t.dtype===e.dtype)return[t,e];const n=function(t,e){if("string"===t||"string"===e){if("string"===t&&"string"===e)return"string";throw new Error(`Can not upcast ${t} with ${e}`)}return As[t][e]}(t.dtype,e.dtype);return[t.cast(n),e.cast(n)]}function Cs(t){const e=[];return $s(t,e,new Set),e}function $s(t,e,n){if(null==t)return;if(t instanceof bs)return void e.push(t);if(s=t,!Array.isArray(s)&&"object"!=typeof s)return;var s;const i=t;for(const t in i){const s=i[t];n.has(s)||(n.add(s),$s(s,e,n))}}function Fs(t){return null!=t.kernelName}class Ds{constructor(){this.registeredVariables={},this.nextTapeNodeId=0,this.numBytes=0,this.numTensors=0,this.numStringTensors=0,this.numDataBuffers=0,this.gradientDepth=0,this.kernelDepth=0,this.scopeStack=[],this.numDataMovesStack=[],this.nextScopeId=0,this.tensorInfo=new WeakMap,this.profiling=!1,this.activeProfile={newBytes:0,newTensors:0,peakBytes:0,kernels:[],result:null,get kernelNames(){return Array.from(new Set(this.kernels.map((t=>t.name))))}}}dispose(){for(const t in this.registeredVariables)this.registeredVariables[t].dispose()}}class Ls{constructor(t){this.ENV=t,this.registry={},this.registryFactory={},this.pendingBackendInitId=0,this.state=new Ds}async ready(){if(null!=this.pendingBackendInit)return this.pendingBackendInit.then((()=>{}));if(null!=this.backendInstance)return;const t=this.getSortedBackends();for(let e=0;e{null!=t.setupFunc&&t.setupFunc(this.backendInstance)}))}disposeRegisteredKernels(t){Nn(t).forEach((e=>{null!=e.disposeFunc&&e.disposeFunc(this.registry[t])}))}initializeBackend(t){const e=this.registryFactory[t];if(null==e)throw new Error(`Cannot initialize backend ${t}, no registration found.`);try{const n=e.factory();if(!n||n instanceof class{refCount(t){return Ke("refCount")}incRef(t){return Ke("incRef")}timerAvailable(){return!0}time(t){return Ke("time")}read(t){return Ke("read")}readSync(t){return Ke("readSync")}readToGPU(t,e){return Ke("readToGPU")}numDataIds(){return Ke("numDataIds")}disposeData(t,e){return Ke("disposeData")}write(t,e,n){return Ke("write")}move(t,e,n,s,i){return Ke("move")}createTensorFromGPUData(t,e,n){return Ke("createTensorFromGPUData")}memory(){return Ke("memory")}floatPrecision(){return Ke("floatPrecision")}epsilon(){return 32===this.floatPrecision()?1e-7:1e-4}dispose(){return Ke("dispose")}}||"function"!=typeof n.then)return this.registry[t]=n,{success:!0,asyncInit:!1};{const e=++this.pendingBackendInitId,s=n.then((n=>!(e(ethis.registryFactory[e].priority-this.registryFactory[t].priority))}initializeBackendsAndReturnBest(){const t=this.getSortedBackends();for(let e=0;ethis.startScope(s)),(()=>this.endScope(n)),(()=>(n=e(),n instanceof Promise&&console.error("Cannot return a Promise inside of tidy."),n)))}scopedRun(t,e,n){t();try{const t=n();return e(),t}catch(t){throw e(),t}}nextTensorId(){return Ls.nextTensorId++}nextVariableId(){return Ls.nextVariableId++}clone(t){const e=_s.runKernel(Ie,{x:t}),n={x:t};return this.addTapeNode(this.state.activeScope.name,n,[e],(t=>({x:()=>{const e={x:t},n={dtype:"float32"};return _s.runKernel(de,e,n)}})),[],{}),e}runKernel(t,e,n){null==this.backendName&&this.backend;if(!(null!=Sn(t,this.backendName)))throw new Error(`Kernel '${t}' not registered for backend '${this.backendName}'`);return this.runKernelFunc({kernelName:t,inputs:e,attrs:n})}shouldCheckForMemLeaks(){return this.ENV.getBool("IS_TEST")}checkKernelForMemLeak(t,e,n){const s=this.backend.numDataIds();let i=0;n.forEach((t=>{i+="complex64"===t.dtype?3:1}));const r=this.state.numDataMovesStack[this.state.numDataMovesStack.length-1],a=s-e-i-r;if(a>0)throw new Error(`Backend '${this.backendName}' has an internal memory leak (${a} data ids) after running '${t}'`)}runKernelFunc(t){let e,n=[];const s=this.isTapeOn(),i=this.state.numBytes,r=this.state.numTensors;let a,o;this.shouldCheckForMemLeaks()&&this.state.numDataMovesStack.push(0),null==this.backendName&&this.backend;const l=Fs(t)?t.kernelName:null!=this.state.activeScope?this.state.activeScope.name:"";if(Fs(t)){const{kernelName:e,inputs:i,attrs:r}=t;null==this.backendName&&this.backend;const l=Sn(e,this.backendName);Ge(null!=l,(()=>`Cannot find registered kernel '${e}' for backend '${this.backendName}'`)),a=()=>{const t=this.backend.numDataIds();o=l.kernelFunc({inputs:i,attrs:r,backend:this.backend});const a=Array.isArray(o)?o:[o];this.shouldCheckForMemLeaks()&&this.checkKernelForMemLeak(e,t,a);const u=a.map((t=>null!=t.rank?t:this.makeTensorFromTensorInfo(t)));if(s){const t=this.getTensorsForGradient(e,i,u);n=this.saveTensorsForBackwardMode(t)}return u}}else{const{forwardFunc:e}=t,i=t=>{s&&(n=t.map((t=>this.keep(this.clone(t)))))};a=()=>{const t=this.backend.numDataIds();o=this.tidy((()=>e(this.backend,i)));const n=Array.isArray(o)?o:[o];return this.shouldCheckForMemLeaks()&&this.checkKernelForMemLeak(l,t,n),n}}const{inputs:u,attrs:h}=t,c=Fs(t)?null:t.backwardsFunc;let p;return this.scopedRun((()=>this.state.kernelDepth++),(()=>this.state.kernelDepth--),(()=>{this.ENV.getBool("DEBUG")||this.state.profiling?(p=this.profiler.profileKernel(l,u,(()=>a())),this.ENV.getBool("DEBUG")&&this.profiler.logKernelProfile(p),e=p.outputs):e=a()})),s&&this.addTapeNode(l,u,e,c,n,h),this.state.profiling&&this.state.activeProfile.kernels.push({name:l,bytesAdded:this.state.numBytes-i,totalBytesSnapshot:this.state.numBytes,tensorsAdded:this.state.numTensors-r,totalTensorsSnapshot:this.state.numTensors,inputShapes:Object.keys(u).map((t=>null!=u[t]?u[t].shape:null)),outputShapes:e.map((t=>t.shape)),kernelTimeMs:p.timeMs,extraInfo:p.extraInfo}),Array.isArray(o)?e:e[0]}saveTensorsForBackwardMode(t){return t.map((t=>this.keep(this.clone(t))))}getTensorsForGradient(t,e,n){const s=xn(t);if(null!=s){const t=s.inputsToSave||[],i=s.outputsToSave||[];let r;s.saveAllInputs?(Ge(Array.isArray(e),(()=>"saveAllInputs is true, expected inputs to be an array.")),r=Object.keys(e).map((t=>e[t]))):r=t.map((t=>e[t]));const a=n.filter(((t,e)=>i[e]));return r.concat(a)}return[]}makeTensor(t,e,n,s){if(null==t)throw new Error("Values passed to engine.makeTensor() are null");n=n||"float32",s=s||this.backend;let i=t;"string"===n&&tn(t[0])&&(i=t.map((t=>function(t,e="utf-8"){return e=e||"utf-8",fn().platform.encode(t,e)}(t))));const r=s.write(i,e,n),a=new bs(e,n,r,this.nextTensorId());if(this.trackTensor(a,s),"string"===n){const t=this.state.tensorInfo.get(r),e=function(t){if(null==t)return 0;let e=0;return t.forEach((t=>e+=t.length)),e}(i);this.state.numBytes+=e-t.bytes,t.bytes=e}return a}makeTensorFromDataId(t,e,n,s){const i={dataId:t,shape:e,dtype:n=n||"float32"};return this.makeTensorFromTensorInfo(i,s)}makeTensorFromTensorInfo(t,e){const{dataId:n,shape:s,dtype:i}=t,r=new bs(s,i,n,this.nextTensorId());return this.trackTensor(r,e),r}makeVariable(t,e=!0,n,s){n=n||this.nextVariableId().toString(),null!=s&&s!==t.dtype&&(t=t.cast(s));const i=new ks(t,e,n,this.nextTensorId());if(null!=this.state.registeredVariables[i.name])throw new Error(`Variable with name ${i.name} was already registered`);return this.state.registeredVariables[i.name]=i,this.incRef(i,this.backend),i}trackTensor(t,e){this.state.numTensors++,"string"===t.dtype&&this.state.numStringTensors++;let n=0;"complex64"!==t.dtype&&"string"!==t.dtype&&(n=t.size*Qe(t.dtype)),this.state.numBytes+=n,this.state.tensorInfo.has(t.dataId)||(this.state.numDataBuffers++,this.state.tensorInfo.set(t.dataId,{backend:e||this.backend,dtype:t.dtype,shape:t.shape,bytes:n})),t instanceof ks||this.track(t)}incRef(t,e){this.trackTensor(t,e),this.backend.incRef(t.dataId)}removeDataId(t,e){this.state.tensorInfo.has(t)&&this.state.tensorInfo.get(t).backend===e&&(this.state.tensorInfo.delete(t),this.state.numDataBuffers--)}disposeTensor(t){if(!this.state.tensorInfo.has(t.dataId))return;const e=this.state.tensorInfo.get(t.dataId);if(this.state.numTensors--,"string"===t.dtype&&(this.state.numStringTensors--,this.state.numBytes-=e.bytes),"complex64"!==t.dtype&&"string"!==t.dtype){const e=t.size*Qe(t.dtype);this.state.numBytes-=e}e.backend.disposeData(t.dataId)&&this.removeDataId(t.dataId,e.backend)}disposeVariables(){for(const t in this.state.registeredVariables){const e=this.state.registeredVariables[t];this.disposeVariable(e)}}disposeVariable(t){this.disposeTensor(t),null!=this.state.registeredVariables[t.name]&&delete this.state.registeredVariables[t.name]}memory(){const t=this.backend.memory();return t.numTensors=this.state.numTensors,t.numDataBuffers=this.state.numDataBuffers,t.numBytes=this.state.numBytes,this.state.numStringTensors>0&&(t.unreliable=!0,null==t.reasons&&(t.reasons=[]),t.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)")),t}async profile(t){this.state.profiling=!0;const e=this.state.numBytes,n=this.state.numTensors;this.state.activeProfile.kernels=[],this.state.activeProfile.result=await t(),this.state.profiling=!1,this.state.activeProfile.peakBytes=Math.max(...this.state.activeProfile.kernels.map((t=>t.totalBytesSnapshot))),this.state.activeProfile.newBytes=this.state.numBytes-e,this.state.activeProfile.newTensors=this.state.numTensors-n;for(const t of this.state.activeProfile.kernels)t.kernelTimeMs=await t.kernelTimeMs,t.extraInfo=await t.extraInfo;return this.state.activeProfile}isTapeOn(){return this.state.gradientDepth>0&&0===this.state.kernelDepth}addTapeNode(t,e,n,s,i,r){const a={id:this.state.nextTapeNodeId++,kernelName:t,inputs:e,outputs:n,saved:i},o=xn(t);null!=o&&(s=o.gradFunc),null!=s&&(a.gradient=t=>(t=t.map(((t,e)=>{if(null==t){const t=n[e],s=ln(t.size,t.dtype);return this.makeTensor(s,t.shape,t.dtype)}return t})),s(t.length>1?t:t[0],i,r))),this.state.activeTape.push(a)}keep(t){return t.kept=!0,t}startTape(){0===this.state.gradientDepth&&(this.state.activeTape=[]),this.state.gradientDepth++}endTape(){this.state.gradientDepth--}startScope(t){const e={track:[],name:"unnamed scope",id:this.state.nextScopeId++};t&&(e.name=t),this.state.scopeStack.push(e),this.state.activeScope=e}endScope(t){const e=Cs(t),n=new Set(e.map((t=>t.id)));for(let t=0;t{t.kept||t.scopeId!==s.id||this.track(t)}))}gradients(t,e,n,s=!1){if(Ge(e.length>0,(()=>"gradients() received an empty list of xs.")),null!=n&&"float32"!==n.dtype)throw new Error(`dy must have 'float32' dtype, but has '${n.dtype}'`);const i=this.scopedRun((()=>this.startTape()),(()=>this.endTape()),(()=>this.tidy("forward",t)));Ge(i instanceof bs,(()=>"The result y returned by f() must be a tensor."));const r=function(t,e,n){const s={},i={};for(let t=0;ts[t.id]=!0)),o=!0,i[r.id]=!0;break}if(o)break}}const r={};r[n.id]=!0;const a={};for(let e=t.length-1;e>=0;e--){const n=t[e],s=n.inputs;for(let t=0;t0)throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y.");return this.tidy("backward",(()=>{const t={};t[i.id]=null==n?function(t){const e=on(He(t),"float32");return _s.makeTensor(e,t,"float32")}(i.shape):n,function(t,e,n,s){for(let i=e.length-1;i>=0;i--){const r=e[i],a=[];if(r.outputs.forEach((e=>{const n=t[e.id];null!=n?a.push(n):a.push(null)})),null==r.gradient)throw new Error(`Cannot compute gradient: gradient function not found for ${r.kernelName}.`);const o=r.gradient(a);for(const e in r.inputs){if(!(e in o))throw new Error(`Cannot backprop through input ${e}. Available gradients found: ${Object.keys(o)}.`);const i=n((()=>o[e]()));if("float32"!==i.dtype)throw new Error(`Error in gradient for op ${r.kernelName}. The gradient of input ${e} must have 'float32' dtype, but has '${i.dtype}'`);const a=r.inputs[e];if(!Je(i.shape,a.shape))throw new Error(`Error in gradient for op ${r.kernelName}. The gradient of input '${e}' has shape '${i.shape}', which does not match the shape of the input '${a.shape}'`);if(null==t[a.id])t[a.id]=i;else{const e=t[a.id];t[a.id]=s(e,i),e.dispose()}}}}(t,r,(t=>this.tidy(t)),Rs);const s=e.map((e=>t[e.id]));return 0===this.state.gradientDepth&&(this.state.activeTape.forEach((t=>{for(const e of t.saved)e.dispose()})),this.state.activeTape=null),{value:i,grads:s}}))}customGrad(t){return Ge(nn(t),(()=>"The f passed in customGrad(f) must be a function.")),(...e)=>{let n;Ge(e.every((t=>t instanceof bs)),(()=>"The args passed in customGrad(f)(x1, x2,...) must all be tensors"));const s={};e.forEach(((t,e)=>{s[e]=t}));return this.runKernelFunc({forwardFunc:(s,i)=>(n=t(...e,i),Ge(n.value instanceof bs,(()=>"The function f passed in customGrad(f) must return an object where `obj.value` is a tensor")),Ge(nn(n.gradFunc),(()=>"The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function.")),n.value),backwardsFunc:(t,s)=>{const i=n.gradFunc(t,s),r=Array.isArray(i)?i:[i];Ge(r.length===e.length,(()=>"The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...).")),Ge(r.every((t=>t instanceof bs)),(()=>"The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors."));const a={};return r.forEach(((t,e)=>{a[e]=()=>t})),a},inputs:s})}}readSync(t){return this.state.tensorInfo.get(t).backend.readSync(t)}read(t){return this.state.tensorInfo.get(t).backend.read(t)}readToGPU(t,e){return this.state.tensorInfo.get(t).backend.readToGPU(t,e)}async time(t){const e=is(),n=await this.backend.time(t);return n.wallMs=is()-e,n}track(t){return null!=this.state.activeScope&&(t.scopeId=this.state.activeScope.id,this.state.activeScope.track.push(t)),t}get registeredVariables(){return this.state.registeredVariables}reset(){this.pendingBackendInitId++,this.state.dispose(),this.ENV.reset(),this.state=new Ds;for(const t in this.registry)this.disposeRegisteredKernels(t),this.registry[t].dispose(),delete this.registry[t];this.backendName=null,this.backendInstance=null,this.pendingBackendInit=null}}Ls.nextTensorId=0,Ls.nextVariableId=0;const _s=function(){const t=yn();if(null==t._tfengine){const e=new pn(t);t._tfengine=new Ls(e)}var e;return e=t._tfengine.ENV,mn=e,ms=()=>t._tfengine,t._tfengine}();function Rs(t,e){const n={a:t,b:e};return _s.runKernel(he,n)}function Os(t,e){let n=t;if(as(t))return"string"===e?[]:[t.length];if(zs(t)){const e=t.channels||"RGBA";return[t.height,t.width*e.length]}if(Es(t))return[t.buffer.size/(null==e?4:Qe(e))];if(!Array.isArray(t))return[];const s=[];for(;Array.isArray(n)||as(n)&&"string"!==e;)s.push(n.length),n=n[0];return Array.isArray(t)&&fn().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY")&&Ms(t,s,[]),s}function Ms(t,e,n){if(n=n||[],!Array.isArray(t)&&!as(t))return void Ge(0===e.length,(()=>`Element arr[${n.join("][")}] is a primitive, but should be an array/TypedArray of ${e[0]} elements`));Ge(e.length>0,(()=>`Element arr[${n.join("][")}] should be a primitive, but is an array of ${t.length} elements`)),Ge(t.length===e[0],(()=>`Element arr[${n.join("][")}] should have ${e[0]} elements, but has ${t.length} elements`));const s=e.slice(1);for(let e=0;e=0&&(i=s),Bs(s,i,e,n),null==t||!as(t)&&!Array.isArray(t)&&"number"!=typeof t&&"boolean"!=typeof t&&"string"!=typeof t){const s=null==t?"null":t.constructor.name;throw new Error(`Argument '${e}' passed to '${n}' must be a Tensor or TensorLike, but got '${s}'`)}const r=Os(t,i);as(t)||Array.isArray(t)||(t=[t]);const a="string"!==i?ss(t,i):os(t,[],!0);return _s.makeTensor(a,r,i)}function Us(t,e,n,s="numeric"){if(!Array.isArray(t))throw new Error(`Argument ${e} passed to ${n} must be a \`Tensor[]\` or \`TensorLike[]\``);return t.map(((t,i)=>Ps(t,`${e}[${i}]`,n,s)))}function Ws(t){const e=Object.keys(t);if(1!==e.length)throw new Error(`Please provide an object with a single key (operation name) mapping to a function. Got an object with ${e.length} keys.`);let n=e[0];const s=t[n];n.endsWith("_")&&(n=n.substring(0,n.length-1)),n+="__op";const i=(...t)=>{_s.startScope(n);try{const e=s(...t);return hn(e)&&console.error("Cannot return a Promise inside of tidy."),_s.endScope(e),e}catch(t){throw _s.endScope(null),t}};return Object.defineProperty(i,"name",{value:n,configurable:!0}),i}const js=Ws({cast_:function(t,e){const n=Ps(t,"x","cast");if(!function(t){return"bool"===t||"complex64"===t||"float32"===t||"int32"===t||"string"===t}(e))throw new Error(`Failed to cast to unknown dtype ${e}`);if("string"===e&&"string"!==n.dtype||"string"!==e&&"string"===n.dtype)throw new Error("Only strings can be casted to strings");const s={x:n},i={dtype:e};return _s.runKernel(de,s,i)}});const qs=Ws({mul_:function(t,e){let n=Ps(t,"a","mul"),s=Ps(e,"b","mul");[n,s]=Ts(n,s);const i={a:n,b:s};return _s.runKernel(ze,i)}});const Vs=Ws({step_:function(t,e=0){const n={x:Ps(t,"x","step")},s={alpha:e};return _s.runKernel(Ve,n,s)}}),Ks={kernelName:"Abs",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(t,Vs(js(n,"float32"),-1))}}};const Gs=Ws({floorDiv_:function(t,e){let n=Ps(t,"a","floorDiv"),s=Ps(e,"b","floorDiv");[n,s]=Ts(n,s);const i={a:n,b:s};return _s.runKernel(Se,i)}});const Hs=Ws({div_:function(t,e){let n=Ps(t,"a","div"),s=Ps(e,"b","div");if([n,s]=Ts(n,s),"int32"===n.dtype&&"int32"===s.dtype)return Gs(n,s);const i={a:n,b:s};return _s.runKernel(we,i,{})}});const Js=Ws({neg_:function(t){const e={x:Ps(t,"x","neg")};return _s.runKernel("Neg",e)}});function Zs(t,e){if((as(t)&&"string"!==e||Array.isArray(t))&&"complex64"!==e)throw new Error("Error creating a new Scalar: value must be a primitive (number|boolean|string)");if("string"===e&&as(t)&&!(t instanceof Uint8Array))throw new Error("When making a scalar from encoded string, the value must be `Uint8Array`.");return function(t,e,n,s){if(null==s)s=en(t);else if("complex64"===s)throw new Error("Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag).");if(Es(t)||zs(t)){if("float32"!==s&&"int32"!==s)throw new Error(`Creating tensor from GPU data only supports 'float32'|'int32' dtype, while the dtype is ${s}.`);return _s.backend.createTensorFromGPUData(t,e||n,s)}if(!as(t)&&!Array.isArray(t)&&"number"!=typeof t&&"boolean"!=typeof t&&"string"!=typeof t)throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray");if(null!=e){un(e);const t=He(e),s=He(n);Ge(t===s,(()=>`Based on the provided shape, [${e}], the tensor should have ${t} values but has ${s}`));for(let t=0;t`Error creating a new Tensor. Inferred shape (${n}) does not match the provided shape (${e}). `))}}return as(t)||Array.isArray(t)||(t=[t]),e=e||n,t="string"!==s?ss(t,s):os(t,[],!0),_s.makeTensor(t,e,s)}(t,[],[],e)}const Ys=Ws({sqrt_:function(t){const e={x:Ps(t,"x","sqrt","float32")};return _s.runKernel(Oe,e)}});const Xs=Ws({square_:function(t){const e=Ps(t,"x","square");return _s.runKernel("Square",{x:e},{})}});const Qs=Ws({sub_:function(t,e){let n=Ps(t,"a","sub"),s=Ps(e,"b","sub");[n,s]=Ts(n,s);const i={a:n,b:s};return _s.runKernel("Sub",i)}}),ti={kernelName:"Acos",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>{const e=Xs(js(n,"float32")),s=Ys(Qs(Zs(1),e));return Js(Hs(t,s))}}}},ei={kernelName:"Acosh",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>{const e=Ys(Qs(Xs(js(n,"float32")),1));return Hs(t,e)}}}};function ni(t,e){const n=[];for(let s=0;s1)&&n.unshift(r)}return n}function si(t,e){const n=Math.max(t.length,e.length),s=new Array(n);for(let i=0;i{const[n,s]=e,i=si(n.shape,s.shape);return{a:()=>{let e=t;const s=ni(n.shape,i);return s.length>0&&(e=ri(e,s)),ii(e,n.shape)},b:()=>{let e=t;const n=ni(s.shape,i);return n.length>0&&(e=ri(e,n)),ii(e,s.shape)}}}},oi={kernelName:"AddN",saveAllInputs:!0,gradFunc:(t,e)=>{const n={};return e.forEach(((e,s)=>{n[s]=()=>t.clone()})),n}};const li=Ws({zerosLike_:function(t){const e={x:Ps(t,"x","zerosLike")};return _s.runKernel(qe,e)}}),ui={kernelName:"ArgMax",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>li(n)}}},hi={kernelName:"ArgMin",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>li(n)}}},ci={kernelName:"Asin",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>Hs(t,Ys(Qs(Zs(1),Xs(js(n,"float32")))))}}};const pi=Ws({add_:function(t,e){let n=Ps(t,"a","add"),s=Ps(e,"b","add");[n,s]=Ts(n,s);const i={a:n,b:s};return _s.runKernel(he,i)}}),di={kernelName:"Asinh",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>{const e=Ys(pi(Zs(1),Xs(js(n,"float32"))));return Hs(t,e)}}}},fi={kernelName:"Atan2",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=si(n.shape,s.shape);return{a:()=>{const e=pi(Xs(n),Xs(s));let r=qs(t,Hs(s,e));const a=ni(n.shape,i);return a.length>0&&(r=ri(r,a)),ii(r,n.shape)},b:()=>{const e=pi(Xs(n),Xs(s));let r=Js(qs(t,Hs(n,e)));const a=ni(s.shape,i);return a.length>0&&(r=ri(r,a)),ii(r,s.shape)}}}},gi={kernelName:"Atan",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>Hs(t,pi(Xs(js(n,"float32")),1))}}},mi={kernelName:"Atanh",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>Hs(t,Qs(Zs(1),Xs(js(n,"float32"))))}}};function yi(t){return"number"==typeof t?[t,t,t]:2===t.length?[t[0],t[1],1]:t}function bi(t){const[e,n,s]=yi(t);return 1===e&&1===n&&1===s}function wi(t,e){return bi(t)||bi(e)}function ki(t){return yi(t).every((t=>t>0))}function vi(t,e,n){if(null!=n){if("string"==typeof e)throw Error(`Error in ${t}: pad must be an integer when using dimRoundingMode ${n} but got pad ${e}.`);if("number"==typeof e)Ge(Ze(e),(()=>`Error in ${t}: pad must be an integer when using dimRoundingMode ${n} but got pad ${e}.`));else{if("object"!=typeof e)throw Error(`Error in ${t}: Unknown padding parameter: ${e}`);e.forEach((e=>{e.forEach((e=>{Ge(Ze(e),(()=>`Error in ${t}: pad must be an integer when using dimRoundingMode ${n} but got pad ${e}.`))}))}))}}}const Si=Ws({avgPool3dGrad_:function(t,e,n,s,i,r){const a=Ps(t,"dy","avgPool3dGrad"),o=Ps(e,"input","avgPool3dGrad");let l=a,u=o,h=!1;4===o.rank&&(h=!0,l=ii(a,[1,a.shape[0],a.shape[1],a.shape[2],a.shape[3]]),u=ii(o,[1,o.shape[0],o.shape[1],o.shape[2],o.shape[3]])),Ge(5===l.rank,(()=>`Error in avgPool3dGrad: dy must be rank 5 but got rank ${l.rank}.`)),Ge(5===u.rank,(()=>`Error in avgPool3dGrad: input must be rank 5 but got rank ${u.rank}.`)),vi("avgPool3dGrad",i,r);const c={dy:l,input:u},p={filterSize:n,strides:s,pad:i,dimRoundingMode:r},d=_s.runKernel("AvgPool3DGrad",c,p);return h?ii(d,[d.shape[1],d.shape[2],d.shape[3],d.shape[4]]):d}}),xi={kernelName:"AvgPool3D",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{filterSize:i,strides:r,pad:a,dimRoundingMode:o}=n;return{x:()=>Si(t,s,i,r,a,o)}}};const Ni=Ws({avgPoolGrad_:function(t,e,n,s,i){const r=Ps(t,"dy","avgPoolGrad"),a=Ps(e,"input","avgPoolGrad");Ge(a.rank===r.rank,(()=>`Rank of input (${a.rank}) does not match rank of dy (${r.rank})`));let o=a,l=r,u=!1;3===a.rank&&(u=!0,o=ii(a,[1,a.shape[0],a.shape[1],a.shape[2]]),l=ii(r,[1,r.shape[0],r.shape[1],r.shape[2]])),Ge(4===l.rank,(()=>`Error in avgPoolGrad: dy must be rank 4 but got rank ${l.rank}.`)),Ge(4===o.rank,(()=>`Error in avgPoolGrad: input must be rank 4 but got rank ${o.rank}.`));const h={dy:l,input:o},c={filterSize:n,strides:s,pad:i},p=_s.runKernel("AvgPoolGrad",h,c);return u?ii(p,[p.shape[1],p.shape[2],p.shape[3]]):p}}),Ii={kernelName:"AvgPool",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{filterSize:i,strides:r,pad:a}=n;return{x:()=>Ni(t,s,i,r,a)}}};const Ai=Ws({matMul_:function(t,e,n=!1,s=!1){let i=Ps(t,"a","matMul"),r=Ps(e,"b","matMul");[i,r]=Ts(i,r);const a={a:i,b:r},o={transposeA:n,transposeB:s};return _s.runKernel(ce,a,o)}}),zi={kernelName:ce,inputsToSave:["a","b"],gradFunc:(t,e,n)=>{const[s,i]=e,{transposeA:r,transposeB:a}=n;return r||a?!r&&a?{a:()=>Ai(t,i,!1,!1),b:()=>Ai(t,s,!0,!1)}:r&&!a?{a:()=>Ai(i,t,!1,!0),b:()=>Ai(s,t,!1,!1)}:{a:()=>Ai(i,t,!0,!0),b:()=>Ai(t,s,!0,!0)}:{a:()=>Ai(t,i,!1,!0),b:()=>Ai(s,t,!0,!1)}}};const Ei=Ws({spaceToBatchND_:function(t,e,n){const s=Ps(t,"x","spaceToBatchND");Ge(s.rank>=1+e.length,(()=>`input rank ${s.rank} should be > than [blockShape] ${e.length}`)),Ge(n.length===e.length,(()=>`paddings.shape[0] ${n.length} must be equal to [blockShape] ${e.length}`)),Ge(s.shape.reduce(((t,s,i)=>i>0&&i<=e.length?t&&(s+n[i-1][0]+n[i-1][1])%e[i-1]==0:t),!0),(()=>`input spatial dimensions ${s.shape.slice(1)} with paddings ${n.toString()} must be divisible by blockShapes ${e.toString()}`));const i={x:s},r={blockShape:e,paddings:n};return _s.runKernel(Me,i,r)}}),Ti={kernelName:pe,gradFunc:(t,e,n)=>{const{blockShape:s,crops:i}=n;return{x:()=>Ei(t,s,i)}}},Ci={kernelName:"BroadcastTo",gradFunc:(t,e,n)=>{const s=n,i=s.inputShape,r=s.shape,a=Array.from(r);for(let t=i.length-1;t>=0;t--)if(i[t]===r[t])a[t]=1;else if(1!==i[t])throw new Error(`broadcastTo(): [${i}] cannot be broadcast to [${r}].`);const o=[];for(let t=0;t1&&o.push(t);return{x:()=>ri(t,o,!0)}}},$i={kernelName:de,gradFunc:t=>({x:()=>t.clone()})},Fi={kernelName:"Ceil",gradFunc:t=>({x:()=>li(t)})};const Di=Ws({greaterEqual_:function(t,e){let n=Ps(t,"a","greaterEqual","string_or_numeric"),s=Ps(e,"b","greaterEqual","string_or_numeric");[n,s]=Ts(n,s),si(n.shape,s.shape);const i={a:n,b:s};return _s.runKernel(Ne,i)}});const Li=Ws({lessEqual_:function(t,e){let n=Ps(t,"a","lessEqual","string_or_numeric"),s=Ps(e,"b","lessEqual","string_or_numeric");[n,s]=Ts(n,s),si(n.shape,s.shape);const i={a:n,b:s};return _s.runKernel("LessEqual",i)}});const _i=Ws({logicalAnd_:function(t,e){const n=Ps(t,"a","logicalAnd","bool"),s=Ps(e,"b","logicalAnd","bool");si(n.shape,s.shape);const i={a:n,b:s};return _s.runKernel("LogicalAnd",i)}});const Ri=Ws({clone_:function(t){const e={x:Ps(t,"x","clone","string_or_numeric")};return _s.runKernel(Ie,e)}});const Oi=Ws({broadcastTo_:function(t,e){let n=Ps(t,"broadcastTo","x");const s=n.shape;if(un(e),e.lengthn.rank){const t=n.shape.slice();for(;t.length=0;t--)if(i[t]===e[t])r[t]=1;else if(1!==n.shape[t])throw new Error(`broadcastTo(): [${s}] cannot be broadcast to [${e}].`);if(0===r.map(((t,e)=>t>1?e:-1)).filter((t=>t>=0)).length)return Ri(n);const a={x:n},o={reps:r};return _s.runKernel(Pe,a,o)}});const Mi=Ws({where_:function(t,e,n){const s=Ps(e,"a","where"),i=Ps(n,"b","where"),r=Ps(t,"condition","where","bool"),a=si(si(r.shape,s.shape),i.shape),o={condition:Oi(r,a),t:Oi(s,a),e:Oi(i,a)};return _s.runKernel(De,o)}}),Bi={kernelName:"ClipByValue",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{clipValueMin:i,clipValueMax:r}=n;return{x:()=>Mi(_i(Di(s,i),Li(s,r)),t,li(t))}}},Pi={kernelName:"ComplexAbs",inputsToSave:["x"],gradFunc:Ks.gradFunc};const Ui=Ws({split_:function(t,e,n=0){const s={x:Ps(t,"x","split")},i={numOrSizeSplits:e,axis:n};return _s.runKernel(Be,s,i)}}),Wi={kernelName:fe,saveAllInputs:!0,gradFunc:(t,e,n)=>{const s=e.map((t=>t.shape)),{axis:i}=n,r=Xe(i,e[0].shape)[0],a=s.map((t=>t[r]));return Ui(t,a,r).map((t=>()=>t))}};const ji=Ws({conv2DBackpropFilter_:function(t,e,n,s,i,r="NHWC",a){let o=t;3===t.rank&&(o=ii(t,[1,t.shape[0],t.shape[1],t.shape[2]]));let l=e;3===l.rank&&(l=ii(e,[1,e.shape[0],e.shape[1],e.shape[2]])),Ge(4===o.rank,(()=>`Error in conv2dDerFilter: input must be rank 4, but got shape ${o.shape}.`)),Ge(4===l.rank,(()=>`Error in conv2dDerFilter: dy must be rank 4, but got shape ${l.shape}.`)),Ge(4===n.length,(()=>`Error in conv2dDerFilter: filterShape must be length 4, but got ${n}.`));const u="NHWC"===r?o.shape[3]:o.shape[1],h="NHWC"===r?l.shape[3]:l.shape[1];Ge(u===n[2],(()=>`Error in conv2dDerFilter: depth of input ${u}) must match input depth in filter (${n[2]}.`)),Ge(h===n[3],(()=>`Error in conv2dDerFilter: depth of dy (${h}) must match output depth for filter (${n[3]}).`)),vi("conv2dDerFilter",i,a);const c={x:o,dy:l},p={strides:s,pad:i,dataFormat:r,dimRoundingMode:a,filterShape:n};return _s.runKernel("Conv2DBackpropFilter",c,p)}});const qi=Ws({conv2DBackpropInput_:function(t,e,n,s,i,r="NHWC",a){Ge(t.length===e.rank,(()=>`Length of inShape (${t.length}) and rank of dy (${e.rank}) must match`));let o=t,l=e,u=!1;3===e.rank&&(u=!0,l=ii(e,[1,e.shape[0],e.shape[1],e.shape[2]]),o=[1,t[0],t[1],t[2]]),Ge(4===o.length,(()=>`Error in conv2dDerInput: inShape must be length 4, but got length ${o.length}.`)),Ge(4===l.rank,(()=>`Error in conv2dDerInput: dy must be rank 4, but got rank ${l.rank}`)),Ge(4===n.rank,(()=>`Error in conv2dDerInput: filter must be rank 4, but got rank ${n.rank}`));const h="NHWC"===r?o[3]:o[1],c="NHWC"===r?l.shape[3]:l.shape[1];Ge(h===n.shape[2],(()=>`Error in conv2dDerInput: depth of input (${h}) must match input depth for filter ${n.shape[2]}.`)),Ge(c===n.shape[3],(()=>`Error in conv2dDerInput: depth of output (${c}) must match output depth for filter ${n.shape[3]}.`)),vi("conv2dDerInput",i,a);const p={dy:l,filter:n},d={strides:s,pad:i,dataFormat:r,dimRoundingMode:a,inputShape:o},f=_s.runKernel(me,p,d);return u?ii(f,[f.shape[1],f.shape[2],f.shape[3]]):f}}),Vi={kernelName:ge,inputsToSave:["x","filter"],gradFunc:(t,e,n)=>{const[s,i]=e,{dilations:r,strides:a,pad:o,dataFormat:l}=n;return Ge(bi(r),(()=>`Error in gradient of conv2D: dilation rates greater than 1 are not yet supported in gradients. Got dilations '${r}'`)),{x:()=>qi(s.shape,t,i,a,o,l),filter:()=>ji(s,t,i.shape,a,o,l)}}};const Ki=Ws({conv2d_:function(t,e,n,s,i="NHWC",r=[1,1],a){const o=Ps(t,"x","conv2d","float32"),l=Ps(e,"filter","conv2d","float32");let u=o,h=!1;3===o.rank&&(h=!0,u=ii(o,[1,o.shape[0],o.shape[1],o.shape[2]])),Ge(4===u.rank,(()=>`Error in conv2d: input must be rank 4, but got rank ${u.rank}.`)),Ge(4===l.rank,(()=>`Error in conv2d: filter must be rank 4, but got rank ${l.rank}.`)),vi("conv2d",s,a);const c="NHWC"===i?u.shape[3]:u.shape[1];Ge(c===l.shape[2],(()=>`Error in conv2d: depth of input (${c}) must match input depth for filter ${l.shape[2]}.`)),Ge(wi(n,r),(()=>`Error in conv2D: Either strides or dilations must be 1. Got strides ${n} and dilations '${r}'`)),Ge(ki(r),(()=>"Error in conv2D: Dilated rates should be larger than 0.")),Ge(ki(n),(()=>"Error in conv2D: Strides should be larger than 0."));const p={x:u,filter:l},d={strides:n,pad:s,dataFormat:i,dilations:r,dimRoundingMode:a},f=_s.runKernel(ge,p,d);return h?ii(f,[f.shape[1],f.shape[2],f.shape[3]]):f}}),Gi={kernelName:me,inputsToSave:["dy","filter"],gradFunc:(t,e,n)=>{const[s,i]=e,{strides:r,pad:a,dataFormat:o,dimRoundingMode:l}=n;return{dy:()=>Ki(t,i,r,a,o,1,l),filter:()=>ji(t,s,i.shape,r,a,o,l)}}};const Hi=Ws({conv3DBackpropFilter_:function(t,e,n,s,i){let r=t;4===t.rank&&(r=ii(t,[1,t.shape[0],t.shape[1],t.shape[2],t.shape[3]]));let a=e;4===a.rank&&(a=ii(e,[1,e.shape[0],e.shape[1],e.shape[2],e.shape[3]])),Ge(5===r.rank,(()=>`Error in conv3dDerFilter: input must be rank 5, but got shape ${r.shape}.`)),Ge(5===a.rank,(()=>`Error in conv3dDerFilter: dy must be rank 5, but got shape ${a.shape}.`)),Ge(5===n.length,(()=>`Error in conv3dDerFilter: filterShape must be length 5, but got ${n}.`)),Ge(r.shape[4]===n[3],(()=>`Error in conv3dDerFilter: depth of input ${r.shape[4]}) must match input depth in filter (${n[3]}.`)),Ge(a.shape[4]===n[4],(()=>`Error in conv3dDerFilter: depth of dy (${a.shape[4]}) must match output depth for filter (${n[4]}).`));const o={x:r,dy:a},l={strides:s,pad:i,filterShape:n};return _s.runKernel("Conv3DBackpropFilterV2",o,l)}});const Ji=Ws({conv3DBackpropInput_:function(t,e,n,s,i){Ge(t.length===e.rank,(()=>`Length of inShape (${t.length}) and rank of dy (${e.rank}) must match`));let r=t,a=e,o=!1;4===e.rank&&(o=!0,a=ii(e,[1,e.shape[0],e.shape[1],e.shape[2],e.shape[3]]),r=[1,t[0],t[1],t[2],t[3]]);const l=r[4],u=a.shape[4];Ge(5===r.length,(()=>`Error in conv3dDerInput: inShape must be length 5, but got length ${r.length}.`)),Ge(5===a.rank,(()=>`Error in conv3dDerInput: dy must be rank 5, but got rank ${a.rank}`)),Ge(5===n.rank,(()=>`Error in conv3dDerInput: filter must be rank 5, but got rank ${n.rank}`)),Ge(l===n.shape[3],(()=>`Error in conv3dDerInput: depth of input (${l}) must match input depth for filter ${n.shape[3]}.`)),Ge(u===n.shape[4],(()=>`Error in conv3dDerInput: depth of output (${u}) must match output depth for filter ${n.shape[4]}.`));const h={dy:a,filter:n},c={pad:i,strides:s,inputShape:r},p=_s.runKernel("Conv3DBackpropInputV2",h,c);return o?ii(p,[p.shape[1],p.shape[2],p.shape[3],p.shape[4]]):p}}),Zi={kernelName:"Conv3D",inputsToSave:["x","filter"],gradFunc:(t,e,n)=>{const{dilations:s,strides:i,pad:r}=n;Ge(bi(s),(()=>`Error in gradient of conv3D: dilation rates greater than 1 are not yet supported in gradients. Got dilations '${s}'`));const[a,o]=e;return{x:()=>Ji(a.shape,t,o,i,r),filter:()=>Hi(a,t,o.shape,i,r)}}};const Yi=Ws({sin_:function(t){const e={x:Ps(t,"x","sin","float32")};return _s.runKernel("Sin",e)}}),Xi={kernelName:"Cos",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(Js(Yi(js(n,"float32"))),t)}}};const Qi=Ws({sinh_:function(t){const e={x:Ps(t,"x","sinh")};return _s.runKernel(_e,e)}}),tr={kernelName:ye,inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(Qi(js(n,"float32")),t)}}};function er(t,e){return function(t,e,n){const s=t.length+e.length,i=[];let r=0,a=0;for(let o=0;o1)),e)}function nr(t,e){if(function(t,e){for(let n=0;nn.push(t))),n}function sr(t){return t.map(((t,e)=>[e,t])).sort(((t,e)=>t[1]-e[1])).map((t=>t[0]))}const ir=Ws({cumsum_:function(t,e=0,n=!1,s=!1){const i={x:Ps(t,"x","cumsum")},r={axis:e,exclusive:n,reverse:s};return _s.runKernel(be,i,r)}});const rr=Ws({complex_:function(t,e){const n=Ps(t,"real","complex"),s=Ps(e,"imag","complex");!function(t,e,n=""){Ge(Je(t,e),(()=>n+` Shapes ${t} and ${e} must match`))}(n.shape,s.shape,`real and imag shapes, ${n.shape} and ${s.shape}, must match in call to tf.complex().`);const i={real:n,imag:s};return _s.runKernel("Complex",i)}});const ar=Ws({imag_:function(t){const e={input:Ps(t,"input","imag")};return _s.runKernel("Imag",e)}});const or=Ws({real_:function(t){const e={input:Ps(t,"input","real")};return _s.runKernel("Real",e)}});const lr=Ws({transpose_:function(t,e,n){const s=Ps(t,"x","transpose");if(null==e&&(e=s.shape.map(((t,e)=>e)).reverse()),Ge(s.rank===e.length,(()=>`Error in transpose: rank of input ${s.rank} must match length of perm ${e}.`)),e.forEach((t=>{Ge(t>=0&&t"All entries in 'perm' must be between 0 and "+(s.rank-1)+` but got ${e}`))})),s.rank<=1)return s.clone();const i={x:s},r={perm:e};return"complex64"===s.dtype?(a=()=>{let t=or(s),e=ar(s);return t=_s.runKernel(Ue,{x:t},r),e=_s.runKernel(Ue,{x:e},r),n&&(e=Js(e)),rr(t,e)},_s.tidy(a,o)):_s.runKernel(Ue,i,r);var a,o}}),ur={kernelName:be,inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{axis:i,exclusive:r,reverse:a}=n;return{x:()=>{const e=nr([i],s.rank);let n=ir(t,i,r,!a);return null!=e&&(n=lr(n,e)),n}}}};const hr=Ws({depthwiseConv2dNativeBackpropFilter_:function(t,e,n,s,i,r=[1,1],a){let o=t;3===t.rank&&(o=ii(t,[1,t.shape[0],t.shape[1],t.shape[2]]));let l=e;3===l.rank&&(l=ii(e,[1,e.shape[0],e.shape[1],e.shape[2]]));const u={x:o,dy:l},h={strides:s,pad:i,dimRoundingMode:a,dilations:r,filterShape:n};return _s.runKernel("DepthwiseConv2dNativeBackpropFilter",u,h)}});const cr=Ws({depthwiseConv2dNativeBackpropInput_:function(t,e,n,s,i,r=[1,1],a){let o=e,l=!1;3===e.rank&&(l=!0,o=ii(e,[1,e.shape[0],e.shape[1],e.shape[2]]));const u={dy:o,filter:n},h={strides:s,pad:i,dimRoundingMode:a,dilations:r,inputShape:t},c=_s.runKernel("DepthwiseConv2dNativeBackpropInput",u,h);return l?ii(c,[c.shape[1],c.shape[2],c.shape[3]]):c}}),pr={kernelName:"DepthwiseConv2dNative",inputsToSave:["x","filter"],gradFunc:(t,e,n)=>{const{dilations:s,strides:i,pad:r,dimRoundingMode:a}=n,o=null==s?[1,1]:s;Ge(bi(o),(()=>`Error in gradient of depthwiseConv2dNative: dilation rates greater than 1 are not yet supported. Got dilations '${o}'`));const[l,u]=e;return Ge(4===l.rank,(()=>`Error in gradient of depthwiseConv2dNative: input must be rank 4, but got rank ${l.rank}.`)),Ge(4===u.rank,(()=>`Error in gradient of depthwiseConv2dNative: filter must be rank 4, but got rank ${u.rank}.`)),Ge(l.shape[3]===u.shape[2],(()=>`Error in gradient of depthwiseConv2d: number of input channels (${l.shape[3]}) must match the inChannels dimension in filter ${u.shape[2]}.`)),Ge(wi(i,o),(()=>`Error in gradient of depthwiseConv2d: Either strides or dilations must be 1. Got strides ${i} and dilations '${o}'.`)),vi("depthwiseConv2d",r,a),{x:()=>cr(l.shape,t,u,i,r,o,a),filter:()=>hr(l,t,u.shape,i,r,o,a)}}},dr={kernelName:"Dilation2D",inputsToSave:["x","filter"],gradFunc:(t,e,n)=>{const[s,i]=e,r={x:s,filter:i,dy:t},a={x:s,filter:i,dy:t};return{x:()=>_s.runKernel("Dilation2DBackpropInput",r,n),filter:()=>_s.runKernel("Dilation2DBackpropFilter",a,n)}}},fr={kernelName:"Elu",outputsToSave:[!0],gradFunc:(t,e)=>{const[n]=e,s={dy:t,y:n};return{x:()=>_s.runKernel("EluGrad",s)}}};const gr=Ws({exp_:function(t){const e={x:Ps(t,"x","exp")};return _s.runKernel("Exp",e)}}),mr={kernelName:"Erf",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e,s=qs(gr(Js(Xs(n))),2/Math.sqrt(Math.PI));return{x:()=>qs(t,s)}}},yr={kernelName:"Exp",outputsToSave:[!0],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(t,n)}}},br={kernelName:ke,inputsToSave:["input"],gradFunc:(t,e)=>{const[n]=e;return{input:()=>ii(t,n.shape)}}},wr={kernelName:"Expm1",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(t,gr(n))}}},kr={kernelName:ve,gradFunc:t=>({x:()=>li(t)})},vr={kernelName:Se,inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=si(n.shape,s.shape);return{a:()=>{const e=Hs(t,js(s,"float32")),r=ni(n.shape,i);return r.length>0?ii(ri(e,r),n.shape):e},b:()=>{let e=qs(t,js(n,"float32"));const r=ni(s.shape,i);r.length>0&&(e=ii(ri(e,r),s.shape));const a=Xs(s);return Js(Hs(e,js(a,"float32")))}}}};const Sr=Ws({rsqrt_:function(t){const e={x:Ps(t,"x","rsqrt","float32")};return _s.runKernel(Fe,e)}});const xr=Ws({tile_:function(t,e){const n=Ps(t,"x","tile","string_or_numeric");Ge(n.rank===e.length,(()=>`Error in transpose: rank of input ${n.rank} must match length of reps ${e}.`));const s={x:n},i={reps:e};return _s.runKernel(Pe,s,i)}}),Nr={kernelName:"FusedBatchNorm",inputsToSave:["x","mean","variance","scale"],gradFunc:(t,e,n)=>{const{varianceEpsilon:s}=n,[i,r,a,o]=e,l=null==o?Zs(1):o,u=ni(r.shape,i.shape),h=[];if(1===r.rank){for(let t=0;t1===r.rank?ii(qs(qs(t,xr(ii(d,[1,1,1,r.shape[0]]),h)),l),i.shape):ii(qs(qs(t,d),l),i.shape),mean:()=>{let t=qs(qs(d,Zs(-1)),p);return 1===r.rank&&(t=ri(t,u)),ii(t,r.shape)},variance:()=>{let t=qs(qs(f,c),p);return 1===r.rank&&(t=ri(t,u)),ii(t,r.shape)},scale:()=>{const e=qs(c,d);let n=qs(t,e);return 1===r.rank&&(n=ri(n,u)),ii(n,r.shape)},offset:()=>{let e=t;return 1===r.rank&&(e=ri(e,u)),ii(e,r.shape)}}}};const Ir=Ws({stack_:function(t,e=0){const n=Us(t,"tensors","stack","string_or_numeric");Ge(n.length>=1,(()=>"Pass at least one tensor to tf.stack")),n.length>0&&Ge(e<=n[0].rank,(()=>"Axis must be <= rank of the tensor"));const s=n,i={axis:e};return _s.runKernel(Ee,s,i)}});const Ar=Ws({unsortedSegmentSum_:function(t,e,n){const s=Ps(t,"x","unsortedSegmentSum"),i=Ps(e,"segmentIds","unsortedSegmentSum","int32");Ge(Ze(n),(()=>"numSegments must be of dtype int"));const r={x:s,segmentIds:i},a={numSegments:n};return _s.runKernel(je,r,a)}}),zr={kernelName:xe,inputsToSave:["x","indices"],gradFunc:(t,e,n)=>{const[s,i]=e,{axis:r,batchDims:a}=n,o=Xe(r,s.shape)[0],l=(t,e,n)=>()=>{const s=t.shape,i=e.size,a=s.slice(0,o),l=a.length,u=s.slice(r,s.length).slice(1),h=u.length,c=Er(0,l),p=Er(l+1,l+1+h),d=Tr([a,[i],u]),f=ii(n,d),g=ii(e,[i]),m=Tr([[l],c,p]),y=lr(f,m);let b=Ar(y,g,t.shape[o]);const w=sr(m);return b=lr(b,w),b};if(1===a){const e=s.shape[0],n=s.split(e,0);return{x:()=>{const e=Ir(n.map(((e,n)=>l(e,i.slice(n,1),t.slice(n,1))())));return e.reshape(s.shape)},indices:()=>i}}return{x:l(s,i,t),indices:()=>i}}};function Er(t,e){const n=[];for(let s=t;s{const[n,s]=e;return{a:()=>li(n),b:()=>li(s)}}},$r={kernelName:Ie,gradFunc:t=>({x:()=>js(t,"float32")})},Fr={kernelName:"IsFinite",gradFunc:t=>({x:()=>li(t)})},Dr={kernelName:"IsInf",gradFunc:t=>({x:()=>li(t)})},Lr={kernelName:"IsNan",gradFunc:t=>({x:()=>li(t)})};const _r=Ws({greater_:function(t,e){let n=Ps(t,"a","greater","string_or_numeric"),s=Ps(e,"b","greater","string_or_numeric");[n,s]=Ts(n,s),si(n.shape,s.shape);const i={a:n,b:s};return _s.runKernel("Greater",i)}}),Rr={kernelName:"LeakyRelu",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{alpha:i}=n,r=_r(s,0);return{x:()=>Mi(r,t,qs(t,i))}}},Or={kernelName:"Log1p",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>Hs(t,pi(n,1))}}},Mr={kernelName:"Log",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>Hs(t,js(n,"float32"))}}},Br={kernelName:"LogSoftmax",inputsToSave:[],outputsToSave:[!0],gradFunc:(t,e,n)=>{const[s]=e,{axis:i}=n;return{logits:()=>{const e=gr(s);return Qs(t,qs(ri(t,i,!0),e))}}}};const Pr=Ws({localResponseNormalizationBackprop_:function(t,e,n,s=5,i=1,r=1,a=.5){const o={x:t,y:e,dy:n},l={depthRadius:s,bias:i,alpha:r,beta:a};return _s.runKernel("LRNGrad",o,l)}}),Ur={kernelName:"LRN",inputsToSave:["x"],outputsToSave:[!0],gradFunc:(t,e,n)=>{const[s,i]=e,{depthRadius:r,bias:a,alpha:o,beta:l}=n;return{x:()=>Pr(s,i,t,r,a,o,l)}}};const Wr=Ws({equal_:function(t,e){let n=Ps(t,"a","equal","string_or_numeric"),s=Ps(e,"b","equal","string_or_numeric");[n,s]=Ts(n,s),si(n.shape,s.shape);const i={a:n,b:s};return _s.runKernel("Equal",i)}});function jr(t,e,n,s){return e.rankqs(t,js(Wr(n,e),t.dtype))}}const qr={kernelName:"Max",inputsToSave:["x"],outputsToSave:[!0],gradFunc:(t,e,n)=>{const s=n,{reductionIndices:i}=s,r=e[0],a=jr(t,e[1],r,Xe(i,r.shape));return{x:()=>a.x()}}};const Vr=Ws({less_:function(t,e){let n=Ps(t,"a","less","string_or_numeric"),s=Ps(e,"b","less","string_or_numeric");[n,s]=Ts(n,s),si(n.shape,s.shape);const i={a:n,b:s};return _s.runKernel("Less",i)}}),Kr={kernelName:Ae,inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e;return{a:()=>qs(t,js(Di(n,s),"float32")),b:()=>qs(t,js(Vr(n,s),"float32"))}}};const Gr=Ws({maxPool3dGrad_:function(t,e,n,s,i,r,a){const o=Ps(t,"dy","maxPool3dGrad"),l=Ps(e,"input","maxPool3dGrad"),u=Ps(n,"output","maxPool3dGrad");let h=o,c=l,p=u,d=!1;4===l.rank&&(d=!0,h=ii(o,[1,o.shape[0],o.shape[1],o.shape[2],o.shape[3]]),c=ii(l,[1,l.shape[0],l.shape[1],l.shape[2],l.shape[3]]),p=ii(u,[1,u.shape[0],u.shape[1],u.shape[2],u.shape[3]])),Ge(5===h.rank,(()=>`Error in maxPool3dGrad: dy must be rank 5 but got rank ${h.rank}.`)),Ge(5===c.rank,(()=>`Error in maxPool3dGrad: input must be rank 5 but got rank ${c.rank}.`)),Ge(5===p.rank,(()=>`Error in maxPool3dGrad: output must be rank 5 but got rank ${p.rank}.`)),vi("maxPool3dGrad",r,a);const f={dy:h,input:c,output:p},g={filterSize:s,strides:i,pad:r,dimRoundingMode:a},m=_s.runKernel("MaxPool3DGrad",f,g);return d?ii(m,[m.shape[1],m.shape[2],m.shape[3],m.shape[4]]):m}}),Hr={kernelName:"MaxPool3D",inputsToSave:["x"],outputsToSave:[!0],gradFunc:(t,e,n)=>{const[s,i]=e,{filterSize:r,strides:a,pad:o,dimRoundingMode:l}=n;return{x:()=>Gr(t,s,i,r,a,o,l)}}};const Jr=Ws({maxPoolGrad_:function(t,e,n,s,i,r,a){const o=Ps(t,"dy","maxPoolGrad"),l=Ps(e,"input","maxPoolGrad"),u=Ps(n,"output","maxPoolGrad");Ge(l.rank===o.rank,(()=>`Rank of input (${l.rank}) does not match rank of dy (${o.rank})`)),Ge(4===o.rank,(()=>`Error in maxPoolGrad: dy must be rank 4 but got rank ${o.rank}.`)),Ge(4===l.rank,(()=>`Error in maxPoolGrad: input must be rank 4 but got rank ${l.rank}.`)),vi("maxPoolGrad",r,a);const h={dy:o,input:l,output:u},c={filterSize:s,strides:i,pad:r,dimRoundingMode:a};return _s.runKernel("MaxPoolGrad",h,c)}}),Zr={kernelName:"MaxPool",inputsToSave:["x"],outputsToSave:[!0],gradFunc:(t,e,n)=>{const[s,i]=e,{filterSize:r,strides:a,pad:o}=n;return{x:()=>Jr(t,s,i,r,a,o)}}};function Yr(t,e="float32"){if(un(t),"complex64"===e){const e=Yr(t,"float32"),n=Yr(t,"float32");return rr(e,n)}const n=ln(He(t),e);return _s.makeTensor(n,t,e)}function Xr(t,e="float32"){if(un(t),"complex64"===e){const e=Xr(t,"float32"),n=Yr(t,"float32");return rr(e,n)}const n=on(He(t),e);return _s.makeTensor(n,t,e)}const Qr={kernelName:"Mean",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{axis:i}=n,r=Xe(i,s.shape),a=function(t,e){const n=[],s=t.length;for(let i=0;it[e]))]}(s.shape,r),o=He(a[1]);return{x:()=>{const e=s.shape.slice();r.forEach((t=>{e[t]=1}));const n=ii(t,e);return Hs(qs(n,Xr(s.shape,"float32")),o)}}}},ta={kernelName:"Min",inputsToSave:["x"],outputsToSave:[!0],gradFunc:(t,e,n)=>{const s=n,{axis:i}=s,[r,a]=e,o=jr(t,a,r,Xe(i,r.shape));return{x:()=>o.x()}}},ea={kernelName:"Minimum",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e;return{a:()=>qs(t,js(Li(n,s),"float32")),b:()=>qs(t,js(_r(n,s),"float32"))}}};const na=Ws({slice_:function(t,e,n){const s=Ps(t,"x","slice","string_or_numeric");if(0===s.rank)throw new Error("Slicing scalar is not possible");const i={x:s},r={begin:e,size:n};return _s.runKernel(Le,i,r)}}),sa={kernelName:"MirrorPad",inputsToSave:["x"],gradFunc:(t,e,n)=>{const s=e[0],{paddings:i}=n,r=i.map((t=>t[0]));return{x:()=>na(t,r,s.shape)}}};const ia=Ws({floor_:function(t){const e={x:Ps(t,"x","floor","float32")};return _s.runKernel(ve,e)}}),ra={kernelName:"Mod",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=si(n.shape,s.shape);return{a:()=>{const e=ni(n.shape,i);return e.length>0?ii(ri(t,e),n.shape):t},b:()=>{const e=qs(t,Js(ia(Hs(n,s)))),r=ni(s.shape,i);return r.length>0?ii(ri(e,r),s.shape):e}}}},aa={kernelName:ze,inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=si(n.shape,s.shape);return{a:()=>{const e=qs(t,js(s,"float32")),r=ni(n.shape,i);return r.length>0?ii(ri(e,r),n.shape):e},b:()=>{const e=qs(t,js(n,"float32")),r=ni(s.shape,i);return r.length>0?ii(ri(e,r),s.shape):e}}}},oa={kernelName:"Neg",gradFunc:t=>({x:()=>Js(t)})},la={kernelName:"OneHot",inputsToSave:["indices"],gradFunc:(t,e)=>{const n=e[0];return{indices:()=>Yr(n.shape,"float32")}}},ua={kernelName:"OnesLike",gradFunc:t=>({x:()=>li(t)})};const ha=Ws({unstack_:function(t,e=0){const n=Ps(t,"x","unstack","string_or_numeric");Ge(e>=-n.shape.length&&e`Axis = ${e} is not in [-${n.shape.length}, ${n.shape.length})`));const s={value:n},i={axis:e};return _s.runKernel(We,s,i)}}),ca={kernelName:Ee,saveAllInputs:!0,gradFunc:(t,e,n)=>{const{axis:s}=n;return ha(t,s).map((t=>()=>t))}},pa={kernelName:Te,inputsToSave:["x"],gradFunc:(t,e,n)=>{const s=e[0],{paddings:i}=n,r=i.map((t=>t[0]));return{x:()=>na(t,r,s.shape)}}};const da=Ws({log_:function(t){const e={x:Ps(t,"x","log","float32")};return _s.runKernel("Log",e)}});const fa=Ws({pow_:function(t,e){let n=Ps(t,"base","pow"),s=Ps(e,"exp","pow");[n,s]=Ts(n,s);const i={a:n,b:s};return _s.runKernel("Pow",i)}}),ga={kernelName:"Pow",inputsToSave:["a","b"],outputsToSave:[!0],gradFunc:(t,e)=>{const[n,s,i]=e,r=n,a=s,o=si(r.shape,a.shape);return{a:()=>{const e=js(a,"float32");let n=qs(t,qs(e,fa(r,Qs(e,Zs(1)))));const s=ni(r.shape,o);return s.length>0&&(n=ri(n,s)),ii(n,r.shape)},b:()=>{const e=_r(r,0),n=Mi(e,da(r),li(r));let s=qs(t,qs(i,n));const l=ni(a.shape,o);return l.length>0&&(s=ri(s,l)),ii(s,a.shape)}}}},ma={kernelName:"Prelu",inputsToSave:["x","alpha"],gradFunc:(t,e)=>{const[n,s]=e,i=_r(n,0);return{x:()=>Mi(i,t,qs(t,s)),alpha:()=>{let e=Mi(i,li(t),qs(t,n));const r=ni(s.shape,t.shape);return r.length>0&&(e=ri(e,r)),ii(e,s.shape)}}}};const ya=fn();ya.registerFlag("DEBUG",(()=>!1),(t=>{t&&console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance.")})),ya.registerFlag("IS_BROWSER",(()=>"undefined"!=typeof window&&null!=window.document||"undefined"!=typeof WorkerGlobalScope)),ya.registerFlag("IS_NODE",(()=>"undefined"!=typeof process&&"undefined"!=typeof process.versions&&"undefined"!=typeof process.versions.node)),ya.registerFlag("IS_CHROME",(()=>"undefined"!=typeof navigator&&null!=navigator&&null!=navigator.userAgent&&/Chrome/.test(navigator.userAgent)&&/Google Inc/.test(navigator.vendor))),ya.registerFlag("IS_SAFARI",(()=>"undefined"!=typeof navigator&&null!=navigator&&null!=navigator.userAgent&&/Safari/.test(navigator.userAgent)&&/Apple/.test(navigator.vendor))),ya.registerFlag("PROD",(()=>!1)),ya.registerFlag("TENSORLIKE_CHECK_SHAPE_CONSISTENCY",(()=>ya.getBool("DEBUG"))),ya.registerFlag("DEPRECATION_WARNINGS_ENABLED",(()=>!0)),ya.registerFlag("IS_TEST",(()=>!1)),ya.registerFlag("CHECK_COMPUTATION_FOR_ERRORS",(()=>ya.getBool("DEBUG"))),ya.registerFlag("WRAP_TO_IMAGEBITMAP",(()=>!1)),ya.registerFlag("CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU",(()=>!1)),ya.registerFlag("USE_SETTIMEOUTCUSTOM",(()=>!1));class ba{static join(t){return new ba(t).slice()}constructor(t){if(this.shards=[],this.previousShardIndex=0,null==t)return;if(t instanceof Array||(t=[t]),0===(t=t.map((t=>as(t)?t.buffer:t))).length)return;this.bufferUniformSize=t[0].byteLength;let e=0;for(let n=0;n=this.byteLength)return-1;if(null!=this.bufferUniformSize)return this.previousShardIndex=Math.floor(t/this.bufferUniformSize),this.previousShardIndex;function e(e){return t=e.end?1:0}if(0===e(this.shards[this.previousShardIndex]))return this.previousShardIndex;const n=function(t,e){let n=0,s=t.length;for(;n<=s;){const i=Math.floor((s-n)/2)+n,r=e(t[i]);if(0===r)return i;r<0?s=i:n=i+1}return-1}(this.shards,e);return-1===n?-1:(this.previousShardIndex=n,this.previousShardIndex)}}const wa="undefined"!=typeof Buffer&&("undefined"==typeof Blob||"undefined"==typeof atob||"undefined"==typeof btoa);function ka(t){return wa?Buffer.byteLength(t,"utf8"):new Blob([t]).size}function va(t,e){const n={modelTopology:t.modelTopology,format:t.format,generatedBy:t.generatedBy,convertedBy:t.convertedBy,weightsManifest:e};return null!=t.signature&&(n.signature=t.signature),null!=t.userDefinedMetadata&&(n.userDefinedMetadata=t.userDefinedMetadata),null!=t.modelInitializer&&(n.modelInitializer=t.modelInitializer),null!=t.initializerSignature&&(n.initializerSignature=t.initializerSignature),null!=t.trainingConfig&&(n.trainingConfig=t.trainingConfig),n}async function Sa(t,e){let n,s;return null!=t.weightsManifest&&([n,s]=await e(t.weightsManifest)),function(t,e,n){const s={modelTopology:t.modelTopology,format:t.format,generatedBy:t.generatedBy,convertedBy:t.convertedBy};if(null!=t.trainingConfig&&(s.trainingConfig=t.trainingConfig),null!=t.weightsManifest){if(!e)throw new Error("modelJSON has weightsManifest but weightSpecs is null");if(!n)throw new Error("modelJSON has weightsManifest but weightData is null");s.weightSpecs=e,s.weightData=n}return null!=t.signature&&(s.signature=t.signature),null!=t.userDefinedMetadata&&(s.userDefinedMetadata=t.userDefinedMetadata),null!=t.modelInitializer&&(s.modelInitializer=t.modelInitializer),null!=t.initializerSignature&&(s.initializerSignature=t.initializerSignature),s}(t,n,s)}function xa(t){if(t.modelTopology instanceof ArrayBuffer)throw new Error("Expected JSON model topology, received ArrayBuffer.");return{dateSaved:new Date,modelTopologyType:"JSON",modelTopologyBytes:null==t.modelTopology?0:ka(JSON.stringify(t.modelTopology)),weightSpecsBytes:null==t.weightSpecs?0:ka(JSON.stringify(t.weightSpecs)),weightDataBytes:null==t.weightData?0:new ba(t.weightData).byteLength}}function Na(t){const e=[];for(const n of t)e.push(...n.weights);return e}class Ia{constructor(){this.saveRouters=[],this.loadRouters=[]}static getInstance(){return null==Ia.instance&&(Ia.instance=new Ia),Ia.instance}static registerSaveRouter(t){Ia.getInstance().saveRouters.push(t)}static registerLoadRouter(t){Ia.getInstance().loadRouters.push(t)}static getSaveHandlers(t){return Ia.getHandlers(t,"save")}static getLoadHandlers(t,e){return Ia.getHandlers(t,"load",e)}static getHandlers(t,e,n){const s=[];return("load"===e?Ia.getInstance().loadRouters:Ia.getInstance().saveRouters).forEach((e=>{const i=e(t,n);null!==i&&s.push(i)})),s}}const Aa="models_store",za="model_info_store";class Ea{constructor(t){if(this.indexedDB=function(){if(!fn().getBool("IS_BROWSER"))throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser.");const t="undefined"==typeof window?self:window,e=t.indexedDB||t.mozIndexedDB||t.webkitIndexedDB||t.msIndexedDB||t.shimIndexedDB;if(null==e)throw new Error("The current browser does not appear to support IndexedDB.");return e}(),null==t||!t)throw new Error("For IndexedDB, modelPath must not be null, undefined or empty.");this.modelPath=t}async save(t){if(t.modelTopology instanceof ArrayBuffer)throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");return this.databaseAction(this.modelPath,t)}async load(){return this.databaseAction(this.modelPath)}databaseAction(t,e){return new Promise(((t,n)=>{const s=this.indexedDB.open("tensorflowjs",1);s.onupgradeneeded=()=>function(t){const e=t.result;e.createObjectStore(Aa,{keyPath:"modelPath"}),e.createObjectStore(za,{keyPath:"modelPath"})}(s),s.onsuccess=()=>{const i=s.result;if(null==e){const e=i.transaction(Aa,"readonly"),s=e.objectStore(Aa).get(this.modelPath);s.onsuccess=()=>{if(null==s.result)return i.close(),n(new Error(`Cannot find model with path '${this.modelPath}' in IndexedDB.`));t(s.result.modelArtifacts)},s.onerror=t=>(i.close(),n(s.error)),e.oncomplete=()=>i.close()}else{e.weightData=ba.join(e.weightData);const s=xa(e),r=i.transaction(za,"readwrite");let a,o,l=r.objectStore(za);try{a=l.put({modelPath:this.modelPath,modelArtifactsInfo:s})}catch(t){return n(t)}a.onsuccess=()=>{o=i.transaction(Aa,"readwrite");const a=o.objectStore(Aa);let u;try{u=a.put({modelPath:this.modelPath,modelArtifacts:e,modelArtifactsInfo:s})}catch(t){return n(t)}u.onsuccess=()=>t({modelArtifactsInfo:s}),u.onerror=t=>{l=r.objectStore(za);const e=l.delete(this.modelPath);e.onsuccess=()=>(i.close(),n(u.error)),e.onerror=t=>(i.close(),n(u.error))}},a.onerror=t=>(i.close(),n(a.error)),r.oncomplete=()=>{null==o?i.close():o.oncomplete=()=>i.close()}}},s.onerror=t=>n(s.error)}))}}Ea.URL_SCHEME="indexeddb://";const Ta=t=>{return fn().getBool("IS_BROWSER")&&!Array.isArray(t)&&t.startsWith(Ea.URL_SCHEME)?(e=t.slice(Ea.URL_SCHEME.length),new Ea(e)):null;var e};Ia.registerSaveRouter(Ta),Ia.registerLoadRouter(Ta);const Ca="/",$a="tensorflowjs_models",Fa="info",Da="model_topology",La="weight_specs",_a="weight_data",Ra="model_metadata";class Oa{constructor(t){if(!fn().getBool("IS_BROWSER")||"undefined"==typeof window||"undefined"==typeof window.localStorage)throw new Error("The current environment does not support local storage.");if(this.LS=window.localStorage,null==t||!t)throw new Error("For local storage, modelPath must not be null, undefined or empty.");var e;this.modelPath=t,this.keys=(e=this.modelPath,{info:[$a,e,Fa].join(Ca),topology:[$a,e,Da].join(Ca),weightSpecs:[$a,e,La].join(Ca),weightData:[$a,e,_a].join(Ca),modelMetadata:[$a,e,Ra].join(Ca)})}async save(t){if(t.modelTopology instanceof ArrayBuffer)throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");{const e=JSON.stringify(t.modelTopology),n=JSON.stringify(t.weightSpecs),s=xa(t),i=ba.join(t.weightData);try{this.LS.setItem(this.keys.info,JSON.stringify(s)),this.LS.setItem(this.keys.topology,e),this.LS.setItem(this.keys.weightSpecs,n),this.LS.setItem(this.keys.weightData,function(t){if(wa)return Buffer.from(t).toString("base64");const e=new Uint8Array(t);let n="";for(let t=0,s=e.length;t{return fn().getBool("IS_BROWSER")&&!Array.isArray(t)&&t.startsWith(Oa.URL_SCHEME)?(e=t.slice(Oa.URL_SCHEME.length),new Oa(e)):null;var e};Ia.registerSaveRouter(Ma),Ia.registerLoadRouter(Ma);function Ba(t){return new Promise((t=>setTimeout(t))).then(t)}class Pa{constructor(t){if(!fn().getBool("IS_BROWSER"))throw new Error("browserDownloads() cannot proceed because the current environment is not a browser.");t.startsWith(Pa.URL_SCHEME)&&(t=t.slice(Pa.URL_SCHEME.length)),null!=t&&0!==t.length||(t="model"),this.modelJsonFileName=t+".json",this.weightDataFileName=t+".weights.bin"}async save(t){if("undefined"==typeof document)throw new Error("Browser downloads are not supported in this environment since `document` is not present");const e=ba.join(t.weightData),n=window.URL.createObjectURL(new Blob([e],{type:"application/octet-stream"}));if(t.modelTopology instanceof ArrayBuffer)throw new Error("BrowserDownloads.save() does not support saving model topology in binary formats yet.");{const e=va(t,[{paths:["./"+this.weightDataFileName],weights:t.weightSpecs}]),s=window.URL.createObjectURL(new Blob([JSON.stringify(e)],{type:"application/json"})),i=null==this.modelJsonAnchor?document.createElement("a"):this.modelJsonAnchor;if(i.download=this.modelJsonFileName,i.href=s,await Ba((()=>i.dispatchEvent(new MouseEvent("click")))),null!=t.weightData){const t=null==this.weightDataAnchor?document.createElement("a"):this.weightDataAnchor;t.download=this.weightDataFileName,t.href=n,await Ba((()=>t.dispatchEvent(new MouseEvent("click"))))}return{modelArtifactsInfo:xa(t)}}}}Pa.URL_SCHEME="downloads://";function Ua(t,e,n,s){!function(t){Ge(null!=t&&Array.isArray(t)&&t.length>0,(()=>"promises must be a none empty array"))}(t),function(t,e){Ge(t>=0&&t<=1,(()=>`Progress fraction must be in range [0, 1], but got startFraction ${t}`)),Ge(e>=0&&e<=1,(()=>`Progress fraction must be in range [0, 1], but got endFraction ${e}`)),Ge(e>=t,(()=>`startFraction must be no more than endFraction, but got startFraction ${t} and endFraction ${e}`))}(n=null==n?0:n,s=null==s?1:s);let i=0;return Promise.all(t.map((r=>(r.then((r=>{const a=n+ ++i/t.length*(s-n);return e(a),r})),r))))}Ia.registerSaveRouter((t=>fn().getBool("IS_BROWSER")&&!Array.isArray(t)&&t.startsWith(Pa.URL_SCHEME)?function(t="model"){return new Pa(t)}(t.slice(Pa.URL_SCHEME.length)):null));class Wa{constructor(t,e){if(this.DEFAULT_METHOD="POST",null==e&&(e={}),this.weightPathPrefix=e.weightPathPrefix,this.weightUrlConverter=e.weightUrlConverter,null!=e.fetchFunc?(Ge("function"==typeof e.fetchFunc,(()=>"Must pass a function that matches the signature of `fetch` (see https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)")),this.fetch=e.fetchFunc):this.fetch=fn().platform.fetch,Ge(null!=t&&t.length>0,(()=>"URL path for http must not be null, undefined or empty.")),Array.isArray(t)&&Ge(2===t.length,(()=>`URL paths for http must have a length of 2, (actual length is ${t.length}).`)),this.path=t,null!=e.requestInit&&null!=e.requestInit.body)throw new Error("requestInit is expected to have no pre-existing body, but has one.");this.requestInit=e.requestInit||{},this.loadOptions=e}async save(t){if(t.modelTopology instanceof ArrayBuffer)throw new Error("BrowserHTTPRequest.save() does not support saving model topology in binary formats yet.");const e=Object.assign({method:this.DEFAULT_METHOD},this.requestInit);e.body=new FormData;const n=va(t,[{paths:["./model.weights.bin"],weights:t.weightSpecs}]);if(e.body.append("model.json",new Blob([JSON.stringify(n)],{type:"application/json"}),"model.json"),null!=t.weightData){const n=ba.join(t.weightData);e.body.append("model.weights.bin",new Blob([n],{type:"application/octet-stream"}),"model.weights.bin")}const s=await this.fetch(this.path,e);if(s.ok)return{modelArtifactsInfo:xa(t),responses:[s]};throw new Error(`BrowserHTTPRequest.save() failed due to HTTP response status ${s.status}.`)}async loadModelJSON(){const t=await this.fetch(this.path,this.requestInit);if(!t.ok)throw new Error(`Request to ${this.path} failed with status code ${t.status}. Please verify this URL points to the model JSON of the model to load.`);let e;try{e=await t.json()}catch(t){let e=`Failed to parse model JSON of response from ${this.path}.`;throw this.path.endsWith(".pb")?e+=" Your path contains a .pb file extension. Support for .pb models have been removed in TensorFlow.js 1.0 in favor of .json models. You can re-convert your Python TensorFlow model using the TensorFlow.js 1.0 conversion scripts or you can convert your.pb models with the 'pb2json'NPM script in the tensorflow/tfjs-converter repository.":e+=" Please make sure the server is serving valid JSON for this request.",new Error(e)}const n=e.modelTopology,s=e.weightsManifest;if(null==n&&null==s)throw new Error(`The JSON from HTTP path ${this.path} contains neither model topology or manifest for weights.`);return e}async load(){if(this.loadOptions.streamWeights)return this.loadStream();return Sa(await this.loadModelJSON(),(t=>this.loadWeights(t)))}async loadStream(){const t=await this.loadModelJSON(),e=await this.getWeightUrls(t.weightsManifest),n=Na(t.weightsManifest);return Object.assign(Object.assign({},t),{weightSpecs:n,getWeightStream:()=>function(t,e){var n;const s=null==e.fetchFunc?fn().platform.fetch:e.fetchFunc;let i,r=0;return null===(n=e.onProgress)||void 0===n||n.call(e,0),new ReadableStream({pull:async n=>{for(var a;re?t.substring(n):"";return[s+"/",i]}(e),i=this.weightPathPrefix||n,r=[],a=[];for(const e of t)for(const t of e.paths)null!=this.weightUrlConverter?a.push(this.weightUrlConverter(t)):r.push(i+t+s);return this.weightUrlConverter&&r.push(...await Promise.all(a)),r}async loadWeights(t){const e=await this.getWeightUrls(t),n=Na(t),s=await async function(t,e){null==e&&(e={});const n=null==e.fetchFunc?fn().platform.fetch:e.fetchFunc,s=t.map((t=>n(t,e.requestInit,{isBinary:!0}))),i=(null==e.onProgress?await Promise.all(s):await Ua(s,e.onProgress,0,.5)).map((t=>t.arrayBuffer()));return null==e.onProgress?await Promise.all(i):await Ua(i,e.onProgress,.5,1)}(e,this.loadOptions);return[n,s]}}function ja(t){return null!=t.match(Wa.URL_SCHEME_REGEX)}Wa.URL_SCHEME_REGEX=/^https?:\/\//;const qa=(t,e)=>{if("undefined"==typeof fetch&&(null==e||null==e.fetchFunc))return null;{let n=!0;if(n=Array.isArray(t)?t.every((t=>ja(t))):ja(t),n)return function(t,e){return new Wa(t,e)}(t,e)}return null};function Va(t,e,n){if(e.rank<1)throw new Error(`tf.scatterND() expects the indices to be rank 1 or higher, but the rank was ${e.rank}.`);if(t.rank<1)throw new Error(`tf.scatterND() expects the updates to be rank 1 or higher, but the rank was ${t.rank}.`);if("int32"!==e.dtype)throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${e.dtype}`);if(n.length<1)throw new Error(`Output rank must be greater or equal to 1, but got shape: ${n}`);if(0===n.length){if(0===e.size)throw new Error(`Indices specified for empty output. indices shape: ${e.shape}`);if(0===t.size)throw new Error(`Updates specified for empty output. updates shape: ${t.shape}`)}!function(t,e,n){const s=e.rank>1?e.shape[e.rank-1]:1,i=e.rank>1?e.rank-1:1,r=`Must have updates.shape = indices.shape[:batchDim] + shape[sliceDim:], got updates.shape: ${n.shape}, indices.shape: ${e.shape}, shape: ${t}, sliceDim: ${s}, and batchDim: ${i}.`;if(n.rank{if("complex64"!==t.dtype)throw new Error(`Cannot concatenate complex64 tensors with a tensor\n with dtype ${t.dtype}. `)})),1===n.length)return Ri(n[0]);const s=n,i={axis:e};return _s.runKernel(fe,s,i)}});const Ga=Ws({sigmoid_:function(t){const e={x:Ps(t,"x","sigmoid","float32")};return _s.runKernel(Re,e)}});const Ha=Ws({batchToSpaceND_:function(t,e,n){const s=Ps(t,"x","batchToSpaceND"),i=e.reduce(((t,e)=>t*e));Ge(s.rank>=1+e.length,(()=>`input rank is ${s.rank} but should be > than blockShape.length ${e.length}`)),Ge(n.length===e.length,(()=>`crops.length is ${n.length} but should be equal to blockShape.length ${e.length}`)),Ge(s.shape[0]%i==0,(()=>`input tensor batch is ${s.shape[0]} but is not divisible by the product of the elements of blockShape ${e.join(" * ")} === ${i}`));const r={x:s},a={blockShape:e,crops:n};return _s.runKernel(pe,r,a)}});const Ja=Ws({cos_:function(t){const e={x:Ps(t,"x","cos","float32")};return _s.runKernel("Cos",e)}});const Za=Ws({cosh_:function(t){const e={x:Ps(t,"x","cosh","float32")};return _s.runKernel(ye,e)}});const Ya=Ws({cumprod_:function(t,e=0,n=!1,s=!1){const i={x:Ps(t,"x","cumprod")},r={axis:e,exclusive:n,reverse:s};return _s.runKernel("Cumprod",i,r)}});const Xa=Ws({expandDims_:function(t,e=0){const n=Ps(t,"x","expandDims","string_or_numeric");Ge(e<=n.rank,(()=>"Axis must be <= rank of the tensor"));const s={input:n},i={dim:e};return _s.runKernel(ke,s,i)}});const Qa=Ws({gather_:function(t,e,n=0,s=0){const i={x:Ps(t,"x","gather"),indices:Ps(e,"indices","gather","int32")},r={axis:n,batchDims:s};return _s.runKernel(xe,i,r)}});const to=Ws({logicalNot_:function(t){const e={x:Ps(t,"x","logicalNot","bool")};return _s.runKernel("LogicalNot",e)}});const eo=Ws({maximum_:function(t,e){let n=Ps(t,"a","maximum"),s=Ps(e,"b","maximum");[n,s]=Ts(n,s),"bool"===n.dtype&&(n=js(n,"int32"),s=js(s,"int32")),si(n.shape,s.shape);const i={a:n,b:s};return _s.runKernel(Ae,i)}});const no=Ws({pad_:function(t,e,n=0){const s=Ps(t,"x","pad");if(0===s.rank)throw new Error("pad(scalar) is not defined. Pass non-scalar to pad");const i={paddings:e,constantValue:n},r={x:s};return _s.runKernel(Te,r,i)}});var so={exports:{}};!function(t){!function(t,e,n){function s(t){var e,n=this,s=(e=4022871197,function(t){t=String(t);for(var n=0;n>>0,e=(s*=e)>>>0,e+=4294967296*(s-=e)}return 2.3283064365386963e-10*(e>>>0)});n.next=function(){var t=2091639*n.s0+2.3283064365386963e-10*n.c;return n.s0=n.s1,n.s1=n.s2,n.s2=t-(n.c=0|t)},n.c=1,n.s0=s(" "),n.s1=s(" "),n.s2=s(" "),n.s0-=s(t),n.s0<0&&(n.s0+=1),n.s1-=s(t),n.s1<0&&(n.s1+=1),n.s2-=s(t),n.s2<0&&(n.s2+=1),s=null}function i(t,e){return e.c=t.c,e.s0=t.s0,e.s1=t.s1,e.s2=t.s2,e}function r(t,e){var n=new s(t),r=e&&e.state,a=n.next;return a.int32=function(){return 4294967296*n.next()|0},a.double=function(){return a()+11102230246251565e-32*(2097152*a()|0)},a.quick=a,r&&("object"==typeof r&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.alea=r}(0,t,!1)}(so);var io=so.exports,ro={exports:{}};!function(t){!function(t,e,n){function s(t){var e=this,n="";e.x=0,e.y=0,e.z=0,e.w=0,e.next=function(){var t=e.x^e.x<<11;return e.x=e.y,e.y=e.z,e.z=e.w,e.w^=e.w>>>19^t^t>>>8},t===(0|t)?e.x=t:n+=t;for(var s=0;s>>0)/4294967296};return a.double=function(){do{var t=((n.next()>>>11)+(n.next()>>>0)/4294967296)/(1<<21)}while(0===t);return t},a.int32=n.next,a.quick=a,r&&("object"==typeof r&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.xor128=r}(0,t,!1)}(ro);var ao=ro.exports,oo={exports:{}};!function(t){!function(t,e,n){function s(t){var e=this,n="";e.next=function(){var t=e.x^e.x>>>2;return e.x=e.y,e.y=e.z,e.z=e.w,e.w=e.v,(e.d=e.d+362437|0)+(e.v=e.v^e.v<<4^t^t<<1)|0},e.x=0,e.y=0,e.z=0,e.w=0,e.v=0,t===(0|t)?e.x=t:n+=t;for(var s=0;s>>4),e.next()}function i(t,e){return e.x=t.x,e.y=t.y,e.z=t.z,e.w=t.w,e.v=t.v,e.d=t.d,e}function r(t,e){var n=new s(t),r=e&&e.state,a=function(){return(n.next()>>>0)/4294967296};return a.double=function(){do{var t=((n.next()>>>11)+(n.next()>>>0)/4294967296)/(1<<21)}while(0===t);return t},a.int32=n.next,a.quick=a,r&&("object"==typeof r&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.xorwow=r}(0,t,!1)}(oo);var lo=oo.exports,uo={exports:{}};!function(t){!function(t,e,n){function s(t){var e=this;e.next=function(){var t,n,s=e.x,i=e.i;return t=s[i],n=(t^=t>>>7)^t<<24,n^=(t=s[i+1&7])^t>>>10,n^=(t=s[i+3&7])^t>>>3,n^=(t=s[i+4&7])^t<<7,t=s[i+7&7],n^=(t^=t<<13)^t<<9,s[i]=n,e.i=i+1&7,n},function(t,e){var n,s=[];if(e===(0|e))s[0]=e;else for(e=""+e,n=0;n0;--n)t.next()}(e,t)}function i(t,e){return e.x=t.x.slice(),e.i=t.i,e}function r(t,e){null==t&&(t=+new Date);var n=new s(t),r=e&&e.state,a=function(){return(n.next()>>>0)/4294967296};return a.double=function(){do{var t=((n.next()>>>11)+(n.next()>>>0)/4294967296)/(1<<21)}while(0===t);return t},a.int32=n.next,a.quick=a,r&&(r.x&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.xorshift7=r}(0,t,!1)}(uo);var ho=uo.exports,co={exports:{}};!function(t){!function(t,e,n){function s(t){var e=this;e.next=function(){var t,n,s=e.w,i=e.X,r=e.i;return e.w=s=s+1640531527|0,n=i[r+34&127],t=i[r=r+1&127],n^=n<<13,t^=t<<17,n^=n>>>15,t^=t>>>12,n=i[r]=n^t,e.i=r,n+(s^s>>>16)|0},function(t,e){var n,s,i,r,a,o=[],l=128;for(e===(0|e)?(s=e,e=null):(e+="\0",s=0,l=Math.max(l,e.length)),i=0,r=-32;r>>15,s^=s<<4,s^=s>>>13,r>=0&&(a=a+1640531527|0,i=0==(n=o[127&r]^=s+a)?i+1:0);for(i>=128&&(o[127&(e&&e.length||0)]=-1),i=127,r=512;r>0;--r)s=o[i+34&127],n=o[i=i+1&127],s^=s<<13,n^=n<<17,s^=s>>>15,n^=n>>>12,o[i]=s^n;t.w=a,t.X=o,t.i=i}(e,t)}function i(t,e){return e.i=t.i,e.w=t.w,e.X=t.X.slice(),e}function r(t,e){null==t&&(t=+new Date);var n=new s(t),r=e&&e.state,a=function(){return(n.next()>>>0)/4294967296};return a.double=function(){do{var t=((n.next()>>>11)+(n.next()>>>0)/4294967296)/(1<<21)}while(0===t);return t},a.int32=n.next,a.quick=a,r&&(r.X&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.xor4096=r}(0,t,!1)}(co);var po=co.exports,fo={exports:{}};!function(t){!function(t,e,n){function s(t){var e=this,n="";e.next=function(){var t=e.b,n=e.c,s=e.d,i=e.a;return t=t<<25^t>>>7^n,n=n-s|0,s=s<<24^s>>>8^i,i=i-t|0,e.b=t=t<<20^t>>>12^n,e.c=n=n-s|0,e.d=s<<16^n>>>16^i,e.a=i-t|0},e.a=0,e.b=0,e.c=-1640531527,e.d=1367130551,t===Math.floor(t)?(e.a=t/4294967296|0,e.b=0|t):n+=t;for(var s=0;s>>0)/4294967296};return a.double=function(){do{var t=((n.next()>>>11)+(n.next()>>>0)/4294967296)/(1<<21)}while(0===t);return t},a.int32=n.next,a.quick=a,r&&("object"==typeof r&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.tychei=r}(0,t,!1)}(fo);var go=fo.exports,mo={exports:{}},yo=En({__proto__:null,default:{}});!function(t){!function(e,n,s){var i,r=256,a=s.pow(r,6),o=s.pow(2,52),l=2*o,u=255;function h(t,u,h){var m=[],y=f(d((u=1==u?{entropy:!0}:u||{}).entropy?[t,g(n)]:null==t?function(){try{var t;return i&&(t=i.randomBytes)?t=t(r):(t=new Uint8Array(r),(e.crypto||e.msCrypto).getRandomValues(t)),g(t)}catch(t){var s=e.navigator,a=s&&s.plugins;return[+new Date,e,a,e.screen,g(n)]}}():t,3),m),b=new c(m),w=function(){for(var t=b.g(6),e=a,n=0;t=l;)t/=2,e/=2,n>>>=1;return(t+n)/e};return w.int32=function(){return 0|b.g(4)},w.quick=function(){return b.g(4)/4294967296},w.double=w,f(g(b.S),n),(u.pass||h||function(t,e,n,i){return i&&(i.S&&p(i,b),t.state=function(){return p(b,{})}),n?(s.random=t,e):t})(w,y,"global"in u?u.global:this==s,u.state)}function c(t){var e,n=t.length,s=this,i=0,a=s.i=s.j=0,o=s.S=[];for(n||(t=[n++]);it*e),1);o.push(l);let u=function(t,e,n){const s=t.shape.slice();s[n]=1;const i=ii(e,s),r=Ya(t,n,!0,!1),a=Ya(t,n,!0,!0),o=qs(r,a);return qs(i,o)}(a.reshape(o),e,i);if(u=u.reshape(a.shape),null!=r){const t=sr(r);u=lr(u,t)}return u}const To={kernelName:Me,gradFunc:(t,e,n)=>{const{blockShape:s,paddings:i}=n;return{x:()=>Ha(t,s,i)}}},Co={kernelName:Be,gradFunc:(t,e,n)=>{const{axis:s}=n;return{x:()=>Ka(t,s)}}};const $o=[Ks,ti,ei,ai,oi,ui,hi,ci,di,fi,gi,mi,xi,Ii,zi,Ti,Ci,$i,Fi,Bi,Pi,Wi,Gi,Vi,Zi,Xi,tr,ur,pr,dr,{kernelName:we,inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=si(n.shape,s.shape);return{a:()=>{const e=Hs(t,js(s,"float32")),r=ni(n.shape,i);return r.length>0?ii(ri(e,r),n.shape):e},b:()=>{let e=qs(t,js(n,"float32"));const r=ni(s.shape,i);r.length>0&&(e=ii(ri(e,r),s.shape));const a=Xs(s);return Js(Hs(e,js(a,"float32")))}}}},fr,mr,yr,br,wr,vr,kr,Nr,zr,Cr,$r,Fr,Dr,Lr,Rr,Or,Mr,Br,Ur,qr,qr,Kr,Hr,Zr,Qr,ta,ea,sa,ra,aa,oa,la,ua,ca,pa,pa,ga,ma,{kernelName:"Prod",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{axis:i}=n;let r=[];return r=null==i?s.shape.map(((t,e)=>e)):"number"==typeof i?[i]:i,{x:()=>Eo(s,t,r)}}},{kernelName:"Reciprocal",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>Hs(t,Js(Xs(n)))}}},{kernelName:"Relu6",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e,s=qs(Li(n,6),Vs(n));return{x:()=>qs(t,js(s,"float32"))}}},{kernelName:"Relu",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(t,js(Vs(n),"float32"))}}},{kernelName:Ce,inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ii(t,n.shape)}}},{kernelName:"ResizeBilinear",inputsToSave:["images"],gradFunc:(t,e,n)=>{const[s]=e,i={dy:t,images:s};return{images:()=>_s.runKernel("ResizeBilinearGrad",i,n)}}},{kernelName:"ResizeNearestNeighbor",inputsToSave:["images"],gradFunc:(t,e,n)=>{const[s]=e,i={dy:t,images:s};return{images:()=>_s.runKernel("ResizeNearestNeighborGrad",i,n)}}},{kernelName:$e,gradFunc:(t,e,n)=>{const{dims:s}=n,i=Xe(s,t.shape);return{x:()=>Io(t,i)}}},{kernelName:"Round",gradFunc:t=>({x:()=>li(t)})},{kernelName:Fe,inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>Js(Hs(t,qs(fa(n,1.5),2)))}}},{kernelName:De,inputsToSave:["condition"],gradFunc:(t,e)=>{const[n]=e;return{condition:()=>js(li(n),"float32"),t:()=>qs(t,js(n,t.dtype)),e:()=>qs(t,js(to(n),t.dtype))}}},{kernelName:"Selu",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>{const e=_r(n,Zs(0)),s=Zs(1.7580993408473768),i=Zs(1.0507009873554805),r=qs(t,i),a=qs(qs(t,s),gr(js(n,"float32")));return Mi(e,r,a)}}}},{kernelName:Re,outputsToSave:[!0],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(t,qs(n,Qs(Zs(1),n)))}}},{kernelName:"Sign",gradFunc:t=>({x:()=>li(t)})},{kernelName:"Sin",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(Ja(js(n,"float32")),t)}}},{kernelName:_e,inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(Za(js(n,"float32")),t)}}},{kernelName:Le,inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{begin:i,size:r}=n,a=s.shape,[o,l]=function(t,e,n){let s;const i=t.shape.length;let r;return s="number"==typeof e?[e,...new Array(i-1).fill(0)]:e.length{Ge(-1!==t,(()=>"slice() does not support negative begin indexing."))})),r=null==n?new Array(i).fill(-1):"number"==typeof n?[n,...new Array(i-1).fill(-1)]:n.lengthe>=0?e:(Ge(-1===e,(()=>`Negative size values should be exactly -1 but got ${e} for the slice() size at index ${n}.`)),t.shape[n]-s[n]))),[s,r]}(s,i,r),u=[];for(let e=0;eno(t,u)}}},{kernelName:"Softmax",outputsToSave:[!0],gradFunc:(t,e,n)=>{const[s]=e,{dim:i}=n,r=qs(t,s);return{logits:()=>Qs(r,qs(ri(r,[i],true),s))}}},{kernelName:"Softplus",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(t,Ga(n))}}},To,To,Co,Co,{kernelName:Oe,inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>Hs(t,qs(Ys(js(n,"float32")),2))}}},{kernelName:"SquaredDifference",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=Zs(2);return{a:()=>qs(t,qs(i,Qs(n,s))),b:()=>qs(t,qs(i,Qs(s,n)))}}},{kernelName:"Square",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(t,qs(js(n,"float32"),2))}}},{kernelName:Ve,gradFunc:t=>({x:()=>li(t)})},{kernelName:"Sub",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=si(n.shape,s.shape);return{a:()=>{let e=t;const s=ni(n.shape,i);return s.length>0&&(e=ri(e,s)),ii(e,n.shape)},b:()=>{let e=t;const n=ni(s.shape,i);return n.length>0&&(e=ri(e,n)),ii(Js(e),s.shape)}}}},{kernelName:"Sum",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,i=s.shape.slice(),{axis:r}=n;Xe(r,s.shape).forEach((t=>{i[t]=1}));const a=ii(t,i),o=qs(a,Xr(s.shape,"float32"));return{x:()=>o}}},{kernelName:"Tan",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>Hs(t,Xs(Ja(n)))}}},{kernelName:"Tanh",outputsToSave:[!0],gradFunc:(t,e)=>{const[n]=e;return{x:()=>qs(Qs(Zs(1),Xs(n)),t)}}},{kernelName:Pe,inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{reps:i}=n;return{x:()=>{let e=li(s);if(1===s.rank)for(let n=0;n{const s=n,{perm:i}=s,r=sr(i);return{x:()=>lr(t,r)}}},{kernelName:We,gradFunc:(t,e,n)=>{const s=n,{axis:i}=s;return{value:()=>Ir(t,i)}}},{kernelName:je,inputsToSave:["segmentIds"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>function(t,e){const n=eo(e,li(e)),s=Qa(t,n);let i=Di(e,Zs(0,"int32"));const r=s.rank-i.rank;for(let t=0;t({x:()=>li(t)})}];for(const t of $o)In(t);function Fo(t,n){return e.tidy((()=>i.sqrt(i.sum(i.mul(t,t),n,!0))))}class Do extends e.serialization.Serializable{getConfig(){return{}}}class Lo extends Do{constructor(t){super(),this.defaultMaxValue=2,this.defaultAxis=0,this.maxValue=null!=t.maxValue?t.maxValue:this.defaultMaxValue,this.axis=null!=t.axis?t.axis:this.defaultAxis}apply(t){return e.tidy((()=>{const e=Fo(t,this.axis),n=i.clipByValue(e,0,this.maxValue);return i.mul(t,i.div(n,i.add(tt(),e)))}))}getConfig(){return{maxValue:this.maxValue,axis:this.axis}}}Lo.className="MaxNorm",e.serialization.registerClass(Lo);class _o extends Do{constructor(t){super(),this.defaultAxis=0,this.axis=null!=t.axis?t.axis:this.defaultAxis}apply(t){return e.tidy((()=>i.div(t,i.add(tt(),Fo(t,this.axis)))))}getConfig(){return{axis:this.axis}}}_o.className="UnitNorm",e.serialization.registerClass(_o);class Ro extends Do{apply(t){return i.relu(t)}}Ro.className="NonNeg",e.serialization.registerClass(Ro);class Oo extends Do{constructor(t){super(),this.defaultMinValue=0,this.defaultMaxValue=1,this.defaultRate=1,this.defaultAxis=0,this.minValue=null!=t.minValue?t.minValue:this.defaultMinValue,this.maxValue=null!=t.maxValue?t.maxValue:this.defaultMaxValue,this.rate=null!=t.rate?t.rate:this.defaultRate,this.axis=null!=t.axis?t.axis:this.defaultAxis}apply(t){return e.tidy((()=>{const e=Fo(t,this.axis),n=i.add(i.mul(this.rate,i.clipByValue(e,this.minValue,this.maxValue)),i.mul(1-this.rate,e));return i.mul(t,i.div(n,i.add(tt(),e)))}))}getConfig(){return{minValue:this.minValue,maxValue:this.maxValue,rate:this.rate,axis:this.axis}}}Oo.className="MinMaxNorm",e.serialization.registerClass(Oo);const Mo={maxNorm:"MaxNorm",minMaxNorm:"MinMaxNorm",nonNeg:"NonNeg",unitNorm:"UnitNorm"};function Bo(t){return w(t)}function Po(t,n={}){return v(t,e.serialization.SerializationMap.getMap().classNameMap,n,"constraint")}function Uo(t){if(null==t)return null;if("string"==typeof t){return Po({className:t in Mo?Mo[t]:t,config:{}})}return t instanceof Do?t:Po(t)}var Wo={__proto__:null,maxNorm:function(t){return new Lo(t)},minMaxNorm:function(t){return new Oo(t)},nonNeg:function(){return new Ro},unitNorm:function(t){return new _o(t)}};var jo,qo={__proto__:null,constant:function(t){return new St(t)},glorotNormal:function(t){return new Tt(t)},glorotUniform:function(t){return new Et(t)},heNormal:function(t){return new Ct(t)},heUniform:function(t){return new $t(t)},identity:function(t){return new At(t)},leCunNormal:function(t){return new Ft(t)},leCunUniform:function(t){return new Dt(t)},ones:function(){return new vt},orthogonal:function(t){return new Lt(t)},randomNormal:function(t){return new Nt(t)},randomUniform:function(t){return new xt(t)},truncatedNormal:function(t){return new It(t)},varianceScaling:function(t){return new zt(t)},zeros:function(){return new kt}};async function Vo(t){if(null==t)return;const n=[],s=[],i=[];for(const e in t){const r=t[e];if("number"!=typeof r){const t=r;n.push(t.data()),s.push(e),i.push(t)}}if(n.length>0){const r=await Promise.all(n);for(let e=0;ee.add(this.totals[t],e.mul(i,s))));this.totals[t]=r,null!=n&&n.dispose()}}}async onEpochEnd(t,n){if(null!=n)for(const t of this.params.metrics)null!=this.totals[t]&&("number"==typeof this.totals[t]?n[t]=this.totals[t]/this.seen:e.tidy((()=>{const s=e.mul(e.div(1,this.seen),this.totals[t]);n[t]=s,this.totals[t].dispose(),e.keep(n[t])})))}}class Zo extends Go{async onTrainBegin(t){this.epoch=[],this.history={}}async onEpochEnd(t,e){null==e&&(e={}),this.epoch.push(t);for(const t in e)null==this.history[t]&&(this.history[t]=[]),this.history[t].push(e[t])}async syncData(){const t=[],e=[],n=[];for(const s in this.history){const i=this.history[s];for(let r=0;r{const o=null!=s?s():e.util.now();return o-rnew Yo(t,e)))}class Qo{constructor(){}static registerCallbackConstructor(t,n){e.util.assert(t>=0&&Number.isInteger(t),(()=>`Verbosity level is expected to be an integer >= 0, but got ${t}`)),Qo.checkForDuplicate(n),null==Qo.constructors[t]&&(Qo.constructors[t]=[]),Qo.constructors[t].push(n)}static checkForDuplicate(t){for(const e in Qo.constructors){Qo.constructors[+e].forEach((e=>{if(e===t)throw new o("Duplicate callback constructor.")}))}}static clear(){Qo.constructors={}}static createCallbacks(t){const e=[];for(const n in Qo.constructors){const s=+n;t>=s&&e.push(...Qo.constructors[s])}return e.map((t=>new t))}}function tl(t,e,n,s,i,r,a,o,l){const u=new Zo,h=[new Jo,...Qo.createCallbacks(e)];null!=t&&h.push(...t),h.push(u);const c=new Ho(h);return c.setParams({epochs:n,initialEpoch:s,samples:i,steps:r,batchSize:a,verbose:e,doValidation:o,metrics:l}),{callbackList:c,history:u}}function el(t,n={},s=!1){return v(t,e.serialization.SerializationMap.getMap().classNameMap,n,"layer",s)}function nl(t,n){return e.tidy((()=>{"float32"!==t.dtype&&(t=i.cast(t,"float32"));const e=i.sum(pt(t),n,!0),s=i.fill(e.shape,tt()),r=i.sqrt(i.maximum(e,s));return i.div(t,r)}))}function sl(t,n){return e.tidy((()=>i.mean(pt(i.sub(n,t)),-1)))}function il(t,n){return e.tidy((()=>i.mean(i.abs(i.sub(n,t)),-1)))}function rl(t,n){return e.tidy((()=>{const e=i.sub(t,n),s=i.clipByValue(i.abs(t),tt(),Number.MAX_VALUE),r=i.abs(i.div(e,s));return i.mul(100,i.mean(r,-1))}))}function al(t,n,s=!1){return e.tidy((()=>{if(s)n=i.softmax(n);else{const t=i.sum(n,n.shape.length-1,!0);n=i.div(n,t)}return n=i.clipByValue(n,tt(),1-tt()),i.neg(i.sum(i.mul(i.cast(t,"float32"),i.log(n)),n.shape.length-1))}))}function ol(t,n,s=!1){return e.tidy((()=>{const e=i.cast(i.floor(function(t){const e=[J(t.shape)];return i.reshape(t,e)}(t)),"int32"),r=(n=i.clipByValue(n,tt(),1-tt())).shape;return al(i.reshape(i.oneHot(e,r[r.length-1]),r),n,s)}))}function ll(t,n){return e.tidy((()=>{let s;return s=i.clipByValue(n,tt(),1-tt()),s=i.log(i.div(s,i.sub(1,s))),i.mean(function(t,n){if(!e.util.arraysEqual(t.shape,n.shape))throw new o(`logits and labels must have the same shape, but got shapes ${JSON.stringify(t.shape)} and ${JSON.stringify(n.shape)}`);return e.tidy((()=>{const e=i.relu(n),s=i.neg(i.abs(n));return i.add(i.sub(e,i.mul(n,t)),i.log1p(i.exp(s)))}))}(t,s),-1)}))}function ul(t,n){return e.tidy((()=>{const e=nl(t,-1),s=nl(n,-1),r=i.mul(e,s);return i.neg(i.sum(r,-1))}))}Qo.constructors={};const hl={meanSquaredError:sl,meanAbsoluteError:il,meanAbsolutePercentageError:rl,meanSquaredLogarithmicError:function(t,n){return e.tidy((()=>{const e=i.clipByValue(n,tt(),Number.MAX_VALUE),s=i.log(i.add(1,e)),r=i.clipByValue(t,tt(),Number.MAX_VALUE),a=i.log(i.add(1,r));return i.mean(pt(i.sub(s,a)),-1)}))},squaredHinge:function(t,n){return e.tidy((()=>{const e=i.maximum(0,i.sub(1,i.mul(t,n)));return i.mean(pt(e),-1)}))},hinge:function(t,n){return e.tidy((()=>{const e=i.maximum(0,i.sub(1,i.mul(t,n)));return i.mean(e,-1)}))},categoricalHinge:function(t,n){return e.tidy((()=>{const e=i.sum(i.mul(t,n),-1),s=i.max(i.mul(i.sub(1,t),n),-1);return i.maximum(0,i.add(1,i.sub(s,e)))}))},logcosh:function(t,n){return e.tidy((()=>{const e=Math.log(2),s=i.sub(n,t),r=i.sub(i.add(s,i.softplus(i.mul(-2,s))),e);return i.mean(r,-1)}))},categoricalCrossentropy:al,sparseCategoricalCrossentropy:ol,binaryCrossentropy:ll,kullbackLeiblerDivergence:function(t,n){return e.tidy((()=>{const e=i.clipByValue(t,tt(),1),s=i.clipByValue(n,tt(),1);return i.sum(i.mul(t,i.log(i.div(e,s))),-1)}))},poisson:function(t,n){return e.tidy((()=>{const e=i.log(i.add(tt(),n));return i.mean(i.sub(n,i.mul(t,e)),-1)}))},cosineProximity:ul};function cl(t){if("string"==typeof t){if(t in hl)return hl[t];let e=`Unknown loss ${t}`;throw t.toLowerCase().includes("softmaxcrossentropy")&&(e=`Unknown loss ${t}. Use "categoricalCrossentropy" as the string name for tf.losses.softmaxCrossEntropy`),new o(e)}return t}function pl(t,n){return e.tidy((()=>{const e=i.mul(.5,i.onesLike(n)),s=et(i.greater(n,e),t.dtype);return i.mean(i.equal(t,s),-1)}))}function dl(t,n){return e.tidy((()=>et(i.equal(i.argMax(t,-1),i.argMax(n,-1)),"float32")))}function fl(t,n){return e.tidy((()=>i.cast(i.sum(i.logicalAnd(i.equal(t,1),i.equal(n,1))),"float32")))}function gl(t,n){return e.tidy((()=>{const s=fl(t,n),r=function(t,n){return e.tidy((()=>i.cast(i.sum(i.logicalAnd(i.equal(t,0),i.equal(n,1))),"float32")))}(t,n),a=i.add(s,r);return i.cast(i.where(i.greater(a,0),i.div(s,a),0),"float32")}))}function ml(t,n){return e.tidy((()=>{const s=fl(t,n),r=function(t,n){return e.tidy((()=>i.cast(i.sum(i.logicalAnd(i.equal(t,1),i.equal(n,0))),"float32")))}(t,n),a=i.add(s,r);return i.cast(i.where(i.greater(a,0),i.div(s,a),0),"float32")}))}function yl(t,e){return ll(t,e)}function bl(t,e){return t.rank===e.rank&&(t=i.squeeze(t,[t.rank-1])),(e=i.argMax(e,-1)).dtype!==t.dtype&&(e=i.cast(e,t.dtype)),i.cast(i.equal(t,e),"float32")}const wl=al,kl=ol,vl={binaryAccuracy:pl,categoricalAccuracy:dl,precision:gl,categoricalCrossentropy:wl,sparseCategoricalCrossentropy:kl,mse:sl,MSE:sl,mae:il,MAE:il,mape:rl,MAPE:rl,cosine:ul};function Sl(t){if("string"==typeof t&&t in vl)return vl[t];if("string"!=typeof t&&null!=t)return t;throw new o(`Unknown metric ${t}`)}function xl(t){if(p(null!==t,`Unknown LossOrMetricFn ${t}`),"string"==typeof t)return t;{let e;for(const n of Object.keys(hl))if(hl[n]===t){e=n;break}if(void 0!==e)return e;for(const n of Object.keys(vl))if(vl[n]===t){e=n;break}return void 0!==e?e:t.name}}const Nl=1048576;function Il(t,e,n=!1){if(null==t||"object"!=typeof t||Object.getPrototypeOf(t)!==Object.prototype||!Al(t))throw new Error("User-defined metadata is expected to be a JSON object, but is not.");if(n){const n=JSON.stringify(t);n.length>Nl&&console.warn(`User-defined metadata of model "${e}" is too large in size (length=${n.length} when serialized). It is not recommended to store such large objects in user-defined metadata. Please make sure its serialized length is <= 1048576.`)}}function Al(t){if(null===t)return!0;if("object"==typeof t){if(Object.getPrototypeOf(t)===Object.prototype){const e=Object.keys(t);for(const n of e){if("string"!=typeof n)return!1;if(!Al(t[n]))return!1}return!0}if(Array.isArray(t)){for(const e of t)if(!Al(e))return!1;return!0}return!1}{const e=typeof t;return"string"===e||"number"===e||"boolean"===e}}function zl(t,e,n,s=console.log){const i=function(t){let e=!0;const n=[],s=[];for(const e in t.nodesByDepth)n.push(t.nodesByDepth[e]);for(const t of n){if(t.length>1||1===t.length&&t[0].inboundLayers.length>1){e=!1;break}s.push(...t)}if(e)for(const n of t.layers){let t=!1;for(const i of n.inboundNodes)if(-1!==s.indexOf(i)){if(t){e=!1;break}t=!0}if(!e)break}return e}(t),r=["Layer (type)","Input Shape","Output shape","Param #"];let a;if(i?(e=e||90,n=n||[.32,.61,.89,1]):(e=e||115,n=n||[.24,.48,.7,.8,1]),n[n.length-1]<=1&&(n=n.map((t=>Math.floor(e*t)))),!i){r.push("Receives inputs"),a=[];for(const e in t.nodesByDepth)a.push(...t.nodesByDepth[e])}s("_".repeat(e)),El(r,n,s),s("=".repeat(e));const o=t.layers;for(let t=0;t0&&(s=s.slice(0,s.length-1)+" "),s+=t[n],s=s.slice(0,e[n]),s+=" ".repeat(e[n]-s.length);n(s)}function Tl(t,e,n){let s,i;try{i=t.inboundNodes.map((t=>JSON.stringify(t.inputShapes))).join(",")}catch(t){i="multiple"}try{s=JSON.stringify(t.outputShape)}catch(t){s="multiple"}El([`${t.name} (${t.getClassName()})`,i,s,t.countParams().toString()],e,n)}function Cl(t,e,n,s){let i,r;try{r=t.inboundNodes.map((t=>JSON.stringify(t.inputShapes))).join(",")}catch(t){r="multiple"}try{i=JSON.stringify(t.outputShape)}catch(t){i="multiple"}const a=[];for(const e of t.inboundNodes)if(!(null!=n&&n.length>0&&-1===n.indexOf(e)))for(let t=0;tt.name))}`);x(this.outputs).length!==this.outputs.length&&console.warn(`The list of outputs passed to the model is redundant. All outputs should only appear once. Found: ${this.outputs.map((t=>t.name))}`),this.inputLayers=[],this.inputLayersNodeIndices=[],this.inputLayersTensorIndices=[],this.outputLayers=[],this.outputLayersNodeIndices=[],this.outputLayersTensorIndices=[],this.layers=[],this.internalContainerRefs=[];for(const t of this.outputs){const e=t.sourceLayer,n=t.nodeIndex,s=t.tensorIndex;this.outputLayers.push(e),this.outputLayersNodeIndices.push(n),this.outputLayersTensorIndices.push(s)}for(const t of this.inputs){const e=t.sourceLayer,n=t.nodeIndex,s=t.tensorIndex;p(0===n,"input layer has >1 nodes"),p(0===s,"input layer has >1 tensors"),this.inputLayers.push(e),this.inputLayersNodeIndices.push(n),this.inputLayersTensorIndices.push(s)}this.inputNames=[],this.outputNames=[],this.feedInputShapes=[],this.feedInputNames=[],this.feedOutputNames=[];for(let e=0;et.shape)),this.internalOutputShapes=this.outputs.map((t=>t.shape));const e={},n={},s={},i={},r={},l=[],u=(t,e,n,s,i,o)=>{null!=s&&null!=i&&null!=o||(s=t.sourceLayer,i=t.nodeIndex,o=t.tensorIndex);const h=s.inboundNodes[i];if(-1!==n.indexOf(h))throw new a(`The tensor ${t.name} at layer "${s.name}" is part of a cycle.`);if(-1!==e.indexOf(h))return;this.containerNodes.add(_l.nodeKey(s,i)),s.id in r||(r[s.id]=Object.keys(r).length),-1===n.indexOf(h)&&n.push(h);const c=h.inboundLayers.length;for(let t=0;t=0;)n.splice(n.indexOf(h),1);l.push(h)},h=[],c=[];for(const t of this.outputs)u(t,h,c);const d=l.slice().reverse();for(const t of d){n[t.id]=t,t.id in e||(e[t.id]=0);let r=e[t.id];const a=null==s[t.outboundLayer.id]?0:s[t.outboundLayer.id];r=Math.max(r,a),s[t.outboundLayer.id]=r,i[t.outboundLayer.id]=t.outboundLayer,e[t.id]=r;for(let s=0;sparseInt(t,10))).sort(S);this.layers=[];for(const t of m){const e=g[t];e.sort(((t,e)=>{const n=r[t.id],s=r[e.id];return ns?1:0}));for(const t of e)t instanceof _l&&this.internalContainerRefs.push(t),this.layers.push(t)}this.layersByDepth=g,m=Object.keys(f).map((t=>parseInt(t,10))).sort(S);const y=this.inputs.slice(),b=[];for(const t of m)for(const e of f[t]){const t=e.outboundLayer;if(null!=t){for(const n of e.inputTensors)if(-1===y.indexOf(n))throw new a(`Graph disconnected: cannot obtain value for tensor ${n} at layer "${t.name}". The following previous layers were accessed without issue: ${b}`);for(const t of e.outputTensors)y.push(t);b.push(t.name)}}this.nodesByDepth=f;const w=this.layers.map((t=>t.name));for(const t of w){const e=w.filter((e=>e===t)).length;if(1!==e)throw new a(`The name "${t}" is used ${e} times in the model. All layer names should be unique. Layer names: `+JSON.stringify(w))}this.outboundNodes=[],this.inboundNodes=[],new Yt({outboundLayer:this,inboundLayers:[],nodeIndices:[],tensorIndices:[],inputTensors:this.inputs,outputTensors:this.outputs,inputMasks:this.inputs.map((t=>null)),outputMasks:this.outputs.map((t=>null)),inputShapes:this.inputs.map((t=>t.shape)),outputShapes:this.outputs.map((t=>t.shape))}),this.built=!0,this._refCount=1}assertNotDisposed(){if(0===this._refCount)throw new Error(`Container '${this.name}' is already disposed.`)}dispose(){this.assertNotDisposed();const t={refCountAfterDispose:null,numDisposedVariables:0};if(0==--this._refCount){for(const e of this.layers)t.numDisposedVariables+=e.dispose().numDisposedVariables;for(const e of this.internalContainerRefs)t.numDisposedVariables+=e.dispose().numDisposedVariables}return t.refCountAfterDispose=this._refCount,t}get trainable(){return this.trainable_}set trainable(t){this.layers.forEach((e=>{e._trainableWeights.forEach((e=>e.trainable=t))})),this.trainable_=t}get trainableWeights(){if(this._trainableWeights.length>0)throw new o("Container instance unexpectedly contains _trainableWeights.The trainable weights of a Container are a union of the trainable weights of its consituent Layers. Its own _trainableWeights must remain an empty Array.");if(!this.trainable)return[];let t=[];for(const e of this.layers)t=t.concat(e.trainableWeights);return t}get nonTrainableWeights(){const t=[];for(const e of this.layers)t.push(...e.nonTrainableWeights);if(!this.trainable){const e=[];for(const t of this.layers)e.push(...t.trainableWeights);return e.concat(t)}return t}get weights(){return this.trainableWeights.concat(this.nonTrainableWeights)}loadWeights(t,e=!0){const n={};let s=0;const i=(t=>{const e=Object.keys(t);if(0===e.length)return!1;const n=e[0].split("/");return!isNaN(parseInt(n[n.length-1],10))})(t);i&&this.parseWeights(t);for(const t of this.layers)for(const[e,r]of t.weights.entries()){const t=i?`${r.name.split("/").slice(0,-1).join("/")+"/"}${e}`:r.originalName;if(null!=n[t])throw new o(`Duplicate weight name: ${t}`);n[t]=r,s++}const r=[];for(const s in t){let i=s;if(null==n[s]){const t=s.split("/");i=t.slice(0,-2).concat([t[t.length-1]]).join("/")}if(null!=n[i])r.push([n[i],t[s]]);else if(e)throw new o(`Provided weight data has no target variable: ${s}`);delete n[i]}if(e){const t=[];for(const e in n)t.push(e);if(t.length>0)throw new o(`${t.length} of ${s} weights are not set: ${t}`)}Gt(r)}parseWeights(t){for(const e in Object.keys(t)){const n=e.split("/"),s=["vars","layer_checkpoint_dependencies"],i=n.map((t=>t.startsWith("_")?t.slice(1):t)).filter((t=>!s.includes(t))).join("/");i!==e&&(t[i]=t[e],delete t[e])}}updatedConfig(){const t=this.getConfig(),e={};return e.className=this.getClassName(),e.config=t,e.kerasVersion="tfjs-layers 4.15.0",e.backend="TensorFlow.js",e}toJSON(t,e=!0){const n=Dl(this.updatedConfig());return e?JSON.stringify(n):n}call(t,n){return e.tidy((()=>{t=g(t);const e=new se;for(let n=0;n{let e;return t=g(t),e=null==n?c(null,t.length):g(n),this.runInternalGraph(t,e)[1]}))}computeOutputShape(t){const e=Pt(t);if(e.length!==this.inputLayers.length)throw new o(`Invalid inputShape argument ${t}: model has ${this.inputLayers.length} tensor inputs.`);const n={};for(let t=0;tparseInt(t,10))).sort(S);if(s.length>1)for(const t of s){const e=this.nodesByDepth[t];for(const t of e){const e=t.outboundLayer;if(-1!==this.inputLayers.map((t=>t.id)).indexOf(e.id))continue;const s=[];for(let e=0;eparseInt(t,10))).sort(S);for(const t of s){const e=this.nodesByDepth[t];for(const t of e){const e=t.outboundLayer,s=t.inputTensors,i=t.outputTensors,r=new Array;for(const t of s)t.id in n&&r.push(n[t.id]);if(r.length===s.length){let s,a,o,u,h={};if(null!=t.callArgs&&(h=t.callArgs),1===r.length){const[t,n]=r[0];null==h.mask&&(h.mask=n),o=g(e.call(t,h)),u=g(e.computeMask(t,n)),s=[t],a=[n]}else s=r.map((t=>t[0])),a=r.map((t=>t[1])),null==h.mask&&(h.mask=a),o=g(e.call(s,h)),u=g(e.computeMask(s,a));if(e.activityRegularizer)throw new l("LayersModel invocation with concrete Tensor value(s) in the presence of activity regularizer(s) is not supported yet.");for(let t=0;t{const t=[];for(const e of this.layers)for(let n=0;n0){const t=[];for(let n=0;n0&&t.apply(f(n),s)}function u(t){const n=t.name,r=el(t,null!=e.customObjects?e.customObjects:{});r.setFastWeightInitDuringBuild(s),i[n]=r;t.inboundNodes.forEach((t=>{if(!(t instanceof Array))throw new o(`Corrupted configuration, expected array for nodeData: ${t}`);a(r,t)}))}const h=e.name,c=e.layers;for(const t of c)u(t);for(;!N(r);)for(const t of c){const e=i[t.name];if(e.name in r){const t=r[e.name];delete r[e.name];for(const n of t)l(e,n)}}const d=[],g=[],m=e.inputLayers;for(const t of m){const e=t[0],n=t[1],s=t[2];p(e in i);const r=i[e].inboundNodes[n].outputTensors;d.push(r[s])}const y=e.outputLayers;for(const t of y){const e=t[0],n=t[1],s=t[2];p(e in i);const r=i[e].inboundNodes[n].outputTensors;g.push(r[s])}return new t({inputs:d,outputs:g,name:h})}get stateful(){if(this._stateful)throw new o("Container instance unexpectedly has _stateful = true. The statefulness of a Container is determined by the Layers it contains. Its _stateful property must remain the default false.");for(const t of this.layers)if(t.stateful)return!0;return!1}resetStates(){e.tidy((()=>{this.layers.forEach((t=>{t.stateful&&t.resetStates()}))}))}}function Rl(t,e){return function(t,e,n){const s=e.length;if(null==t||Array.isArray(t)&&0===t.length)return e.map((t=>null));if(1===s)return Array.isArray(t)&&1===t.length?t:"object"==typeof t&&e[0]in t?[t[e[0]]]:[t];if(Array.isArray(t)){if(t.length!==s)throw new Error(`Provided ${n} is an array of ${t.length} element(s), but the model has ${s} outputs. Make sure a set of weights is provided for each model output.`);return t}if("object"==typeof t&&Object.keys(t).length>0&&"object"==typeof t[Object.keys(t)[0]]){const n=[];return e.forEach((e=>{e in t?n.push(t[e]):n.push(null)})),n}throw new Error(`The model has multiple (${s}) outputs, so ${n} must be either an array with ${s} elements or an object with ${e} keys. Provided ${n} not understood: ${JSON.stringify(t)}`)}(t,e,"classWeight")}async function Ol(t,n,s,i){if(null!=n||null!=i)throw new Error("Support sampleWeight is not implemented yet");if(null!=s){const n=e.tidy((()=>{if(1===t.shape.length)return e.clone(t);if(2===t.shape.length){if(t.shape[1]>1){const n=1;return e.argMax(t,n)}if(1===t.shape[1])return e.reshape(t,[t.shape[0]]);throw new Error(`Encountered unexpected last-dimension size (${t.shape[1]}) during handling of class weights. The size is expected to be >= 1.`)}throw new Error(`Unexpected rank of target (y) tensor (${t.rank}) during handling of class weights. The rank is expected to be 1 or 2.`)})),i=Array.from(await n.data());e.dispose(n);const r=[];return i.forEach((t=>{if(null==s[t])throw new Error(`classWeight must contain all classes in the training data. The class ${t} exists in the data but not in classWeight`);r.push(s[t])})),e.tensor1d(r,"float32")}return null}function Ml(t,n){return e.mul(t,n)}function Bl(t,e){let n,s;const r=e;n=r.xs,s=r.ys,i.util.assert(null!=n&&null!=s,(()=>`A Dataset iterator for fitDataset() is expected to generate objects of the form \`{xs: xVal, ys: yVal}\`, where the two values may be \`tf.Tensor\`, an array of Tensors, or a map of string to Tensor. The provided Dataset instead generates ${e}`));const a=Pl("input",t.inputNames,n),o=Pl("output",t.outputNames,s),l=a[0].shape[0];i.util.assert(a.length===t.inputs.length,(()=>`LayersModel has ${t.inputs.length} inputs, but the dataset provides ${a.length} inputs. (Expected input keys: ${JSON.stringify(t.inputNames)})`)),i.util.assert(o.length===t.outputs.length,(()=>`LayersModel has ${t.outputs.length} outputs, but the dataset provides ${o.length} outputs. (Expected output keys: ${JSON.stringify(t.outputNames)})`));for(let e=0;e`Batch size mismatch: input ${t.inputNames[e]} has ${a[e].shape[0]}; expected ${l} based on input ${t.inputNames[0]}.`));for(let e=0;e`Batch size mismatch: output ${t.outputNames[e]} has ${o[e].shape[0]}; expected ${l} based on input ${t.inputNames[0]}.`));return{xs:a,ys:o}}function Pl(t,e,n){if(n instanceof i.Tensor)return[n];if(Array.isArray(n))return i.util.assert(n.length===e.length,(()=>`Received an array of ${n.length} Tensors, but expected ${e.length} to match the ${t} keys ${e}.`)),n;{const s=[];for(const i of e){if(null==n[i])throw new o(`The feature data generated by the dataset lacks the required ${t} key '${i}'.`);s.push(n[i])}return s}}async function Ul(t,e,n){const s=null!=n.batchesPerEpoch;if(i.util.assert(null!=t.optimizer,(()=>"You must compile a model before training/testing. Use LayersModel.compile(modelCompileConfig).")),i.util.assert(null!=n,(()=>"For fitDataset(), the 2nd argument (config) is required, but it is not provided in this call.")),i.util.assert(null!=n.epochs&&n.epochs>0&&Number.isInteger(n.epochs),(()=>`For fitDataset(), config.epochs is expected to be a positive integer, but got ${n.epochs}`)),i.util.assert(!s||n.batchesPerEpoch>0&&Number.isInteger(n.batchesPerEpoch),(()=>`For fitDataset(), config.batchesPerEpoch is expected to be a positive integer if specified, but got ${n.batchesPerEpoch}`)),i.util.assert(null==n.validationSplit,(()=>"`validationSplit` is not supported by `fitDataset()`. Use validationData instead.")),t.isTraining)throw new Error("Cannot start training because another fit() call is ongoing.");t.isTraining=!0;try{const r=null!=n.validationData;let a,o;if(r)if(Wl(n.validationData))i.util.assert(null==n.validationBatches||n.validationBatches>0&&Number.isInteger(n.validationBatches),(()=>`For fitDataset() with dataset-based validation, config.validationBatches is expected not to be provided, or to be a positive integer, but got ${n.validationBatches}`));else{const t=function(t){if(3===t.length)throw new l("Validation with sample weights is not implemented yet.");return{xs:t[0],ys:t[1]}}(n.validationData);a=t.xs,o=t.ys}const u=t.makeTrainFunction(),h=t.getDedupedMetricsNames();let c;c=r?h.slice().concat(h.map((t=>"val_"+t))):h.slice();const p=Xo(n.callbacks,n.yieldEvery),d=null==n.verbose?1:n.verbose,{callbackList:f,history:m}=tl(p,d,n.epochs,null,null,function(t,e){let n=null;null!=e.batchesPerEpoch?n=e.batchesPerEpoch:Number.isFinite(t.size)&&(n=t.size);return n}(e,n),null,r,c);f.setModel(t),t.history=m,await f.onTrainBegin(),t.stopTraining_=!1;let y=null==n.initialEpoch?0:n.initialEpoch,b=await e.iterator();for(;y=n.batchesPerEpoch:e.done){if(r){let e;e=Wl(n.validationData)?g(await t.evaluateDataset(n.validationData,{batches:n.validationBatches})):g(t.evaluate(a,o,{batchSize:null==n.validationBatchSize?32:n.validationBatchSize,verbose:0}));for(let n=0;n0&&Number.isInteger(t),(()=>`batchSize is required to be a positive integer, but got ${t}`))}function ql(t,e,n){return null==t?[null]:Array.isArray(t)?t.map((t=>st(t,e,n-e))):st(t,e,n-e)}function Vl(t,e){return i.tidy((()=>null==t?null:Array.isArray(t)?t.map((t=>Vl(t,e))):ct(t,"int32"===e.dtype?e:i.cast(e,"int32"))))}function Kl(t,e){const n=[];let s=0,i=null;for(;s=t&&(i=t),n.push([s,i]),s=i;return n}function Gl(t){const n=[];t instanceof e.Tensor&&(t=[t]);for(let e=0;es.push(t.id)));else if(null!=n)for(const t in n){const e=n[t];s.push(e.id)}const i=[];if(t instanceof e.Tensor)-1===s.indexOf(t.id)&&i.push(t);else if(Array.isArray(t))t.forEach((t=>{-1===s.indexOf(t.id)&&i.push(t)}));else if(null!=t)for(const e in t){const n=t[e];-1===s.indexOf(n.id)&&i.push(n)}i.forEach((t=>{t.isDisposed||t.dispose()}))}function Jl(t){return Array.isArray(t)}function Zl(t){return!function(t){return t instanceof e.Tensor}(t)&&!Jl(t)}function Yl(t,e,n,s=!0,i=""){if(null==e||0===e.length){if(null!=t){let e=!1;if(Jl(t)&&t.length>0)e=!0;else if(Zl(t)){for(const n in t)if(t.hasOwnProperty(n)){e=!0;break}}else e=!0;if(e)throw new o(`Error when checking model ${i} expected no data, but got ${t}`)}return[]}if(null==t)return e.map((t=>null));let r;if(Zl(t)){r=[];for(const n of e){if(null==t[n])throw new o(`No data provided for "${n}". Need data for each key in: ${e}`);r.push(t[n])}}else if(Jl(t)){if(t.length!==e.length)throw new o(`Error when checking model ${i}: the Array of Tensors that you are passing to your model is not the size the model expected. Expected to see ${e.length} Tensor(s), but instead got the following list of Tensor(s): ${t}`);r=t}else{if(e.length>1)throw new o(`The model ${i} expects ${e.length} Tensor(s), but only received one Tensor. Found: Tensor with shape ${t.shape}`);r=[t]}if(r=Gl(r),null!=n)for(let t=0;t=0&&r!==l)throw new o(`${i} expected a batch of elements where each example has shape [${n[t].slice(1,n[t].length)}] (i.e.,tensor shape [*,${n[t].slice(1,n[t].length)}]) but the ${i} received an input with ${a.shape[0]} examples, each with shape [${a.shape.slice(1,a.shape.length)}] (tensor shape [${a.shape}])`)}}return r}function Xl(t,e,n,s=!0,i=""){let r;if(Array.isArray(t)){if(t.length!==e.length)throw new o(`Error when checking model ${i}: the Array of Tensors that you are passing to your model is not the size the the model expected. Expected to see ${e.length} Tensor(s), but instead got ${t.length} Tensors(s).`);r=t}else{if(e.length>1)throw new o(`The model expects ${e.length} ${i} Tensors, but only received one Tensor. Found: array with shape ${JSON.stringify(t.shape)}.`);r=[t]}if(null!=n)for(let t=0;te.train.adagrad(.01),Adadelta:()=>e.train.adadelta(1,.95,tt()),Adam:()=>e.train.adam(.001,.9,.999,tt()),Adamax:()=>e.train.adamax(.002,.9,.999,tt(),0),RMSProp:()=>e.train.rmsprop(.001,.9,0,tt()),SGD:()=>e.train.sgd(.01)};if(n.adagrad=n.Adagrad,n.adadelta=n.Adadelta,n.adam=n.Adam,n.adamax=n.Adamax,n.rmsprop=n.RMSProp,n.sgd=n.SGD,t in n)return n[t]();throw new o(`Unknown Optimizer ${t}`)}(t.optimizer),this.isOptimizerOwned=!0;else{if(!(t.optimizer instanceof e.Optimizer))throw new o("User-defined optimizer must be an instance of tf.Optimizer.");this.optimizer_=t.optimizer,this.isOptimizerOwned=!1}let n=[];if(Array.isArray(t.loss)||"string"==typeof t.loss||"function"==typeof t.loss)if(Array.isArray(t.loss)){if(t.loss.length!==this.outputs.length)throw new o(`When passing an Array as loss, it should have one entry per model output. The model has ${this.outputs.length} output(s), but you passed loss=${t.loss}.`);const e=t.loss;n=e.map((t=>cl(t)))}else{const e=cl(t.loss);this.outputs.forEach((t=>{n.push(e)}))}else{t.loss=t.loss;for(const e in t.loss)if(-1===this.outputNames.indexOf(e))throw new o(`Unknown entry in loss dictionary: "${e}". Only expected the following keys: ${this.outputNames}`);for(const e of this.outputNames)null==t.loss[e]&&console.warn(`Output "${e}" is missing from loss dictionary. We assume this was done on purpose, and we will not be expecting data to be passed to ${e} during training`),n.push(cl(t.loss[e]))}this.lossFunctions=n,this.feedOutputNames=[],this.feedOutputShapes=[],this.feedLossFns=[];for(let t=0;t{for(let t=0;t1&&(this.metricsTensors.push([e,t]),this.metricsNames.push(this.outputNames[t]+"_loss"))}}));const i=function(t,e){if(null==t||Array.isArray(t)&&0===t.length)return e.map((t=>[]));let n;if("string"==typeof t||"function"==typeof t)n=[t];else{if(!Array.isArray(t)&&"object"!=typeof t)throw new TypeError(`Type of metrics argument not understood. Expected an string,function, Array, or Object, found: ${t}`);n=t}if(Array.isArray(n))return e.map((t=>n));{const t=[];for(const s of e){let e=n.hasOwnProperty(s)?n[s]:[];Array.isArray(e)||(e=[e]),t.push(e)}return t}}(t.metrics,this.outputNames),r=(t,e,n)=>{this.outputNames.length>1&&(e=this.outputNames[t]+"_"+e),this.metricsNames.push(e),this.metricsTensors.push([n,t])};q("metric",(()=>{for(let t=0;t{let n,s,i;for(const a of e){if("string"==typeof a&&-1!==["accuracy","acc","crossentropy","ce"].indexOf(a)){const e=this.internalOutputShapes[t];let r;1===e[e.length-1]||this.lossFunctions[t]===ll?-1!==["accuracy","acc"].indexOf(a)?s=pl:-1!==["crossentropy","ce"].indexOf(a)&&(s=yl):this.lossFunctions[t]===ol?-1!==["accuracy","acc"].indexOf(a)?s=bl:-1!==["crossentropy","ce"].indexOf(a)&&(s=kl):-1!==["accuracy","acc"].indexOf(a)?s=dl:-1!==["crossentropy","ce"].indexOf(a)&&(s=wl),-1!==["accuracy","acc"].indexOf(a)?r="acc":-1!==["crossentropy","ce"].indexOf(a)&&(r="ce"),i=s,n=""+r}else{const t=Sl(a);i=t,n=""+xl(a)}let e;q(n,(()=>{e=i})),r(t,n,e)}})(i[t])}})),this.collectedTrainableWeights=this.trainableWeights}checkTrainableWeightsConsistency(){null!=this.collectedTrainableWeights&&this.trainableWeights.length!==this.collectedTrainableWeights.length&&console.warn("Discrepancy between trainableweights and collected trainable weights. Did you set `model.trainable` without calling `model.compile()` afterwards?")}evaluate(t,e,n={}){const s=null==n.batchSize?32:n.batchSize;jl(s);const i=this.standardizeUserDataXY(t,e,!0,s);try{const r=i[0].concat(i[1]);this.makeTestFunction();const a=this.testFunction;return f(this.testLoop(a,r,s,n.verbose,n.steps))}finally{Hl(i[0],t),Hl(i[1],e)}}async evaluateDataset(t,n){return this.makeTestFunction(),async function(t,n,s){const r=null!=(s=s||{}).batches,a=t.testFunction;let o=[];if(s.verbose>0)throw new l("Verbose mode is not implemented yet.");i.util.assert(!r||s.batches>0&&Number.isInteger(s.batches),(()=>`Test loop expects \`batches\` to be a positive integer, but received ${JSON.stringify(s.batches)}`));const u="function"==typeof n.next?n:await n.iterator();let h=0,c=0;for(;!r||c{if(n.value){const{xs:s,ys:r}=Bl(t,n.value),l=s.concat(r),u=i.tidy((()=>a(l)));if(i.dispose(l),0===c)for(let t=0;ti.add(o[t],i.mul(p,e)))),c>0&&i.dispose(n)}i.dispose(u),h+=p,++c}return o})),n.done){r&&console.warn(`Your dataset iterator ran out of data during evaluateDataset(). Interrupting evalution. Make sure that your dataset can generate at least \`batches\` batches (in this case, ${s.batches} batches). You may need to use the repeat() function when building your dataset.`);break}}for(let t=0;tt.name));for(let s=0;s0){const n=[];throw e.forEach(((e,s)=>{null==e&&n.push(t[s])})),new o(`Cannot find SymbolicTensors for output name(s): ${JSON.stringify(n)}`)}return e}predictLoop(t,e=32,n=!1){return i.tidy((()=>{const s=this.checkNumSamples(t);if(n)throw new l("Verbose predictLoop() is not implemented yet.");const r=Kl(s,e),a=this.outputs.map((t=>[]));for(let e=0;e{const n=r[e][0],s=r[e][1],i=ql(t,n,s),a=[];if(Array.isArray(i))for(let t=0;ta[e].push(t)))}return f(a.map((t=>i.concat(t,0))))}))}predict(t,e={}){const n=Gl(t);Xl(n,this.inputNames,this.feedInputShapes,!1);try{const s=null==e.batchSize?32:e.batchSize;return jl(s),this.predictLoop(n,s)}finally{Hl(n,t)}}predictOnBatch(t){Xl(t,this.inputNames,this.feedInputShapes,!0);const e=(Array.isArray(t)?t[0]:t).shape[0];return this.predictLoop(t,e)}standardizeUserDataXY(t,n,s=!0,i){if(null==this.optimizer_)throw new a("You must compile a model before training/testing. Use LayersModel.compile(modelCompileArgs).");const r=[];for(let t=0;tt.shape[0])));i.sort();const r=x(n.map((t=>t.shape[0])));if(r.sort(),i.length>1)throw new o(`All input Tensors (x) should have the same number of samples. Got array shapes: ${JSON.stringify(t.map((t=>t.shape)))}`);if(r.length>1)throw new o(`All target Tensors (y) should have the same number of samples. Got array shapes: ${JSON.stringify(n.map((t=>t.shape)))}`);if(i.length>0&&r.length>0&&!e.util.arraysEqual(i,r))throw new o(`Input Tensors should have the same number of samples as target Tensors. Found ${i[0]} input sample(s) and ${r[0]} target sample(s).`)}(t=Yl(t,this.feedInputNames,this.feedInputShapes,!1,"input"),n=Yl(n,this.feedOutputNames,r,!1,"target")),function(t,e,n){const s=[sl,ll,al];for(let i=0;i0&&t[0].shape[0]%i!=0)throw new o(`In a stateful network, you should only pass inputs with a number of samples that is divisible by the batch size ${i}. Found: ${t[0].shape[0]} sample(s).`);return[t,n]}async standardizeUserData(t,e,n,s,i=!0,r){const[a,o]=this.standardizeUserDataXY(t,e,i,r);if(null!=n)throw new Error("sample weight is not supported yet.");let l=null;if(null!=s){const t=Rl(s,this.outputNames);l=[];for(let e=0;e{const o=this.checkNumSamples(n,s,a,"steps"),u=[];if(r>0)throw new l("Verbose mode is not implemented yet.");if(null!=a)throw new l("steps mode in testLoop() is not implemented yet");{const r=Kl(o,s),a=e.tensor1d(X(0,o));for(let s=0;s1){i+=`_${d(t.slice(0,n),s)}`}e.push(i)}return e}makeTrainFunction(){return t=>{const e=[],n=t.slice(0,this.inputs.length),s=t.slice(this.inputs.length,this.inputs.length+this.outputs.length),r=t.slice(this.inputs.length+this.outputs.length,this.inputs.length+2*this.outputs.length),a=[],o=this.collectedTrainableWeights.map((t=>t.read()));return[this.optimizer_.minimize((()=>{const t=[];for(let e=0;e1&&t{u=i.add(u,t)})),u}),!0,o)].concat(a)}}makeTestFunction(){this.testFunction=t=>i.tidy((()=>{const e=[];let n;const s=t.slice(0,this.inputs.length),r=t.slice(this.inputs.length,this.inputs.length+this.outputs.length),a=[];for(let t=0;t0){if(w=!0,2!==n.validationData.length)throw 3===n.validationData.length?new l("validationData including sample weights is not supported yet."):new o(`When passing validation data, it must contain 2 (valX, valY) or 3 (valX, valY, valSampleWeight) items; ${n.validationData} is invalid.`);h=n.validationData[0],c=n.validationData[1];const t=!0,e=await this.standardizeUserData(h,c,null,null,t,g);p=e[0],d=e[1],b=p.concat(d)}else if(null!=n.validationSplit&&n.validationSplit>0&&n.validationSplit<1){w=!0;const t=Math.floor(s[0].shape[0]*(1-n.validationSplit)),e=s[0].shape[0];p=ql(s,t,e),a=s,s=ql(s,0,t),d=ql(r,t,e),u=r,r=ql(r,0,t),b=p.concat(d)}else null!=n.validationSteps&&(w=!0);const k=s.concat(r).concat(f);this.checkTrainableWeightsConsistency();const v=this.makeTrainFunction(),S=this.getDedupedMetricsNames();let x,N;w?(this.makeTestFunction(),x=this.testFunction,N=S.slice().concat(S.map((t=>"val_"+t)))):(x=null,b=[],N=S.slice());const I=Xo(n.callbacks,n.yieldEvery);return await this.fitLoop(v,k,S,g,n.epochs,n.verbose,I,x,b,n.shuffle,N,n.initialEpoch,null,null)}finally{this.isTraining=!1,Hl(s,t),Hl(r,e),Hl(a,t),Hl(u,e),Hl(p,h),Hl(d,c),null!=f&&i.dispose(f)}}async fitLoop(t,n,s,r,a,u,h,c,p,d,f,g,m,y){null==r&&(r=32),null==a&&(a=1),null==d&&(d=!0),null==g&&(g=0);let b=!1;if(null!=c&&null!=p&&(b=!0),null!=y&&(b=!0,null==m))throw new o("Can only use `validationSteps` when doing step-wise training, i.e., `stepsPerEpoch` must be set.");const w=this.checkNumSamples(n,r,m,"steps_per_epoch");let k;null!=w&&(k=X(0,w)),null==u&&(u=1);const{callbackList:v,history:S}=tl(h,u,a,g,w,m,r,b,f);v.setModel(this),this.history=S,await v.onTrainBegin(),this.stopTraining_=!1;for(let o=g;o{const h=u[e][0],d=u[e][1],f=st(o,h,d-h);l.batch=e,l.size=d-h;const g=Vl(n,f),m=t(g);for(let t=0;tm(t)))}else{const e=Object.keys(this.loss);t={};const n=this.loss;for(const s of e){if("string"!=typeof n[s])throw new Error("Serialization of non-string loss is not supported.");t[s]=m(n[s])}}return t}getMetricIdentifiers(){if("string"==typeof this.metrics||"function"==typeof this.metrics)return[m(xl(this.metrics))];if(Array.isArray(this.metrics))return this.metrics.map((t=>m(xl(t))));{const t={};for(const e in this.metrics)t[e]=m(xl(this.metrics[e]));return t}}getTrainingConfig(){return{loss:this.getLossIdentifiers(),metrics:this.getMetricIdentifiers(),optimizer_config:{class_name:this.optimizer.getClassName(),config:this.optimizer.getConfig()}}}loadTrainingConfig(t){if(null!=t.weighted_metrics)throw new Error("Loading weight_metrics is not supported yet.");if(null!=t.loss_weights)throw new Error("Loading loss_weights is not supported yet.");if(null!=t.sample_weight_mode)throw new Error("Loading sample_weight_mode is not supported yet.");const e=el(Fl(t.optimizer_config));let n,s;if("string"==typeof t.loss)n=y(t.loss);else if(Array.isArray(t.loss))n=t.loss.map((t=>y(t)));else if(null!=t.loss){n={};for(const e in t.loss)n[e]=y(t.loss[e])}if(Array.isArray(t.metrics))s=t.metrics.map((t=>y(t)));else if(null!=t.metrics){s={};for(const e in t.metrics)s[e]=y(t.metrics[e])}this.compile({loss:n,metrics:s,optimizer:e})}async save(t,n){if("string"==typeof t){const n=e.io.getSaveHandlers(t);if(0===n.length)throw new o(`Cannot find any save handlers for URL '${t}'`);if(n.length>1)throw new o(`Found more than one (${n.length}) save handlers for URL '${t}'`);t=n[0]}if(null==t.save)throw new o("LayersModel.save() cannot proceed because the IOHandler provided does not have the `save` attribute defined.");const s=await e.io.encodeWeights(this.getNamedWeights(n)),i={modelTopology:this.toJSON(null,!1),format:"layers-model",generatedBy:"TensorFlow.js tfjs-layers v4.15.0",convertedBy:null};if(null!=n&&n.includeOptimizer&&null!=this.optimizer){i.trainingConfig=this.getTrainingConfig();const t="optimizer",{data:n,specs:r}=await e.io.encodeWeights(await this.optimizer.getWeights(),t);s.specs.push(...r),s.data=e.io.concatenateArrayBuffers([s.data,n])}if(null!=this.userDefinedMetadata){const t=!0;Il(this.userDefinedMetadata,this.name,t),i.userDefinedMetadata=this.userDefinedMetadata}return i.weightData=s.data,i.weightSpecs=s.specs,t.save(i)}setUserDefinedMetadata(t){Il(t,this.name),this.userDefinedMetadata=t}getUserDefinedMetadata(){return this.userDefinedMetadata}}Ql.className="Model",e.serialization.registerClass(Ql);class tu extends Ql{}tu.className="Functional",e.serialization.registerClass(tu);class eu extends Ql{constructor(t){if(super({inputs:[],outputs:[]}),t=t||{},this.trainable=!0,this.built=!1,this.name=null!=t.name?t.name:D("sequential_"),null!=t.layers)for(const e of t.layers)this.add(e)}checkShape(t){if(t.inboundNodes[0].outputTensors[0].shape.some((t=>t<0)))throw new o(`Negative dimension size caused by adding layer ${t.name} with input shape [${t.inboundNodes[0].inputTensors[0].shape}]`)}add(t){const e=t instanceof eu||t instanceof Ql;let n;if(e){if(n=t,1!==n.outputs.length)throw new o("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");if(1!==n.inputs.length)throw new o("All layers in a Sequential model should have a single input tensor. For multi-input layers, use the functional API.")}if(0===this.outputs.length){if(0===t.inboundNodes.length){if(null==t.batchInputShape)throw new o("The first layer in a Sequential model must get an `inputShape` or `batchInputShape` argument.");const e=ne({batchShape:t.batchInputShape,dtype:t.dtype,name:t.name+"_input"});t.apply(e)}if(e)this.outputs=n.outputs,this.inputs=n.inputs;else{if(1!==t.inboundNodes.length)throw new o(`A layer added to a Sequential model must not already be connected somewhere else. LayersModel received layer ${t.name} which has ${t.inboundNodes.length} pre-existing inbound connections.`);if(1!==t.inboundNodes[0].outputTensors.length)throw new o("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");this.checkShape(t),this.outputs=[t.inboundNodes[0].outputTensors[0]],this.inputs=te(this.outputs[0])}this.inboundNodes=[],new Yt({outboundLayer:this,inboundLayers:[],nodeIndices:[],tensorIndices:[],inputTensors:this.inputs,outputTensors:this.outputs,inputMasks:c(null,this.inputs.length),outputMasks:[null],inputShapes:this.inputs.map((t=>t.shape)),outputShapes:this.outputs[0].shape})}else{const e=t.apply(this.outputs[0]);if(Array.isArray(e))throw new TypeError("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");this.checkShape(t),this.outputs=[e],this.inboundNodes[0].outputTensors=this.outputs,this.inboundNodes[0].outputShapes=[this.outputs[0].shape]}this.layers.push(t),this.built=!1}pop(){if(0===this.layers.length)throw new TypeError("There are no layers in the model.");if(this.layers.pop(),0===this.layers.length)this.outputs=[],this.inboundNodes=[],this.outboundNodes=[];else{const t=this.layers.length-1;this.layers[t].outboundNodes=[],this.outputs=[this.layers[t].output],this.inboundNodes[0].outputTensors=this.outputs,this.inboundNodes[0].outputShapes=[this.outputs[0].shape]}}call(t,e){return null==this.model&&this.build(),this.model.call(t,e)}build(t){if(Wt(t),0===this.inputs.length||0===this.outputs.length)throw new TypeError("Sequential model cannot be built: model is empty. Add some layers first.");this.model=new Ql({inputs:this.inputs,outputs:this.outputs[0],name:this.name+"_model"}),this.model.trainable=this.trainable,this.supportsMasking=this.model.supportsMasking,this.inputLayers=this.model.inputLayers,this.inputLayersNodeIndices=this.model.inputLayersNodeIndices,this.inputLayersTensorIndices=this.model.inputLayersTensorIndices,this.outputLayers=this.model.outputLayers,this.outputLayersNodeIndices=this.model.outputLayersNodeIndices,this.outputLayersTensorIndices=this.model.outputLayersTensorIndices,this.nodesByDepth=this.model.nodesByDepth,this.containerNodes=this.model.containerNodes,this.outputNames=this.model.outputNames,this.inputNames=this.model.inputNames,this.built=!0}countParams(){return this.built||this.build(),super.countParams()}summary(t,e,n=console.log){this.built||this.build(),super.summary(t,e,n)}setWeights(t){null==this.model&&this.build(),this.model.setWeights(t)}evaluate(t,e,n={}){if(!this.built)throw new a("The model needs to be compiled before being used.");return this.model.evaluate(t,e,n)}async evaluateDataset(t,e){if(!this.built)throw new a("The model needs to be compiled before being used.");return this.model.evaluateDataset(t,e)}predict(t,e={}){return null==this.model&&this.build(),this.model.predict(t,e)}predictOnBatch(t){return null==this.model&&this.build(),this.model.predictOnBatch(t)}compile(t){this.build(),this.model.compile(t),this.optimizer_=this.model.optimizer,this.isOptimizerOwned=this.model.isOptimizerOwned,this.loss=this.model.loss,this.metrics=this.model.metrics,this.metricsTensors=this.model.metricsTensors,this.metricsNames=this.model.metricsNames}get optimizer(){return null==this.model?void 0:this.model.optimizer}set optimizer(t){this.model.optimizer=t}async fit(t,e,n={}){if(!this.built)throw new a("The model needs to be compiled before being used.");return this.model.fit(t,e,n)}async fitDataset(t,e){if(!this.built)throw new a("The model needs to be compiled before being used.");return this.model.fitDataset(t,e)}async trainOnBatch(t,e){return this.model.trainOnBatch(t,e)}static fromConfig(t,n,s={},i=!1){let r,a={};if(n instanceof Array){if(null==n[0].className||"Merge"===n[0].className)throw new o("Legacy serialization format not supported yet.");r=n}else e.util.assert(null!=n.layers,(()=>"When the config data for a Sequential model is not an Array, it must be an Object that contains the 'layers' field.")),r=n.layers,delete n.layers,a=n;const u=new t(a);if(!(u instanceof eu))throw new l(`Sequential.fromConfig called on non-Sequential input: ${u}`);for(const t of r){const e=el(t,void 0,i);i&&e.setFastWeightInitDuringBuild(!0),u.add(e)}return u}set stopTraining(t){if(null==this.model)throw new o("Cannot set the stopTraining property of a sequential model before it is compiled.");this.model.stopTraining=t}get stopTraining(){if(null==this.model)throw new o("Cannot get the stopTraining property of a sequential model before it is compiled.");return this.model.stopTraining}getConfig(){const t=[];for(const e of this.layers){const n={};n.className=e.getClassName(),n.config=e.getConfig(),t.push(n)}return{name:this.name,layers:t}}}function nu(t){return ne(t)}eu.className="Sequential",e.serialization.registerClass(eu);let su=class extends e.serialization.Serializable{getConfig(){return{}}};class iu extends su{apply(t,e=1){return function(t,e=1){if(1!==e)throw new l(`Support for alpha values other than 1 (${e}) is not implemented yet.`);return i.elu(t)}(t,e)}}iu.className="elu",e.serialization.registerClass(iu);class ru extends su{apply(t){return i.selu(t)}}ru.className="selu",e.serialization.registerClass(ru);class au extends su{apply(t){return i.relu(t)}}au.className="relu",e.serialization.registerClass(au);class ou extends su{apply(t){return e.tidy((()=>i.minimum(6,i.relu(t))))}}ou.className="relu6",e.serialization.registerClass(ou);class lu extends su{apply(t){return t}}lu.className="linear",e.serialization.registerClass(lu);class uu extends su{apply(t){return i.sigmoid(t)}}uu.className="sigmoid",e.serialization.registerClass(uu);class hu extends su{apply(t){return function(t){return e.tidy((()=>{const e=i.add(.5,i.mul(.2,t));return i.clipByValue(e,0,1)}))}(t)}}hu.className="hardSigmoid",e.serialization.registerClass(hu);class cu extends su{apply(t){return i.softplus(t)}}cu.className="softplus",e.serialization.registerClass(cu);class pu extends su{apply(t){return function(t){return e.tidy((()=>i.div(t,i.add(i.abs(t),1))))}(t)}}pu.className="softsign",e.serialization.registerClass(pu);class du extends su{apply(t){return i.tanh(t)}}du.className="tanh",e.serialization.registerClass(du);let fu=class extends su{apply(t,e=-1){return i.softmax(t,e)}};fu.className="softmax",e.serialization.registerClass(fu);class gu extends su{apply(t,e=-1){return i.logSoftmax(t,e)}}gu.className="logSoftmax",e.serialization.registerClass(gu);class mu extends su{apply(t,n=1){return e.tidy((()=>i.mul(i.sigmoid(i.mul(t,n)),t)))}}mu.className="swish",e.serialization.registerClass(mu);class yu extends su{apply(t){return e.tidy((()=>i.mul(t,i.tanh(i.softplus(t)))))}}function bu(t){return t.getClassName()}function wu(t,n={}){return v(t,e.serialization.SerializationMap.getMap().classNameMap,n,"activation")}function ku(t){if(null==t){const t={className:"linear",config:{}};return wu(t)}if("string"==typeof t){const e={};return e.className=t,e.config={},wu(e)}return t instanceof su?t:wu(t)}function vu(t){if(null!=t&&"object"!=typeof t)throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an object, but received: ${t}`)}yu.className="mish",e.serialization.registerClass(yu);class Su extends e.serialization.Serializable{}class xu extends Su{constructor(t){super(),vu(t),this.l1=null==t||null==t.l1?.01:t.l1,this.l2=null==t||null==t.l2?.01:t.l2,this.hasL1=0!==this.l1,this.hasL2=0!==this.l2}apply(t){return e.tidy((()=>{let n=e.zeros([1]);return this.hasL1&&(n=e.add(n,e.sum(i.mul(this.l1,e.abs(t))))),this.hasL2&&(n=e.add(n,e.sum(i.mul(this.l2,pt(t))))),i.reshape(n,[])}))}getConfig(){return{l1:this.l1,l2:this.l2}}static fromConfig(t,e){return new t({l1:e.l1,l2:e.l2})}}xu.className="L1L2",e.serialization.registerClass(xu);const Nu={l1l2:"L1L2"};function Iu(t){return w(t)}function Au(t,n={}){return v(t,e.serialization.SerializationMap.getMap().classNameMap,n,"regularizer")}function zu(t){if(null==t)return null;if("string"==typeof t){return Au({className:t in Nu?Nu[t]:t,config:{}})}return t instanceof Su?t:Au(t)}class Eu extends Qt{constructor(t){super(null==t?{}:t),this.supportsMasking=!0,null!=t&&(this.maxValue=t.maxValue)}call(t,n){t=Ut(t);let s=e.relu(t);return null!=this.maxValue&&(s=e.clipByValue(s,0,this.maxValue)),s}computeOutputShape(t){return t}getConfig(){const t={maxValue:this.maxValue},e=super.getConfig();return Object.assign(t,e),t}}Eu.className="ReLU",e.serialization.registerClass(Eu);class Tu extends Qt{constructor(t){super(null==t?{}:t),this.DEFAULT_ALPHA=.3,null==t&&(t={}),this.alpha=null==t.alpha?this.DEFAULT_ALPHA:t.alpha}call(t,n){const s=Ut(t);return e.leakyRelu(s,this.alpha)}computeOutputShape(t){return t}getConfig(){const t={alpha:this.alpha},e=super.getConfig();return Object.assign(t,e),t}}Tu.className="LeakyReLU",e.serialization.registerClass(Tu);class Cu extends Qt{constructor(t){if(super(null==t?{}:t),this.DEFAULT_ALPHA_INITIALIZER="zeros",null==t&&(t={}),this.supportsMasking=!0,this.alphaInitializer=Mt(t.alphaInitializer||this.DEFAULT_ALPHA_INITIALIZER),this.alphaRegularizer=zu(t.alphaRegularizer),this.alphaConstraint=Uo(t.alphaConstraint),null==t.sharedAxes)this.sharedAxes=null;else if(Array.isArray(t.sharedAxes))this.sharedAxes=t.sharedAxes;else{if("number"!=typeof t.sharedAxes)throw new o(`Expected sharedAxes to be a number or an array of numbers, but got ${t.sharedAxes}`);this.sharedAxes=[t.sharedAxes]}}build(t){const e=(t=Wt(t)).slice(1);if(null!=this.sharedAxes)for(const t of this.sharedAxes)e[t-1]=1;this.alpha=this.addWeight("alpha",e,"float32",this.alphaInitializer,this.alphaRegularizer,!0,this.alphaConstraint);const n={};if(null!=this.sharedAxes)for(let e=1;e{let s=Ut(t);const i=n.mask;if(null!=i){const t=e.mul(e.sub(e.ones(s.shape),e.cast(i,s.dtype)),e.scalar(-1e9));s=e.add(s,t)}return this.axis instanceof Array?this.axis.length>1?e.exp(e.sub(s,e.logSumExp(s,this.axis,!0))):this.softmax(s,this.axis[0]):this.softmax(s,this.axis)}))}computeOutputShape(t){return t}getConfig(){const t={axis:this.axis},e=super.getConfig();return Object.assign(t,e),t}}function Lu(t,e,n){if("number"==typeof t)return c(t,e);if(t.length!==e)throw new o(`The ${n} argument must be an integer or tuple of ${e} integers. Received: ${t.length} elements.`);for(let i=0;i(P(n),"channelsFirst"===n?i.transpose(t,[0,2,3,1]):t)))}function Mu(t,n){return e.tidy((()=>(P(n),"channelsFirst"===n?i.transpose(t,[0,2,3,4,1]):t)))}function Bu(t,n,s,r=[1,1],a="valid",u,h,c=null){return e.tidy((()=>{if(null==u&&(u="channelsLast"),P(u),3!==t.rank&&4!==t.rank)throw new o(`conv2dWithBiasActivation expects input to be of rank 3 or 4, but received ${t.rank}.`);if(3!==n.rank&&4!==n.rank)throw new o(`conv2dWithBiasActivation expects kernel to be of rank 3 or 4, but received ${t.rank}.`);let e=Ou(t,u);if("causal"===a)throw new l("The support for CAUSAL padding mode in conv1dWithBias is not implemented yet.");return e=i.fused.conv2d({x:e,filter:n,strides:r,pad:"same"===a?"same":"valid",dilations:h,dataFormat:"NHWC",bias:s,activation:c}),"channelsFirst"===u&&(e=i.transpose(e,[0,3,1,2])),e}))}Du.className="Softmax",e.serialization.registerClass(Du);class Pu extends Qt{constructor(t,e){if(super(e),this.bias=null,this.DEFAULT_KERNEL_INITIALIZER="glorotNormal",this.DEFAULT_BIAS_INITIALIZER="zeros",Pu.verifyArgs(e),this.rank=t,z(this.rank,"rank"),1!==this.rank&&2!==this.rank&&3!==this.rank)throw new l(`Convolution layer for rank other than 1, 2, or 3 (${this.rank}) is not implemented yet.`);if(this.kernelSize=Lu(e.kernelSize,t,"kernelSize"),this.strides=Lu(null==e.strides?1:e.strides,t,"strides"),this.padding=null==e.padding?"valid":e.padding,U(this.padding),this.dataFormat=null==e.dataFormat?"channelsLast":e.dataFormat,P(this.dataFormat),this.activation=ku(e.activation),this.useBias=null==e.useBias||e.useBias,this.biasInitializer=Mt(e.biasInitializer||this.DEFAULT_BIAS_INITIALIZER),this.biasConstraint=Uo(e.biasConstraint),this.biasRegularizer=zu(e.biasRegularizer),this.activityRegularizer=zu(e.activityRegularizer),this.dilationRate=Lu(null==e.dilationRate?1:e.dilationRate,t,"dilationRate"),1===this.rank&&Array.isArray(this.dilationRate)&&1!==this.dilationRate.length)throw new o(`dilationRate must be a number or an array of a single number for 1D convolution, but received ${JSON.stringify(this.dilationRate)}`);if(2===this.rank){if("number"==typeof this.dilationRate)this.dilationRate=[this.dilationRate,this.dilationRate];else if(2!==this.dilationRate.length)throw new o(`dilationRate must be a number or array of two numbers for 2D convolution, but received ${JSON.stringify(this.dilationRate)}`)}else if(3===this.rank)if("number"==typeof this.dilationRate)this.dilationRate=[this.dilationRate,this.dilationRate,this.dilationRate];else if(3!==this.dilationRate.length)throw new o(`dilationRate must be a number or array of three numbers for 3D convolution, but received ${JSON.stringify(this.dilationRate)}`)}static verifyArgs(t){if(p("kernelSize"in t,"required key 'kernelSize' not in config"),"number"!=typeof t.kernelSize&&!A(t.kernelSize,"number",1,3))throw new o(`BaseConv expects config.kernelSize to be number or number[] with length 1, 2, or 3, but received ${JSON.stringify(t.kernelSize)}.`)}getConfig(){const t={kernelSize:this.kernelSize,strides:this.strides,padding:this.padding,dataFormat:this.dataFormat,dilationRate:this.dilationRate,activation:bu(this.activation),useBias:this.useBias,biasInitializer:Ot(this.biasInitializer),biasRegularizer:Iu(this.biasRegularizer),activityRegularizer:Iu(this.activityRegularizer),biasConstraint:Bo(this.biasConstraint)},e=super.getConfig();return Object.assign(t,e),t}}class Uu extends Pu{constructor(t,e){super(t,e),this.kernel=null,Uu.verifyArgs(e),this.filters=e.filters,z(this.filters,"filters"),this.kernelInitializer=Mt(e.kernelInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.kernelConstraint=Uo(e.kernelConstraint),this.kernelRegularizer=zu(e.kernelRegularizer)}build(t){t=Wt(t);const e="channelsFirst"===this.dataFormat?1:t.length-1;if(null==t[e])throw new o(`The channel dimension of the input should be defined. Found ${t[e]}`);const n=t[e],s=this.kernelSize.concat([n,this.filters]);this.kernel=this.addWeight("kernel",s,null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.useBias&&(this.bias=this.addWeight("bias",[this.filters],null,this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint)),this.inputSpec=[{ndim:this.rank+2,axes:{[e]:n}}],this.built=!0}call(t,n){return e.tidy((()=>{let n;t=Ut(t);const s=null==this.bias?null:this.bias.read(),r=T(this.activation.getClassName());if(null!=r&&2===this.rank)n=Bu(t,this.kernel.read(),s,this.strides,this.padding,this.dataFormat,this.dilationRate,r);else{if(1===this.rank)n=function(t,n,s,r=1,a="valid",u,h=1){return e.tidy((()=>{if(null==u&&(u="channelsLast"),P(u),3!==t.shape.length)throw new o(`The input of a conv1dWithBias operation should be 3, but is ${t.shape.length} instead.`);if(3!==n.shape.length)throw new o(`The kernel for a conv1dWithBias operation should be 3, but is ${n.shape.length} instead`);if(null!=s&&1!==s.shape.length)throw new o(`The bias for a conv1dWithBias operation should be 1, but is ${n.shape.length} instead`);if("channelsFirst"===u&&(t=i.transpose(t,[0,2,1])),"causal"===a)throw new l("The support for CAUSAL padding mode in conv1dWithBias is not implemented yet.");let e=i.conv1d(t,n,r,"same"===a?"same":"valid","NWC",h);return null!=s&&(e=ft(e,s)),e}))}(t,this.kernel.read(),s,this.strides[0],this.padding,this.dataFormat,this.dilationRate[0]);else if(2===this.rank)n=Bu(t,this.kernel.read(),s,this.strides,this.padding,this.dataFormat,this.dilationRate);else{if(3!==this.rank)throw new l("convolutions greater than 3D are not implemented yet.");n=function(t,n,s,r=[1,1,1],a="valid",u,h){return e.tidy((()=>{if(null==u&&(u="channelsLast"),P(u),4!==t.rank&&5!==t.rank)throw new o(`conv3dWithBias expects input to be of rank 4 or 5, but received ${t.rank}.`);if(4!==n.rank&&5!==n.rank)throw new o(`conv3dWithBias expects kernel to be of rank 4 or 5, but received ${t.rank}.`);let e=Mu(t,u);if("causal"===a)throw new l("The support for CAUSAL padding mode in conv3dWithBias is not implemented yet.");return e=i.conv3d(e,n,r,"same"===a?"same":"valid","NDHWC",h),null!=s&&(e=ft(e,s)),"channelsFirst"===u&&(e=i.transpose(e,[0,4,1,2,3])),e}))}(t,this.kernel.read(),s,this.strides,this.padding,this.dataFormat,this.dilationRate)}null!=this.activation&&(n=this.activation.apply(n))}return n}))}computeOutputShape(t){t=Wt(t);const e=[],n="channelsLast"===this.dataFormat?t.slice(1,t.length-1):t.slice(2);for(let t=0;t 0 but got ${JSON.stringify(t.filters)}`)}}class Wu extends Uu{constructor(t){super(2,t),Wu.verifyArgs(t)}getConfig(){const t=super.getConfig();return delete t.rank,t}static verifyArgs(t){if("number"!=typeof t.kernelSize&&!A(t.kernelSize,"number",1,2))throw new o(`Conv2D expects config.kernelSize to be number or number[] with length 1 or 2, but received ${JSON.stringify(t.kernelSize)}.`)}}Wu.className="Conv2D",e.serialization.registerClass(Wu);class ju extends Uu{constructor(t){super(3,t),ju.verifyArgs(t)}getConfig(){const t=super.getConfig();return delete t.rank,t}static verifyArgs(t){if("number"!=typeof t.kernelSize&&(!Array.isArray(t.kernelSize)||1!==t.kernelSize.length&&3!==t.kernelSize.length))throw new o(`Conv3D expects config.kernelSize to be number or [number, number, number], but received ${JSON.stringify(t.kernelSize)}.`)}}ju.className="Conv3D",e.serialization.registerClass(ju);class qu extends Wu{constructor(t){if(super(t),this.inputSpec=[new Ht({ndim:4})],"same"!==this.padding&&"valid"!==this.padding)throw new o(`Conv2DTranspose currently supports only padding modes 'same' and 'valid', but received padding mode ${this.padding}`)}build(t){if(4!==(t=Wt(t)).length)throw new o("Input should have rank 4; Received input shape: "+JSON.stringify(t));const e="channelsFirst"===this.dataFormat?1:t.length-1;if(null==t[e])throw new o("The channel dimension of the inputs should be defined. Found `None`.");const n=t[e],s=this.kernelSize.concat([this.filters,n]);this.kernel=this.addWeight("kernel",s,"float32",this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.useBias&&(this.bias=this.addWeight("bias",[this.filters],"float32",this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint)),this.inputSpec=[new Ht({ndim:4,axes:{[e]:n}})],this.built=!0}call(t,e){return i.tidy((()=>{let e=Ut(t);if(4!==e.shape.length)throw new o(`Conv2DTranspose.call() expects input tensor to be rank-4, but received a tensor of rank-${e.shape.length}`);const n=e.shape,s=n[0];let r,a;"channelsFirst"===this.dataFormat?(r=2,a=3):(r=1,a=2);const l=n[r],u=n[a],h=this.kernelSize[0],c=this.kernelSize[1],p=this.strides[0],d=this.strides[1],f=[s,Ru(l,p,h,this.padding),Ru(u,d,c,this.padding),this.filters];"channelsLast"!==this.dataFormat&&(e=i.transpose(e,[0,2,3,1]));let g=i.conv2dTranspose(e,this.kernel.read(),f,this.strides,this.padding);return"channelsLast"!==this.dataFormat&&(g=i.transpose(g,[0,3,1,2])),null!=this.bias&&(g=ft(g,this.bias.read(),this.dataFormat)),null!=this.activation&&(g=this.activation.apply(g)),g}))}computeOutputShape(t){const e=(t=Wt(t)).slice();let n,s,i;"channelsFirst"===this.dataFormat?(n=1,s=2,i=3):(n=3,s=1,i=2);const r=this.kernelSize[0],a=this.kernelSize[1],o=this.strides[0],l=this.strides[1];return e[n]=this.filters,e[s]=Ru(e[s],o,r,this.padding),e[i]=Ru(e[i],l,a,this.padding),e}getConfig(){const t=super.getConfig();return delete t.dilationRate,t}}qu.className="Conv2DTranspose",e.serialization.registerClass(qu);class Vu extends ju{constructor(t){if(super(t),this.inputSpec=[new Ht({ndim:5})],"same"!==this.padding&&"valid"!==this.padding)throw new o(`Conv3DTranspose currently supports only padding modes 'same' and 'valid', but received padding mode ${this.padding}`)}build(t){if(5!==(t=Wt(t)).length)throw new o("Input should have rank 5; Received input shape: "+JSON.stringify(t));const e="channelsFirst"===this.dataFormat?1:t.length-1;if(null==t[e])throw new o("The channel dimension of the inputs should be defined. Found `None`.");const n=t[e],s=this.kernelSize.concat([this.filters,n]);this.kernel=this.addWeight("kernel",s,"float32",this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.useBias&&(this.bias=this.addWeight("bias",[this.filters],"float32",this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint)),this.inputSpec=[new Ht({ndim:5,axes:{[e]:n}})],this.built=!0}call(t,e){return i.tidy((()=>{let e=Ut(t);if(5!==e.shape.length)throw new o(`Conv3DTranspose.call() expects input tensor to be rank-4, but received a tensor of rank-${e.shape.length}`);const n=e.shape,s=n[0];let r,a,l;"channelsFirst"===this.dataFormat?(l=2,r=3,a=4):(l=1,r=2,a=3);const u=n[l],h=n[r],c=n[a],p=this.kernelSize[0],d=this.kernelSize[1],f=this.kernelSize[2],g=this.strides[0],m=this.strides[1],y=this.strides[2],b=[s,Ru(u,g,p,this.padding),Ru(h,m,d,this.padding),Ru(c,y,f,this.padding),this.filters];"channelsLast"!==this.dataFormat&&(e=i.transpose(e,[0,2,3,4,1]));let w=i.conv3dTranspose(e,this.kernel.read(),b,this.strides,this.padding);return"channelsLast"!==this.dataFormat&&(w=i.transpose(w,[0,4,1,2,3])),null!==this.bias&&(w=ft(w,this.bias.read(),this.dataFormat)),null!==this.activation&&(w=this.activation.apply(w)),w}))}computeOutputShape(t){const e=(t=Wt(t)).slice();let n,s,i,r;"channelsFirst"===this.dataFormat?(n=1,s=2,i=3,r=4):(n=4,s=1,i=2,r=3);const a=this.kernelSize[0],o=this.kernelSize[1],l=this.kernelSize[2],u=this.strides[0],h=this.strides[1],c=this.strides[2];return e[n]=this.filters,e[s]=Ru(e[s],u,a,this.padding),e[i]=Ru(e[i],h,o,this.padding),e[r]=Ru(e[r],c,l,this.padding),e}getConfig(){const t=super.getConfig();return delete t.dilationRate,t}}Vu.className="Conv3DTranspose",e.serialization.registerClass(Vu);class Ku extends Uu{constructor(t,e){if(super(t,e),this.DEFAULT_DEPTHWISE_INITIALIZER="glorotUniform",this.DEFAULT_POINTWISE_INITIALIZER="glorotUniform",this.depthwiseKernel=null,this.pointwiseKernel=null,null==e.filters)throw new o("The `filters` configuration field is required by SeparableConv, but is unspecified.");if(null!=e.kernelInitializer||null!=e.kernelRegularizer||null!=e.kernelConstraint)throw new o("Fields kernelInitializer, kernelRegularizer and kernelConstraint are invalid for SeparableConv2D. Use depthwiseInitializer, depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, pointwiseRegularizer and pointwiseConstraint instead.");if(null!=e.padding&&"same"!==e.padding&&"valid"!==e.padding)throw new o(`SeparableConv${this.rank}D supports only padding modes: 'same' and 'valid', but received ${JSON.stringify(e.padding)}`);this.depthMultiplier=null==e.depthMultiplier?1:e.depthMultiplier,this.depthwiseInitializer=Mt(e.depthwiseInitializer||this.DEFAULT_DEPTHWISE_INITIALIZER),this.depthwiseRegularizer=zu(e.depthwiseRegularizer),this.depthwiseConstraint=Uo(e.depthwiseConstraint),this.pointwiseInitializer=Mt(e.depthwiseInitializer||this.DEFAULT_POINTWISE_INITIALIZER),this.pointwiseRegularizer=zu(e.pointwiseRegularizer),this.pointwiseConstraint=Uo(e.pointwiseConstraint)}build(t){if((t=Wt(t)).length{let e;if(t=Ut(t),1===this.rank)throw new l("1D separable convolution is not implemented yet.");return 2===this.rank&&("channelsFirst"===this.dataFormat&&(t=i.transpose(t,[0,2,3,1])),e=i.separableConv2d(t,this.depthwiseKernel.read(),this.pointwiseKernel.read(),this.strides,this.padding,this.dilationRate,"NHWC")),this.useBias&&(e=ft(e,this.bias.read(),this.dataFormat)),null!=this.activation&&(e=this.activation.apply(e)),"channelsFirst"===this.dataFormat&&(e=i.transpose(e,[0,3,1,2])),e}))}getConfig(){const t=super.getConfig();return delete t.rank,delete t.kernelInitializer,delete t.kernelRegularizer,delete t.kernelConstraint,t.depthwiseInitializer=Ot(this.depthwiseInitializer),t.pointwiseInitializer=Ot(this.pointwiseInitializer),t.depthwiseRegularizer=Iu(this.depthwiseRegularizer),t.pointwiseRegularizer=Iu(this.pointwiseRegularizer),t.depthwiseConstraint=Bo(this.depthwiseConstraint),t.pointwiseConstraint=Bo(this.pointwiseConstraint),t}}Ku.className="SeparableConv";class Gu extends Ku{constructor(t){super(2,t)}}Gu.className="SeparableConv2D",e.serialization.registerClass(Gu);class Hu extends Uu{constructor(t){super(1,t),Hu.verifyArgs(t),this.inputSpec=[{ndim:3}]}getConfig(){const t=super.getConfig();return delete t.rank,delete t.dataFormat,t}static verifyArgs(t){if("number"!=typeof t.kernelSize&&!A(t.kernelSize,"number",1,1))throw new o(`Conv1D expects config.kernelSize to be number or number[] with length 1, but received ${JSON.stringify(t.kernelSize)}.`)}}Hu.className="Conv1D",e.serialization.registerClass(Hu);class Ju extends Qt{constructor(t){super(t),"number"==typeof t.cropping?this.cropping=[[t.cropping,t.cropping],[t.cropping,t.cropping]]:"number"==typeof t.cropping[0]?this.cropping=[[t.cropping[0],t.cropping[0]],[t.cropping[1],t.cropping[1]]]:this.cropping=t.cropping,this.dataFormat=void 0===t.dataFormat?"channelsLast":t.dataFormat,this.inputSpec=[{ndim:4}]}computeOutputShape(t){return"channelsFirst"===this.dataFormat?[t[0],t[1],t[2]-this.cropping[0][0]-this.cropping[0][1],t[3]-this.cropping[1][0]-this.cropping[1][1]]:[t[0],t[1]-this.cropping[0][0]-this.cropping[0][1],t[2]-this.cropping[1][0]-this.cropping[1][1],t[3]]}call(t,n){return e.tidy((()=>{if(t=Ut(t),"channelsLast"===this.dataFormat){const e=rt(t,this.cropping[0][0],t.shape[1]-this.cropping[0][0]-this.cropping[0][1],2);return rt(e,this.cropping[1][0],t.shape[2]-this.cropping[1][1]-this.cropping[1][0],3)}{const e=rt(t,this.cropping[0][0],t.shape[2]-this.cropping[0][0]-this.cropping[0][1],3);return rt(e,this.cropping[1][0],t.shape[3]-this.cropping[1][1]-this.cropping[1][0],4)}}))}getConfig(){const t={cropping:this.cropping,dataFormat:this.dataFormat},e=super.getConfig();return Object.assign(t,e),t}}Ju.className="Cropping2D",e.serialization.registerClass(Ju);class Zu extends Qt{constructor(t){var e;super(t),this.DEFAULT_SIZE=[2,2],this.inputSpec=[{ndim:4}],this.size=null==t.size?this.DEFAULT_SIZE:t.size,this.dataFormat=null==t.dataFormat?"channelsLast":t.dataFormat,P(this.dataFormat),this.interpolation=null==t.interpolation?"nearest":t.interpolation,e=this.interpolation,I(_,"InterpolationFormat",e)}computeOutputShape(t){if("channelsFirst"===this.dataFormat){const e=null==t[2]?null:this.size[0]*t[2],n=null==t[3]?null:this.size[1]*t[3];return[t[0],t[1],e,n]}{const e=null==t[1]?null:this.size[0]*t[1],n=null==t[2]?null:this.size[1]*t[2];return[t[0],e,n,t[3]]}}call(t,e){return i.tidy((()=>{let e=Ut(t);const n=e.shape;if("channelsFirst"===this.dataFormat){e=i.transpose(e,[0,2,3,1]);const t=this.size[0]*n[2],s=this.size[1]*n[3],r="nearest"===this.interpolation?i.image.resizeNearestNeighbor(e,[t,s]):i.image.resizeBilinear(e,[t,s]);return i.transpose(r,[0,3,1,2])}{const t=this.size[0]*n[1],s=this.size[1]*n[2];return"nearest"===this.interpolation?i.image.resizeNearestNeighbor(e,[t,s]):i.image.resizeBilinear(e,[t,s])}}))}getConfig(){const t={size:this.size,dataFormat:this.dataFormat,interpolation:this.interpolation},e=super.getConfig();return Object.assign(t,e),t}}Zu.className="UpSampling2D",e.serialization.registerClass(Zu);class Yu extends Pu{constructor(t){super(2,t),this.depthwiseKernel=null,this.depthMultiplier=null==t.depthMultiplier?1:t.depthMultiplier,this.depthwiseInitializer=Mt(t.depthwiseInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.depthwiseConstraint=Uo(t.depthwiseConstraint),this.depthwiseRegularizer=zu(t.depthwiseRegularizer)}build(t){if((t=Wt(t)).length<4)throw new o(`Inputs to DepthwiseConv2D should have rank 4. Received input shape: ${JSON.stringify(t)}.`);const e="channelsFirst"===this.dataFormat?1:3;if(null==t[e]||t[e]<0)throw new o(`The channel dimension of the inputs to DepthwiseConv2D should be defined, but is not (${t[e]}).`);const n=t[e],s=[this.kernelSize[0],this.kernelSize[1],n,this.depthMultiplier];this.depthwiseKernel=this.addWeight("depthwise_kernel",s,null,this.depthwiseInitializer,this.depthwiseRegularizer,!0,this.depthwiseConstraint),this.useBias?this.bias=this.addWeight("bias",[n*this.depthMultiplier],null,this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint):this.bias=null,this.built=!0}call(t,n){return e.tidy((()=>{let n=function(t,n,s=[1,1],r="valid",a,l){return e.tidy((()=>{null==a&&(a="channelsLast"),P(a);let e=Ou(t,a);if(4!==t.rank)throw new o(`Input for depthwiseConv2d is required to be 4-D, but is instead ${t.rank}-D`);if(4!==n.rank)throw new o(`depthwiseKernel is required to be 4-D, but is instead ${n.rank}-D`);return e=i.depthwiseConv2d(e,n,s,"same"===r?"same":"valid","NHWC",l),"channelsFirst"===a&&(e=i.transpose(e,[0,3,1,2])),e}))}(t=Ut(t),this.depthwiseKernel.read(),this.strides,this.padding,this.dataFormat,null);return this.useBias&&(n=ft(n,this.bias.read(),this.dataFormat)),null!=this.activation&&(n=this.activation.apply(n)),n}))}computeOutputShape(t){t=Wt(t);const e="channelsFirst"===this.dataFormat?t[2]:t[1],n="channelsFirst"===this.dataFormat?t[3]:t[2],s="channelsFirst"===this.dataFormat?t[1]*this.depthMultiplier:t[3]*this.depthMultiplier,i=_u(e,this.kernelSize[0],this.padding,this.strides[0]),r=_u(n,this.kernelSize[1],this.padding,this.strides[1]);return"channelsFirst"===this.dataFormat?[t[0],s,i,r]:[t[0],i,r,s]}getConfig(){const t=super.getConfig();return t.depthMultiplier=this.depthMultiplier,t.depthwiseInitializer=Ot(this.depthwiseInitializer),t.depthwiseRegularizer=Iu(this.depthwiseRegularizer),t.depthwiseConstraint=Bo(this.depthwiseRegularizer),t}}function Xu(t,e,n,s){if(Array.isArray(t)){if(null!=e||null!=n)throw new o("When inputs is an array, neither initialState or constants should be provided");null!=s&&(n=t.slice(t.length-s,t.length),t=t.slice(0,t.length-s)),t.length>1&&(e=t.slice(1,t.length)),t=t[0]}function i(t){return null==t||Array.isArray(t)?t:[t]}return{inputs:t,initialState:e=i(e),constants:n=i(n)}}function Qu(t,e,n,s=!1,r,a,u=!1,h=!1){return i.tidy((()=>{const c=e.shape.length;if(c<3)throw new o(`Input should be at least 3D, but is ${c}D.`);const p=[1,0].concat(X(2,c));if(e=i.transpose(e,p),null!=a)throw new l("The rnn() functoin of the deeplearn.js backend does not support constants yet.");u&&console.warn("Backend rnn(): the unroll = true option is not applicable to the imperative deeplearn.js backend."),null!=r&&((r=i.cast(i.cast(r,"bool"),"float32")).rank===c-1&&(r=i.expandDims(r,-1)),r=i.transpose(r,p)),s&&(e=i.reverse(e,0),null!=r&&(r=i.reverse(r,0)));const d=[];let f,g=n;const m=e.shape[0],y=i.unstack(e);let b,w;null!=r&&(b=i.unstack(r));for(let e=0;et(n,g)));if(null==r)f=s[0],g=s[1];else{const t=i.tidy((()=>{const t=b[e],n=i.sub(i.onesLike(t),t);return{output:i.add(i.mul(s[0],t),i.mul(g[0],n)),newStates:g.map(((e,r)=>i.add(i.mul(s[1][r],t),i.mul(e,n))))}}));f=t.output,g=t.newStates}h&&d.push(f)}if(h){const t=1;w=i.stack(d,t)}return[f,w,g]}))}Yu.className="DepthwiseConv2D",e.serialization.registerClass(Yu);class th extends Qt{constructor(t){let e;if(super(t),null==t.cell)throw new o("cell property is missing for the constructor of RNN.");if(e=Array.isArray(t.cell)?new lh({cells:t.cell}):t.cell,null==e.stateSize)throw new o("The RNN cell should have an attribute `stateSize` (tuple of integers, one integer per RNN state).");this.cell=e,this.returnSequences=null!=t.returnSequences&&t.returnSequences,this.returnState=null!=t.returnState&&t.returnState,this.goBackwards=null!=t.goBackwards&&t.goBackwards,this._stateful=null!=t.stateful&&t.stateful,this.unroll=null!=t.unroll&&t.unroll,this.supportsMasking=!0,this.inputSpec=[new Ht({ndim:3})],this.stateSpec=null,this.states_=null,this.numConstants=null,this.keptStates=[]}getStates(){if(null==this.states_){return X(0,Array.isArray(this.cell.stateSize)?this.cell.stateSize.length:1).map((t=>null))}return this.states_}setStates(t){this.states_=t}computeOutputShape(t){Bt(t)&&(t=t[0]);let e=this.cell.stateSize;Array.isArray(e)||(e=[e]);const n=e[0];let s;if(s=this.returnSequences?[t[0],t[1],n]:[t[0],n],this.returnState){const n=[];for(const s of e)n.push([t[0],s]);return[s].concat(n)}return s}computeMask(t,e){return i.tidy((()=>{Array.isArray(e)&&(e=e[0]);const t=this.returnSequences?e:null;if(this.returnState){const e=this.states.map((t=>null));return[t].concat(e)}return t}))}get states(){if(null==this.states_){const t=Array.isArray(this.cell.stateSize)?this.cell.stateSize.length:1,e=[];for(let n=0;nt.shape[t.shape.length-1])),r))throw new o(`An initialState was passed that is not compatible with cell.stateSize. Received stateSpec=${this.stateSpec}; However cell.stateSize is ${this.cell.stateSize}`)}else this.stateSpec=r.map((t=>new Ht({shape:[null,t]})));this.stateful&&this.resetStates()}resetStates(t,n=!1){e.tidy((()=>{if(!this.stateful)throw new r("Cannot call resetStates() on an RNN Layer that is not stateful.");const s=this.inputSpec[0].shape[0];if(null==s)throw new o("If an RNN is stateful, it needs to know its batch size. Specify the batch size of your input tensors: \n- If using a Sequential model, specify the batch size by passing a `batchInputShape` option to your first layer.\n- If using the functional API, specify the batch size by passing a `batchShape` option to your Input layer.");if(null==this.states_)Array.isArray(this.cell.stateSize)?this.states_=this.cell.stateSize.map((t=>i.zeros([s,t]))):this.states_=[i.zeros([s,this.cell.stateSize])];else if(null==t)i.dispose(this.states_),null!=this.keptStates&&(i.dispose(this.keptStates),this.keptStates=[]),Array.isArray(this.cell.stateSize)?this.states_=this.cell.stateSize.map((t=>i.zeros([s,t]))):this.states_[0]=i.zeros([s,this.cell.stateSize]);else{if(Array.isArray(t)||(t=[t]),t.length!==this.states_.length)throw new o(`Layer ${this.name} expects ${this.states_.length} state(s), but it received ${t.length} state value(s). Input received: ${t}`);!0===n?this.keptStates.push(this.states_.slice()):i.dispose(this.states_);for(let n=0;ni.keep(t.clone())))}))}apply(t,e){let n=null==e?null:e.initialState,s=null==e?null:e.constants;null==e&&(e={});const i=Xu(t,n,s,this.numConstants);t=i.inputs,n=i.initialState,s=i.constants;let r=[],a=[];if(null!=n){e.initialState=n,r=r.concat(n),this.stateSpec=[];for(const t of n)this.stateSpec.push(new Ht({shape:t.shape}));a=a.concat(this.stateSpec)}null!=s&&(e.constants=s,r=r.concat(s),this.numConstants=s.length);if(r[0]instanceof Jt){const n=[t].concat(r),s=this.inputSpec.concat(a),i=this.inputSpec;this.inputSpec=s;const o=super.apply(n,e);return this.inputSpec=i,o}return super.apply(t,e)}call(t,n){return e.tidy((()=>{const e=null==n?null:n.mask,s=null==n?null:n.training;let i=null==n?null:n.initialState;t=Ut(t),null==i&&(i=this.stateful?this.states_:this.getInitialState(t));const r=Array.isArray(this.cell.stateSize)?this.cell.stateSize.length:1;if(i.length!==r)throw new o(`RNN Layer has ${r} state(s) but was passed ${i.length} initial state(s).`);this.unroll&&console.warn("Ignoring unroll = true for RNN layer, due to imperative backend.");const a={training:s},l=Qu(((t,e)=>{const n=this.cell.call([t].concat(e),a);return[n[0],n.slice(1)]}),t,i,this.goBackwards,e,null,this.unroll,this.returnSequences),u=l[0],h=l[1],c=l[2];this.stateful&&this.resetStates(c,s);const p=this.returnSequences?h:u;return this.returnState?[p].concat(c):p}))}getInitialState(t){return e.tidy((()=>{let e=i.zeros(t.shape);return e=i.sum(e,[1,2]),e=nt(e),Array.isArray(this.cell.stateSize)?this.cell.stateSize.map((t=>t>1?lt(e,[1,t]):e)):this.cell.stateSize>1?[lt(e,[1,this.cell.stateSize])]:[e]}))}get trainableWeights(){return this.trainable?this.cell.trainableWeights:[]}get nonTrainableWeights(){return this.trainable?this.cell.nonTrainableWeights:this.cell.weights}setFastWeightInitDuringBuild(t){super.setFastWeightInitDuringBuild(t),null!=this.cell&&this.cell.setFastWeightInitDuringBuild(t)}getConfig(){const t=super.getConfig(),e={returnSequences:this.returnSequences,returnState:this.returnState,goBackwards:this.goBackwards,stateful:this.stateful,unroll:this.unroll};null!=this.numConstants&&(e.numConstants=this.numConstants);const n=this.cell.getConfig();return this.getClassName()===th.className&&(e.cell={className:this.cell.getClassName(),config:n}),Object.assign(Object.assign(Object.assign({},n),t),e)}static fromConfig(t,e,n={}){const s=el(e.cell,n);return new t(Object.assign(e,{cell:s}))}}th.className="RNN",e.serialization.registerClass(th);class eh extends Qt{}class nh extends eh{constructor(t){super(t),this.DEFAULT_ACTIVATION="tanh",this.DEFAULT_KERNEL_INITIALIZER="glorotNormal",this.DEFAULT_RECURRENT_INITIALIZER="orthogonal",this.DEFAULT_BIAS_INITIALIZER="zeros",this.units=t.units,z(this.units,"units"),this.activation=ku(null==t.activation?this.DEFAULT_ACTIVATION:t.activation),this.useBias=null==t.useBias||t.useBias,this.kernelInitializer=Mt(t.kernelInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.recurrentInitializer=Mt(t.recurrentInitializer||this.DEFAULT_RECURRENT_INITIALIZER),this.biasInitializer=Mt(t.biasInitializer||this.DEFAULT_BIAS_INITIALIZER),this.kernelRegularizer=zu(t.kernelRegularizer),this.recurrentRegularizer=zu(t.recurrentRegularizer),this.biasRegularizer=zu(t.biasRegularizer),this.kernelConstraint=Uo(t.kernelConstraint),this.recurrentConstraint=Uo(t.recurrentConstraint),this.biasConstraint=Uo(t.biasConstraint),this.dropout=Z([1,Y([0,null==t.dropout?0:t.dropout])]),this.recurrentDropout=Z([1,Y([0,null==t.recurrentDropout?0:t.recurrentDropout])]),this.dropoutFunc=t.dropoutFunc,this.stateSize=this.units,this.dropoutMask=null,this.recurrentDropoutMask=null}build(t){t=Wt(t),this.kernel=this.addWeight("kernel",[t[t.length-1],this.units],null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.recurrentKernel=this.addWeight("recurrent_kernel",[this.units,this.units],null,this.recurrentInitializer,this.recurrentRegularizer,!0,this.recurrentConstraint),this.useBias?this.bias=this.addWeight("bias",[this.units],null,this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint):this.bias=null,this.built=!0}call(t,n){return e.tidy((()=>{if(2!==t.length)throw new o(`SimpleRNNCell expects 2 input Tensors, got ${t.length}.`);let e=t[1];t=t[0];const s=null!=n.training&&n.training;let r;0i.onesLike(t),rate:this.dropout,training:s,dropoutFunc:this.dropoutFunc})),0i.onesLike(e),rate:this.recurrentDropout,training:s,dropoutFunc:this.dropoutFunc}));const a=this.dropoutMask,l=this.recurrentDropoutMask;r=ht(null!=a?i.mul(t,a):t,this.kernel.read()),null!=this.bias&&(r=ft(r,this.bias.read())),null!=l&&(e=i.mul(e,l));let u=i.add(r,ht(e,this.recurrentKernel.read()));return null!=this.activation&&(u=this.activation.apply(u)),[u,u]}))}getConfig(){const t=super.getConfig(),e={units:this.units,activation:bu(this.activation),useBias:this.useBias,kernelInitializer:Ot(this.kernelInitializer),recurrentInitializer:Ot(this.recurrentInitializer),biasInitializer:Ot(this.biasInitializer),kernelRegularizer:Iu(this.kernelRegularizer),recurrentRegularizer:Iu(this.recurrentRegularizer),biasRegularizer:Iu(this.biasRegularizer),activityRegularizer:Iu(this.activityRegularizer),kernelConstraint:Bo(this.kernelConstraint),recurrentConstraint:Bo(this.recurrentConstraint),biasConstraint:Bo(this.biasConstraint),dropout:this.dropout,recurrentDropout:this.recurrentDropout};return Object.assign(Object.assign({},t),e)}}nh.className="SimpleRNNCell",e.serialization.registerClass(nh);class sh extends th{constructor(t){t.cell=new nh(t),super(t)}call(t,n){return e.tidy((()=>{null!=this.cell.dropoutMask&&(i.dispose(this.cell.dropoutMask),this.cell.dropoutMask=null),null!=this.cell.recurrentDropoutMask&&(i.dispose(this.cell.recurrentDropoutMask),this.cell.recurrentDropoutMask=null);const e=null==n?null:n.mask,s=null==n?null:n.training,r=null==n?null:n.initialState;return super.call(t,{mask:e,training:s,initialState:r})}))}static fromConfig(t,e){return new t(e)}}sh.className="SimpleRNN",e.serialization.registerClass(sh);class ih extends eh{constructor(t){if(super(t),this.DEFAULT_ACTIVATION="tanh",this.DEFAULT_RECURRENT_ACTIVATION="hardSigmoid",this.DEFAULT_KERNEL_INITIALIZER="glorotNormal",this.DEFAULT_RECURRENT_INITIALIZER="orthogonal",this.DEFAULT_BIAS_INITIALIZER="zeros",t.resetAfter)throw new o("GRUCell does not support reset_after parameter set to true.");this.units=t.units,z(this.units,"units"),this.activation=ku(void 0===t.activation?this.DEFAULT_ACTIVATION:t.activation),this.recurrentActivation=ku(void 0===t.recurrentActivation?this.DEFAULT_RECURRENT_ACTIVATION:t.recurrentActivation),this.useBias=null==t.useBias||t.useBias,this.kernelInitializer=Mt(t.kernelInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.recurrentInitializer=Mt(t.recurrentInitializer||this.DEFAULT_RECURRENT_INITIALIZER),this.biasInitializer=Mt(t.biasInitializer||this.DEFAULT_BIAS_INITIALIZER),this.kernelRegularizer=zu(t.kernelRegularizer),this.recurrentRegularizer=zu(t.recurrentRegularizer),this.biasRegularizer=zu(t.biasRegularizer),this.kernelConstraint=Uo(t.kernelConstraint),this.recurrentConstraint=Uo(t.recurrentConstraint),this.biasConstraint=Uo(t.biasConstraint),this.dropout=Z([1,Y([0,null==t.dropout?0:t.dropout])]),this.recurrentDropout=Z([1,Y([0,null==t.recurrentDropout?0:t.recurrentDropout])]),this.dropoutFunc=t.dropoutFunc,this.implementation=t.implementation,this.stateSize=this.units,this.dropoutMask=null,this.recurrentDropoutMask=null}build(t){const e=(t=Wt(t))[t.length-1];this.kernel=this.addWeight("kernel",[e,3*this.units],null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.recurrentKernel=this.addWeight("recurrent_kernel",[this.units,3*this.units],null,this.recurrentInitializer,this.recurrentRegularizer,!0,this.recurrentConstraint),this.useBias?this.bias=this.addWeight("bias",[3*this.units],null,this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint):this.bias=null,this.built=!0}call(t,n){return e.tidy((()=>{if(2!==t.length)throw new o(`GRUCell expects 2 input Tensors (inputs, h, c), got ${t.length}.`);const e=null!=n.training&&n.training;let s=t[1];t=t[0],0i.onesLike(t),rate:this.dropout,training:e,count:3,dropoutFunc:this.dropoutFunc})),0i.onesLike(s),rate:this.recurrentDropout,training:e,count:3,dropoutFunc:this.dropoutFunc}));const r=this.dropoutMask,a=this.recurrentDropoutMask;let l,u,h;0{null!=this.cell.dropoutMask&&(i.dispose(this.cell.dropoutMask),this.cell.dropoutMask=null),null!=this.cell.recurrentDropoutMask&&(i.dispose(this.cell.recurrentDropoutMask),this.cell.recurrentDropoutMask=null);const e=null==n?null:n.mask,s=null==n?null:n.training,r=null==n?null:n.initialState;return super.call(t,{mask:e,training:s,initialState:r})}))}static fromConfig(t,e){return 0===e.implmentation&&(e.implementation=1),new t(e)}}rh.className="GRU",e.serialization.registerClass(rh);class ah extends eh{constructor(t){super(t),this.DEFAULT_ACTIVATION="tanh",this.DEFAULT_RECURRENT_ACTIVATION="hardSigmoid",this.DEFAULT_KERNEL_INITIALIZER="glorotNormal",this.DEFAULT_RECURRENT_INITIALIZER="orthogonal",this.DEFAULT_BIAS_INITIALIZER="zeros",this.units=t.units,z(this.units,"units"),this.activation=ku(void 0===t.activation?this.DEFAULT_ACTIVATION:t.activation),this.recurrentActivation=ku(void 0===t.recurrentActivation?this.DEFAULT_RECURRENT_ACTIVATION:t.recurrentActivation),this.useBias=null==t.useBias||t.useBias,this.kernelInitializer=Mt(t.kernelInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.recurrentInitializer=Mt(t.recurrentInitializer||this.DEFAULT_RECURRENT_INITIALIZER),this.biasInitializer=Mt(t.biasInitializer||this.DEFAULT_BIAS_INITIALIZER),this.unitForgetBias=t.unitForgetBias,this.kernelRegularizer=zu(t.kernelRegularizer),this.recurrentRegularizer=zu(t.recurrentRegularizer),this.biasRegularizer=zu(t.biasRegularizer),this.kernelConstraint=Uo(t.kernelConstraint),this.recurrentConstraint=Uo(t.recurrentConstraint),this.biasConstraint=Uo(t.biasConstraint),this.dropout=Z([1,Y([0,null==t.dropout?0:t.dropout])]),this.recurrentDropout=Z([1,Y([0,null==t.recurrentDropout?0:t.recurrentDropout])]),this.dropoutFunc=t.dropoutFunc,this.implementation=t.implementation,this.stateSize=[this.units,this.units],this.dropoutMask=null,this.recurrentDropoutMask=null}build(t){var e;const n=(t=Wt(t))[t.length-1];let s;if(this.kernel=this.addWeight("kernel",[n,4*this.units],null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.recurrentKernel=this.addWeight("recurrent_kernel",[this.units,4*this.units],null,this.recurrentInitializer,this.recurrentRegularizer,!0,this.recurrentConstraint),this.useBias){if(this.unitForgetBias){const t=this.biasInitializer,n=this.units;s=new((e=class extends wt{apply(e,s){const i=t.apply([n]),r=(new vt).apply([n]),a=t.apply([2*n]);return ot(ot(i,r),a)}}).className="CustomInit",e)}else s=this.biasInitializer;this.bias=this.addWeight("bias",[4*this.units],null,s,this.biasRegularizer,!0,this.biasConstraint)}else this.bias=null;this.built=!0}call(t,n){return e.tidy((()=>{const e=null!=n.training&&n.training;if(3!==t.length)throw new o(`LSTMCell expects 3 input Tensors (inputs, h, c), got ${t.length}.`);let s=t[1];const r=t[2];t=t[0],0i.onesLike(t),rate:this.dropout,training:e,count:4,dropoutFunc:this.dropoutFunc})),0i.onesLike(s),rate:this.recurrentDropout,training:e,count:4,dropoutFunc:this.dropoutFunc}));const a=this.dropoutMask,l=this.recurrentDropoutMask;let u,h,c,p;0{null!=this.cell.dropoutMask&&(i.dispose(this.cell.dropoutMask),this.cell.dropoutMask=null),null!=this.cell.recurrentDropoutMask&&(i.dispose(this.cell.recurrentDropoutMask),this.cell.recurrentDropoutMask=null);const e=null==n?null:n.mask,s=null==n?null:n.training,r=null==n?null:n.initialState;return super.call(t,{mask:e,training:s,initialState:r})}))}static fromConfig(t,e){return 0===e.implmentation&&(e.implementation=1),new t(e)}}oh.className="LSTM",e.serialization.registerClass(oh);class lh extends eh{constructor(t){super(t),this.cells=t.cells}get stateSize(){const t=[];for(const e of this.cells.slice().reverse())Array.isArray(e.stateSize)?t.push(...e.stateSize):t.push(e.stateSize);return t}call(t,n){return e.tidy((()=>{let e=t.slice(1);const s=[];for(const t of this.cells.slice().reverse())Array.isArray(t.stateSize)?s.push(e.splice(0,t.stateSize.length)):s.push(e.splice(0,1));s.reverse();const i=[];let r;for(let a=0;a{q(`RNNCell_${s}`,(()=>{n.build(t),e=Array.isArray(n.stateSize)?n.stateSize[0]:n.stateSize,t=[t[0],e]}))})),this.built=!0}getConfig(){const t=super.getConfig(),e={cells:this.cells.map((t=>({className:t.getClassName(),config:t.getConfig()})))};return Object.assign(Object.assign({},t),e)}static fromConfig(t,e,n={}){const s=[];for(const t of e.cells)s.push(el(t,n));return new t({cells:s})}get trainableWeights(){if(!this.trainable)return[];const t=[];for(const e of this.cells)t.push(...e.trainableWeights);return t}get nonTrainableWeights(){const t=[];for(const e of this.cells)t.push(...e.nonTrainableWeights);if(!this.trainable){const e=[];for(const t of this.cells)e.push(...t.trainableWeights);return e.concat(t)}return t}getWeights(){const t=[];for(const e of this.cells)t.push(...e.weights);return Kt(t)}setWeights(t){const e=[];for(const n of this.cells){const s=n.weights.length,i=t.splice(s);for(let t=0;tnull!=a?a(e(),n):gt(e(),n),l=()=>mt(o,e,s);if(!r||r<=1)return i.keep(l().clone());return Array(r).fill(void 0).map(l).map((t=>i.keep(t.clone())))}lh.className="StackedRNNCells",e.serialization.registerClass(lh);var hh=function(t,e){var n={};for(var s in t)Object.prototype.hasOwnProperty.call(t,s)&&e.indexOf(s)<0&&(n[s]=t[s]);if(null!=t&&"function"==typeof Object.getOwnPropertySymbols){var i=0;for(s=Object.getOwnPropertySymbols(t);i{if(null!=this.cell.dropoutMask&&(i.dispose(this.cell.dropoutMask),this.cell.dropoutMask=null),null!=this.cell.recurrentDropoutMask&&(i.dispose(this.cell.recurrentDropoutMask),this.cell.recurrentDropoutMask=null),e&&e.constants)throw new o("ConvRNN2D cell does not support constants");const n=null==e?null:e.mask,s=null==e?null:e.training,r=null==e?null:e.initialState;return super.call(t,{mask:n,training:s,initialState:r})}))}computeOutputShape(t){let e=this.computeSingleOutputShape(t);return this.returnSequences||(e=[e[0],...e.slice(2)]),this.returnState&&(e=[e,...Array(2).fill([t[0],...e.slice(-3)])]),e}getInitialState(t){return i.tidy((()=>{const{stateSize:e}=this.cell,n=t.shape,s=this.computeSingleOutputShape(n),r=[s[0],...s.slice(2)],a=i.zeros(r);return Array.isArray(e)?Array(e.length).fill(a):[a]}))}resetStates(t,n=!1){i.tidy((()=>{if(!this.stateful)throw new r("Cannot call resetStates() on an RNN Layer that is not stateful.");const s=this.inputSpec[0].shape,a=this.computeSingleOutputShape(s),l=[a[0],...a.slice(2)];if(null==s[0])throw new o("If an RNN is stateful, it needs to know its batch size. Specify the batch size of your input tensors: \n- If using a Sequential model, specify the batch size by passing a `batchInputShape` option to your first layer.\n- If using the functional API, specify the batch size by passing a `batchShape` option to your Input layer.");if(null==this.getStates())Array.isArray(this.cell.stateSize)?this.states_=this.cell.stateSize.map((()=>i.zeros(l))):this.states_=[i.zeros(l)];else if(null==t)i.dispose(this.states_),null!=this.keptStates&&(i.dispose(this.keptStates),this.keptStates=[]),Array.isArray(this.cell.stateSize)?this.states_=this.cell.stateSize.map((()=>i.zeros(l))):this.states_[0]=i.zeros(l);else{if(Array.isArray(t)||(t=[t]),t.length!==this.states_.length)throw new o(`Layer ${this.name} expects ${this.states_.length} state(s), but it received ${t.length} state value(s). Input received: ${t}`);n?this.keptStates.push(this.states_.slice()):i.dispose(this.states_);for(let n=0;ni.keep(t.clone())))}))}computeSingleOutputShape(t){const{dataFormat:e,filters:n,kernelSize:s,padding:i,strides:r,dilationRate:a}=this.cell,o="channelsFirst"===e,l=t[o?3:2],u=t[o?4:3],h=_u(l,s[0],i,r[0],a[0]),c=_u(u,s[1],i,r[1],a[1]);return[...t.slice(0,2),...o?[n,h,c]:[h,c,n]]}}ch.className="ConvRNN2D";class ph extends ah{constructor(t){const{filters:e,kernelSize:n,strides:s,padding:i,dataFormat:r,dilationRate:a}=t;super(Object.assign(Object.assign({},t),{units:e})),this.filters=e,z(this.filters,"filters"),this.kernelSize=Lu(n,2,"kernelSize"),this.kernelSize.forEach((t=>z(t,"kernelSize"))),this.strides=Lu(s||1,2,"strides"),this.strides.forEach((t=>z(t,"strides"))),this.padding=i||"valid",U(this.padding),this.dataFormat=r||"channelsLast",P(this.dataFormat),this.dilationRate=Lu(a||1,2,"dilationRate"),this.dilationRate.forEach((t=>z(t,"dilationRate")))}build(t){var e;t=Wt(t);const n="channelsFirst"===this.dataFormat?1:t.length-1;if(null==t[n])throw new o(`The channel dimension of the input should be defined. Found ${t[n]}`);const s=t[n],r=this.kernelSize.concat([s,4*this.filters]);this.kernel=this.addWeight("kernel",r,null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint);const a=this.kernelSize.concat([this.filters,4*this.filters]);if(this.recurrentKernel=this.addWeight("recurrent_kernel",a,null,this.recurrentInitializer,this.recurrentRegularizer,!0,this.recurrentConstraint),this.useBias){let t;if(this.unitForgetBias){const n=this.biasInitializer,s=this.filters;t=new((e=class extends wt{apply(t,e){return at([n.apply([s]),i.ones([s]),n.apply([2*s])])}}).className="CustomInit",e)}else t=this.biasInitializer;this.bias=this.addWeight("bias",[4*this.filters],null,t,this.biasRegularizer,!0,this.biasConstraint)}this.built=!0}call(t,e){return i.tidy((()=>{if(3!==t.length)throw new o(`ConvLSTM2DCell expects 3 input Tensors (inputs, h, c), got ${t.length}.`);const n=e.training||!1,s=t[0],r=t[1],a=t[2];0i.onesLike(s),rate:this.dropout,training:n,count:4,dropoutFunc:this.dropoutFunc}));const l=this.dropoutMask,u=(t,e,n)=>e&&e[n]?i.mul(e[n],t):t;let h=u(s,l,0),c=u(s,l,1),p=u(s,l,2),d=u(s,l,3);0i.onesLike(r),rate:this.recurrentDropout,training:n,count:4,dropoutFunc:this.dropoutFunc}));const f=this.recurrentDropoutMask;let g=u(r,f,0),m=u(r,f,1),y=u(r,f,2),b=u(r,f,3);const[w,k,v,S]=i.split(this.kernel.read(),4,3),[x,N,I,A]=this.useBias?i.split(this.bias.read(),4):[null,null,null,null];h=this.inputConv(h,w,x,this.padding),c=this.inputConv(c,k,N,this.padding),p=this.inputConv(p,v,I,this.padding),d=this.inputConv(d,S,A,this.padding);const[z,E,T,C]=i.split(this.recurrentKernel.read(),4,3);g=this.recurrentConv(g,z),m=this.recurrentConv(m,E),y=this.recurrentConv(y,T),b=this.recurrentConv(b,C);const $=this.recurrentActivation.apply(i.add(h,g)),F=this.recurrentActivation.apply(i.add(c,m)),D=i.add(i.mul(F,a),i.mul($,this.activation.apply(i.add(p,y)))),L=i.mul(this.recurrentActivation.apply(i.add(d,b)),this.activation.apply(D));return[L,L,D]}))}getConfig(){const t=super.getConfig(),e=hh(t,["units"]),n={filters:this.filters,kernelSize:this.kernelSize,padding:this.padding,dataFormat:this.dataFormat,dilationRate:this.dilationRate,strides:this.strides};return Object.assign(Object.assign({},e),n)}inputConv(t,e,n,s){const r=i.conv2d(t,e,this.strides,s||"valid","channelsFirst"===this.dataFormat?"NCHW":"NHWC",this.dilationRate);return n?ft(r,n,this.dataFormat):r}recurrentConv(t,e){return i.conv2d(t,e,1,"same","channelsFirst"===this.dataFormat?"NCHW":"NHWC")}}ph.className="ConvLSTM2DCell",i.serialization.registerClass(ph);class dh extends ch{constructor(t){const e=new ph(t);super(Object.assign(Object.assign({},t),{cell:e}))}static fromConfig(t,e){return new t(e)}}dh.className="ConvLSTM2D",i.serialization.registerClass(dh);class fh extends Qt{constructor(t){super(t),this.rate=Math.max(Math.min(t.rate,1),0),this.noiseShape=t.noiseShape,this.seed=t.seed,this.supportsMasking=!0}getNoiseShape(t){if(null==this.noiseShape)return this.noiseShape;const e=t.shape,n=[];for(let t=0;t{this.invokeCallHook(t,n);const e=Ut(t);if(0gt(e,this.rate,s,this.seed)),(()=>e),t)}return t}))}getConfig(){const t={rate:this.rate,noiseShape:this.noiseShape,seed:this.seed},e=super.getConfig();return Object.assign(t,e),t}dispose(){return super.dispose()}}fh.className="Dropout",e.serialization.registerClass(fh);class gh extends fh{constructor(t){super(t),this.inputSpec=[{ndim:3}]}getNoiseShape(t){const e=t.shape;return[e[0],1,e[2]]}}gh.className="SpatialDropout1D",e.serialization.registerClass(gh);class mh extends Qt{constructor(t){if(super(t),this.activation=null,this.useBias=!0,this.kernel=null,this.bias=null,this.DEFAULT_KERNEL_INITIALIZER="glorotNormal",this.DEFAULT_BIAS_INITIALIZER="zeros",null==t.batchInputShape&&null==t.inputShape&&null!=t.inputDim){let e=null;null!=t.batchSize&&(e=t.batchSize),this.batchInputShape=[e,t.inputDim]}this.units=t.units,z(this.units,"units"),this.activation=ku(t.activation),null!=t.useBias&&(this.useBias=t.useBias),this.kernelInitializer=Mt(t.kernelInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.biasInitializer=Mt(t.biasInitializer||this.DEFAULT_BIAS_INITIALIZER),this.kernelConstraint=Uo(t.kernelConstraint),this.biasConstraint=Uo(t.biasConstraint),this.kernelRegularizer=zu(t.kernelRegularizer),this.biasRegularizer=zu(t.biasRegularizer),this.activityRegularizer=zu(t.activityRegularizer),this.supportsMasking=!0,this.inputSpec=[{minNDim:2}]}build(t){const e=(t=Wt(t))[t.length-1];null==this.kernel&&(this.kernel=this.addWeight("kernel",[e,this.units],null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.useBias&&(this.bias=this.addWeight("bias",[this.units],null,this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint))),this.inputSpec=[{minNDim:2,axes:{[-1]:e}}],this.built=!0}computeOutputShape(t){const e=(t=Wt(t)).slice();return e[e.length-1]=this.units,e}call(t,n){return e.tidy((()=>{this.invokeCallHook(t,n);const e=Ut(t),s=T(this.activation.getClassName());let i;return null!=s?i=ht(e,this.kernel.read(),s,this.bias?this.bias.read():null):(i=ht(e,this.kernel.read()),null!=this.bias&&(i=ft(i,this.bias.read())),null!=this.activation&&(i=this.activation.apply(i))),i}))}getConfig(){const t={units:this.units,activation:bu(this.activation),useBias:this.useBias,kernelInitializer:Ot(this.kernelInitializer),biasInitializer:Ot(this.biasInitializer),kernelRegularizer:Iu(this.kernelRegularizer),biasRegularizer:Iu(this.biasRegularizer),activityRegularizer:Iu(this.activityRegularizer),kernelConstraint:Bo(this.kernelConstraint),biasConstraint:Bo(this.biasConstraint)},e=super.getConfig();return Object.assign(t,e),t}}mh.className="Dense",e.serialization.registerClass(mh);class yh extends Qt{constructor(t){super(t=t||{}),this.inputSpec=[{minNDim:3}],this.dataFormat=t.dataFormat}computeOutputShape(t){t=Wt(t);for(const e of t.slice(1))if(null==e)throw new o(`The shape of the input to "Flatten" is not fully defined (got ${t.slice(1)}). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model.`);return[t[0],J(t,1)]}call(t,n){return e.tidy((()=>{this.invokeCallHook(t,n);let s=Ut(t);if("channelsFirst"===this.dataFormat&&s.rank>1){const t=[0];for(let e=2;e{this.invokeCallHook(t,n);const e=Ut(t);return this.activation.apply(e)}))}getConfig(){const t={activation:bu(this.activation)},e=super.getConfig();return Object.assign(t,e),t}}bh.className="Activation",e.serialization.registerClass(bh);class wh extends Qt{constructor(t){super(t),this.n=t.n,this.inputSpec=[{ndim:2}]}computeOutputShape(t){return[t[0],this.n,t[1]]}call(t,n){return e.tidy((()=>{return t=Ut(t),n=t,s=this.n,e.tidy((()=>{if(2!==n.shape.length)throw new o(`repeat() expects a rank-2 tensor, but received a rank-${n.shape.length} tensor.`);return lt(nt(n,1),[1,s,1])}));var n,s}))}getConfig(){const t={n:this.n},e=super.getConfig();return Object.assign(t,e),t}}wh.className="RepeatVector",e.serialization.registerClass(wh);class kh extends Qt{constructor(t){super(t),this.targetShape=t.targetShape;for(let t=0;t{this.invokeCallHook(t,n);const s=Ut(t),i=s.shape,r=i.slice(0,1).concat(this.fixUnknownDimension(i.slice(1),this.targetShape));return e.reshape(s,r)}))}getConfig(){const t={targetShape:this.targetShape},e=super.getConfig();return Object.assign(t,e),t}}kh.className="Reshape",e.serialization.registerClass(kh);class vh extends Qt{constructor(t){if(super(t),null==t.dims)throw new Error("Required configuration field `dims` is missing during Permute constructor call.");if(!Array.isArray(t.dims))throw new Error(`Permute constructor requires \`dims\` to be an Array, but received ${t.dims} instead.`);const n=X(1,t.dims.length+1);if(!e.util.arraysEqual(t.dims.slice().sort(),n))throw new Error("Invalid permutation `dims`: "+JSON.stringify(t.dims)+" `dims` must contain consecutive integers starting from 1.");this.dims=t.dims,this.dimsIncludingBatch=[0].concat(this.dims),this.inputSpec=[new Ht({ndim:this.dims.length+1})]}computeOutputShape(t){const e=(t=Wt(t)).slice();return this.dims.forEach(((n,s)=>{e[s+1]=t[n]})),e}call(t,n){return e.transpose(Ut(t),this.dimsIncludingBatch)}getConfig(){const t={dims:this.dims},e=super.getConfig();return Object.assign(t,e),t}}vh.className="Permute",e.serialization.registerClass(vh);class Sh extends Qt{constructor(t){super(null==t?{}:t),this.supportsMasking=!0,this.maskValue=null!=t?null==t.maskValue?0:t.maskValue:0}computeOutputShape(t){return t}getConfig(){const t=super.getConfig(),e={maskValue:this.maskValue};return Object.assign(e,t),e}computeMask(t,n){const s=Ut(t);return e.any(e.notEqual(s,this.maskValue),-1)}call(t,n){return e.tidy((()=>{this.invokeCallHook(t,n);const s=Ut(t),i=e.any(e.notEqual(s,this.maskValue),-1,!0);return e.mul(s,e.cast(i,s.dtype))}))}}Sh.className="Masking",e.serialization.registerClass(Sh);class xh extends Qt{constructor(t){if(super(t),this.embeddings=null,this.DEFAULT_EMBEDDINGS_INITIALIZER="randomUniform",null==t.batchInputShape&&null==t.inputShape){let e=null;null!=t.batchSize&&(e=t.batchSize),null==t.inputLength?this.batchInputShape=[e,null]:this.batchInputShape=[e].concat(g(t.inputLength))}this.inputDim=t.inputDim,z(this.inputDim,"inputDim"),this.outputDim=t.outputDim,z(this.outputDim,"outputDim"),this.embeddingsInitializer=Mt(t.embeddingsInitializer||this.DEFAULT_EMBEDDINGS_INITIALIZER),this.embeddingsRegularizer=zu(t.embeddingsRegularizer),this.activityRegularizer=zu(t.activityRegularizer),this.embeddingsConstraint=Uo(t.embeddingsConstraint),this.maskZero=t.maskZero,this.supportsMasking=t.maskZero,this.inputLength=t.inputLength}build(t){this.embeddings=this.addWeight("embeddings",[this.inputDim,this.outputDim],this.dtype,this.embeddingsInitializer,this.embeddingsRegularizer,!0,this.embeddingsConstraint),this.built=!0}warnOnIncompatibleInputShape(t){}computeMask(t,n){return e.tidy((()=>this.maskZero?(t=Ut(t),e.notEqual(t,e.zerosLike(t))):null))}computeOutputShape(t){if(t=Wt(t),null==this.inputLength)return[...t,this.outputDim];const e=g(this.inputLength);if(e.length!==t.length-1)throw new o(`"inputLength" is ${this.inputLength}, but received input shape has shape ${t}`);{let n=0;for(let s=0;s{this.invokeCallHook(t,n);let s=Ut(t);"int32"!==s.dtype&&(s=et(s,"int32"));const i=ct(this.embeddings.read(),e.reshape(s,[s.size]));return e.reshape(i,Wt(this.computeOutputShape(s.shape)))}))}getConfig(){const t={inputDim:this.inputDim,outputDim:this.outputDim,embeddingsInitializer:Ot(this.embeddingsInitializer),embeddingsRegularizer:Iu(this.embeddingsRegularizer),activityRegularizer:Iu(this.activityRegularizer),embeddingsConstraint:Bo(this.embeddingsConstraint),maskZero:this.maskZero,inputLength:this.inputLength},e=super.getConfig();return Object.assign(t,e),t}}xh.className="Embedding",e.serialization.registerClass(xh);class Nh extends Qt{constructor(t){super(t||{}),this.supportsMasking=!0}mergeFunction(t){throw new l}computeElementwiseOpOutputShape(t,e){if(null==t||null==e)return null;if(t.length1)throw new o(`Can not merge tensors with different batch sizes. Got tensors with shapes: ${JSON.stringify(t)}.`);let n=null==t[0]?null:t[0].slice(1);for(let e=1;et.length));-1===t.indexOf(null)&&1===x(s).length?this.reshapeRequired=!1:this.reshapeRequired=!0}call(t,n){return e.tidy((()=>{if(this.reshapeRequired){const e=[],n=t.map((t=>t.rank));if(-1===n.indexOf(null)){const s=Y(n);for(let n of t){const t=n.rank;for(let e=0;e1){const r=X(1,t).concat([0]);e.push(i.transpose(s,r)),n=!0}else e.push(s)}let s=this.mergeFunction(e);const r=s.rank;if(n)if(null==r){const t=s.shape,e=t[t.length-1],n=[e].concat(t.slice(0,t.length-1));s=i.reshape(i.transpose(i.reshape(s,[-1,e]),[1,0]),n)}else if(r>1){const t=[r-1].concat(X(0,r-1));s=i.transpose(s,t)}return s}}return this.mergeFunction(t)}))}computeOutputShape(t){let e;e=null==t[0]?null:t[0].slice(1);for(let n=1;n{if(null==e)return null;if(!Array.isArray(e))throw new o("`mask` should be an Array");if(!Array.isArray(t))throw new o("`inputs` should be an Array");if(e.length!==t.length)throw new o(`The Array 'inputs' and 'mask' are expected to have the same length, but have different lengths (${t.length} vs ${e.length})`);if(e.every((t=>null==t)))return null;let n=(e=e.map((t=>null==t?t:i.expandDims(t,0))))[0];for(let t=1;t{let e=t[0].clone();for(let n=1;n{let e=t[0].clone();for(let n=1;n{let e=t[0].clone();for(let n=1;n{let e=t[0];for(let n=1;n{let e=t[0];for(let n=1;n1)throw new o("A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got input shapes: "+JSON.stringify(t))}mergeFunction(t){return e.tidy((()=>at(t,this.axis)))}computeOutputShape(t){if(!Array.isArray(t)||!Array.isArray(t[0]))throw new o("A `Concatenate` layer should be called on a list of inputs.");const e=t,n=e[0].slice(),s=this.axis<0?n.length+this.axis:this.axis;for(const t of e.slice(1)){if(null==n[s]||null==t[s]){n[s]=null;break}n[s]+=t[s]}return n}computeMask(t,e){if(null==e)return null;if(!Array.isArray(e))throw new o("`mask` should be an array for Concatenate");if(!Array.isArray(t))throw new o("`inputs` should be an array for Concatenate");if(e.length!==t.length)throw new o(`Mismatch in the length of mask (${e.length}) and the legnth of inputs (${t.length})`);return i.tidy((()=>{let n=!0;if(e.forEach((t=>{null==t||(n=!1)})),n)return null;const s=[];for(let n=0;n"A `Dot` layer should be called on a list of exactly 2 inputs."));const e=t[0],n=t[1];if(e.length>3||n.length>3)throw new l("Dot layer does not support tensors of 4D or higher rank yet.");const s=this.interpretAxes(e,n);if(e[s[0]]!==n[s[1]])throw new o(`Dimension incompatibility: ${e[s[0]]} !== ${n[s[1]]}`)}mergeFunction(t){if(2!==t.length)throw new o(`A \`Dot\` layer must be called on exactly 2 inputs, but received ${t.length} input(s).`);let e,n=t[0],s=t[1];return e=Array.isArray(this.axes)?this.axes.map(((e,n)=>$h(e,t[n].shape.length))):[$h(this.axes,n.shape.length),$h(this.axes,s.shape.length)],this.normalize&&(n=nl(n,e[0]),s=nl(s,e[1])),function(t,e,n){if(t.shape.length>3||e.shape.length>3)throw new l("batchDot is not implemented for tensors of 4D or higher rank yet");if(i.util.assert(t.shape.length>=2,(()=>`batchDot requires the rank of x to be >= 2, but got ${t.shape.length}`)),i.util.assert(t.shape.length>=2,(()=>`batchDot requires the rank of y to be >= 2, but got ${e.shape.length}`)),"number"==typeof n&&(n=[n,n]),"complex64"===t.dtype||"complex64"===e.dtype)throw new l("batchDot is not implemented for complex64-type Tensors yet.");const s=t.shape.length,r=e.shape.length;null==n&&(n=[s-1,r-2]);const a=n;return i.tidy((()=>{let n,o;if(s>r){n=s-r;const t=[];for(let e=0;es){n=r-s;const e=[];for(let t=0;t0){let t;t=s>r?s+r-3:s-1;const e=[];for(let s=t;s"A `Dot` layer should be called on a list of exactly 2 inputs."));const e=t[0].slice(),n=t[1].slice();if(e.length>3||n.length>3)throw new l("Dot layer does not support tensors of 4D or higher rank yet.");const s=this.interpretAxes(e,n);e.splice(s[0],1),n.splice(s[1],1),n.splice(0,1);const r=e.concat(n);return 1===r.length&&r.push(1),r}computeMask(t,e){return null}getConfig(){const t={axes:this.axes,normalize:this.normalize},e=super.getConfig();return Object.assign(t,e),t}}Fh.className="Dot",e.serialization.registerClass(Fh);class Dh extends Qt{constructor(t){super(t),this.supportsMasking=!0,this.stddev=t.stddev}computeOutputShape(t){return t}getConfig(){const t=super.getConfig(),e={stddev:this.stddev};return Object.assign(e,t),e}call(t,n){return e.tidy((()=>{this.invokeCallHook(t,n);const s=Ut(t);return mt((()=>e.add(ut(s.shape,0,this.stddev),s)),(()=>s),n.training||!1)}))}}Dh.className="GaussianNoise",e.serialization.registerClass(Dh);class Lh extends Qt{constructor(t){super(t),this.supportsMasking=!0,this.rate=t.rate}computeOutputShape(t){return t}getConfig(){const t=super.getConfig(),e={rate:this.rate};return Object.assign(e,t),e}call(t,n){return e.tidy((()=>{this.invokeCallHook(t,n);const s=Ut(t);if(this.rate>0&&this.rate<1){return mt((()=>{const t=Math.sqrt(this.rate/(1-this.rate));return e.mul(s,ut(s.shape,1,t))}),(()=>s),n.training||!1)}return s}))}}Lh.className="GaussianDropout",e.serialization.registerClass(Lh);class _h extends Qt{constructor(t){super(t),this.supportsMasking=!0,this.rate=t.rate,this.noiseShape=t.noiseShape}_getNoiseShape(t){return this.noiseShape||Ut(t).shape}computeOutputShape(t){return t}getConfig(){const t=super.getConfig(),e={rate:this.rate};return Object.assign(e,t),e}call(t,n){return e.tidy((()=>{if(this.rate<1&&this.rate>0){const s=this._getNoiseShape(t),i=()=>{const n=Ut(t),i=-1.7580993408473766;let r=e.greaterEqual(e.randomUniform(s),this.rate);r=et(r,"float32");const a=((1-this.rate)*(1+this.rate*i**2))**-.5,o=-a*i*this.rate,l=e.add(e.mul(n,r),e.mul(e.add(r,-1),i));return e.add(e.mul(l,a),o)};return mt(i,(()=>Ut(t)),n.training||!1)}return t}))}}function Rh(t,e,n,s,r,a=.001){let o;if(2===t.rank)o=i.batchNorm2d(t,e,n,s,r,a);else if(3===t.rank)o=i.batchNorm3d(t,e,n,s,r,a);else{if(4!==t.rank)throw new l(`batchNormalization is not implemented for array of rank ${t.rank} yet`);o=i.batchNorm4d(t,e,n,s,r,a)}return o}function Oh(t,n,s,r,a=.001){return e.util.arraysEqual(r.slice().sort(),X(0,t.rank-1))?function(t,n,s,r,a=.001){return e.tidy((()=>{const e=i.moments(t,r),o=e.mean,l=e.variance;return[Rh(t,o,l,s,n,a),o,l]}))}(t,n,s,r,a):function(t,n,s,r,a=.001){return e.tidy((()=>{const o=i.moments(t,r),l=o.mean,u=o.variance,h=[];for(const e of X(0,t.rank))-1!==r.indexOf(e)?h.push(1):h.push(t.shape[e]);const c=e.reshape(l,h),p=e.reshape(u,h),d=null==n?null:e.reshape(n,h),f=null==s?null:e.reshape(s,h);return[Rh(t,c,p,f,d,a),l,u]}))}(t,n,s,r,a)}_h.className="AlphaDropout",e.serialization.registerClass(_h);class Mh extends Qt{constructor(t){null==t&&(t={}),super(t),this.supportsMasking=!0,this.axis=null==t.axis?-1:t.axis,this.momentum=null==t.momentum?.99:t.momentum,this.epsilon=null==t.epsilon?.001:t.epsilon,this.center=null==t.center||t.center,this.scale=null==t.scale||t.scale,this.betaInitializer=Mt(t.betaInitializer||"zeros"),this.gammaInitializer=Mt(t.gammaInitializer||"ones"),this.movingMeanInitializer=Mt(t.movingMeanInitializer||"zeros"),this.movingVarianceInitializer=Mt(t.movingVarianceInitializer||"ones"),this.betaConstraint=Uo(t.betaConstraint),this.gammaConstraint=Uo(t.gammaConstraint),this.betaRegularizer=zu(t.betaRegularizer),this.gammaRegularizer=zu(t.gammaRegularizer)}build(t){t=Wt(t);const e=this.axis>=0?this.axis:this.axis+t.length,n=t[e];if(null==n)throw new o(`Axis ${e} of input tensor should have a defined dimension but the layer received an input with shape ${JSON.stringify(t)}.`);this.inputSpec=[new Ht({ndim:t.length,axes:{[e]:n}})];const s=[n];this.scale&&(this.gamma=this.addWeight("gamma",s,null,this.gammaInitializer,this.gammaRegularizer,!0,this.gammaConstraint)),this.center&&(this.beta=this.addWeight("beta",s,null,this.betaInitializer,this.betaRegularizer,!0,this.betaConstraint)),this.movingMean=this.addWeight("moving_mean",s,null,this.movingMeanInitializer,null,!1),this.movingVariance=this.addWeight("moving_variance",s,null,this.movingVarianceInitializer,null,!1),this.built=!0}call(t,n){return e.tidy((()=>{const s=null!=n.training&&n.training,r=Ut(t),a=r.shape,o=a.length,l=X(0,o),u=this.axis>=0?this.axis:this.axis+o;l.splice(u,1);const h=c(1,o);h[u]=a[u];const p=l.slice();p.sort();const d=!e.util.arraysEqual(p,X(0,o).slice(0,o-1));if(!s)return(()=>{if(d){const t=e.reshape(this.movingMean.read(),h),n=e.reshape(this.movingVariance.read(),h),s=this.center?e.reshape(this.beta.read(),h):null,i=this.scale?e.reshape(this.gamma.read(),h):null;return Rh(r,t,n,s,i,this.epsilon)}return Rh(r,this.movingMean.read(),this.movingVariance.read(),null==this.beta?null:this.beta.read(),null==this.gamma?null:this.gamma.read(),this.epsilon)})();const[f,g,m]=Oh(r,this.gamma.read(),this.beta.read(),l,this.epsilon),y=(t,e,n)=>{i.tidy((()=>{const s=1-n,r=t.read(),a=i.mul(i.sub(r,e),s);t.write(i.sub(r,a))}))};return(()=>{y(this.movingMean,g,this.momentum),y(this.movingVariance,m,this.momentum)})(),f}))}getConfig(){const t={axis:this.axis,momentum:this.momentum,epsilon:this.epsilon,center:this.center,scale:this.scale,betaInitializer:Ot(this.betaInitializer),gammaInitializer:Ot(this.gammaInitializer),movingMeanInitializer:Ot(this.movingMeanInitializer),movingVarianceInitializer:Ot(this.movingVarianceInitializer),betaRegularizer:Iu(this.betaRegularizer),gammaRegularizer:Iu(this.gammaRegularizer),betaConstraint:Bo(this.betaConstraint),gammaConstraint:Bo(this.gammaConstraint)},e=super.getConfig();return Object.assign(t,e),t}}Mh.className="BatchNormalization",e.serialization.registerClass(Mh);class Bh extends Qt{constructor(t){if(null==t&&(t={}),super(t),this.axis=null==t.axis?-1:t.axis,"number"==typeof this.axis){if(!Number.isInteger(this.axis))throw new Error(`Expected axis to be an integer, but received ${this.axis}`)}else{if(!Array.isArray(this.axis))throw new Error(`Expected axis to be an integer or an array of integers, but received ${JSON.stringify(this.axis)}`);for(const t of this.axis)if(!Number.isInteger(t))throw new Error(`Expected axis to be an array of integers, but received ${JSON.stringify(this.axis)}`)}this.epsilon=null==t.epsilon?.001:t.epsilon,this.center=null==t.center||t.center,this.scale=null==t.scale||t.scale,this.betaInitializer=Mt(t.betaInitializer||"zeros"),this.gammaInitializer=Mt(t.gammaInitializer||"ones"),this.betaRegularizer=zu(t.betaRegularizer),this.gammaRegularizer=zu(t.gammaRegularizer),this.supportsMasking=!0}build(t){const e=(t=Wt(t)).length;"number"==typeof this.axis&&(this.axis=[this.axis]);for(let t=0;t=e)throw new Error(`Invalid axis: ${t}`);if(this.axis.length!==x(this.axis).length)throw new Error(`Found duplicate axes in: ${this.axis}`);const n=this.axis.map((e=>t[e]));this.scale?this.gamma=this.addWeight("gamma",n,"float32",this.gammaInitializer,this.gammaRegularizer,true):this.gamma=null,this.center?this.beta=this.addWeight("beta",n,"float32",this.betaInitializer,this.betaRegularizer,true):this.beta=null,this.built=!0}call(t,n){const s=Ut(t),r=s.shape,a=r.length;return e.tidy((()=>{let{mean:t,variance:n}=e.moments(s,this.axis,!0);const o=c(1,a);for(const t of this.axis)o[t]=r[t];const l=t=>null!=t&&t.shape.length!==a?i.reshape(t,o):t;let u=this.scale?l(this.gamma.read()):null,h=this.center?l(this.beta.read()):null;const p=[],d=[];for(let t=0;t=0?t[2]+this.padding[0][0]+this.padding[0][1]:null,n=null!=t[3]&&t[3]>=0?t[3]+this.padding[1][0]+this.padding[1][1]:null,[t[0],t[1],e,n]):(e=null!=t[1]&&t[1]>=0?t[1]+this.padding[0][0]+this.padding[0][1]:null,n=null!=t[2]&&t[2]>=0?t[2]+this.padding[1][0]+this.padding[1][1]:null,[t[0],e,n,t[3]])}call(t,n){return e.tidy((()=>{return n=Ut(t),s=this.padding,r=this.dataFormat,e.tidy((()=>{if(4!==n.rank)throw new o(`temporalPadding expects input tensor to be 4-D, but received a ${n.rank}-D tensor.`);if(null==s&&(s=[[1,1],[1,1]]),2!==s.length||2!==s[0].length||2!==s[1].length)throw new o("spatial2dPadding expects `padding` to be an Array of two Arrays, each of which is an Array of two integers.");if(null==r&&(r="channelsLast"),"channelsLast"!==r&&"channelsFirst"!==r)throw new o(`Unknown data format: ${r}. Supported data formats are 'channelsLast' and 'channelsFirst.`);let t;return t="channelsFirst"===r?[[0,0],[0,0],s[0],s[1]]:[[0,0],s[0],s[1],[0,0]],i.pad(n,t)}));var n,s,r}))}getConfig(){const t={padding:this.padding,dataFormat:this.dataFormat},e=super.getConfig();return Object.assign(t,e),t}}function Uh(t,n,s,r,a,o){return e.tidy((()=>{let e;P(a),W(o),U(r),null==s&&(s=[1,1]),null==r&&(r="valid"),null==a&&(a="channelsLast"),null==o&&(o="max"),t=Ou(t,a);const l="same"===r?"same":"valid";return e="max"===o?i.maxPool(t,n,s,l):i.avgPool(t,n,s,l),"channelsFirst"===a&&(e=i.transpose(e,[0,3,1,2])),e}))}function Wh(t,n,s,r,a,o){return e.tidy((()=>{let e;P(a),W(o),U(r),null==s&&(s=[1,1,1]),null==r&&(r="valid"),null==a&&(a="channelsLast"),null==o&&(o="max"),t=Mu(t,a);const l="same"===r?"same":"valid";return e="max"===o?i.maxPool3d(t,n,s,l):i.avgPool3d(t,n,s,l),"channelsFirst"===a&&(e=i.transpose(e,[0,4,1,2,3])),e}))}Ph.className="ZeroPadding2D",e.serialization.registerClass(Ph);class jh extends Qt{constructor(t){if(null==t.poolSize&&(t.poolSize=2),super(t),"number"==typeof t.poolSize)this.poolSize=[t.poolSize];else{if(!Array.isArray(t.poolSize)||1!==t.poolSize.length||"number"!=typeof t.poolSize[0])throw new o(`poolSize for 1D convolutional layer must be a number or an Array of a single number, but received ${JSON.stringify(t.poolSize)}`);this.poolSize=t.poolSize}if(z(this.poolSize,"poolSize"),null==t.strides)this.strides=this.poolSize;else if("number"==typeof t.strides)this.strides=[t.strides];else{if(!Array.isArray(t.strides)||1!==t.strides.length||"number"!=typeof t.strides[0])throw new o(`strides for 1D convolutional layer must be a number or an Array of a single number, but received ${JSON.stringify(t.strides)}`);this.strides=t.strides}z(this.strides,"strides"),this.padding=null==t.padding?"valid":t.padding,U(this.padding),this.inputSpec=[new Ht({ndim:3})]}computeOutputShape(t){const e=_u((t=Wt(t))[1],this.poolSize[0],this.padding,this.strides[0]);return[t[0],e,t[2]]}call(t,n){return e.tidy((()=>{this.invokeCallHook(t,n),t=nt(Ut(t),2);const e=this.poolingFunction(Ut(t),[this.poolSize[0],1],[this.strides[0],1],this.padding,"channelsLast");return i.squeeze(e,[2])}))}getConfig(){const t={poolSize:this.poolSize,padding:this.padding,strides:this.strides},e=super.getConfig();return Object.assign(t,e),t}}class qh extends jh{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return P(i),U(s),Uh(t,e,n,s,i,"max")}}qh.className="MaxPooling1D",e.serialization.registerClass(qh);class Vh extends jh{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return P(i),U(s),Uh(t,e,n,s,i,"avg")}}Vh.className="AveragePooling1D",e.serialization.registerClass(Vh);class Kh extends Qt{constructor(t){if(null==t.poolSize&&(t.poolSize=[2,2]),super(t),this.poolSize=Array.isArray(t.poolSize)?t.poolSize:[t.poolSize,t.poolSize],null==t.strides)this.strides=this.poolSize;else if(Array.isArray(t.strides)){if(2!==t.strides.length)throw new o(`If the strides property of a 2D pooling layer is an Array, it is expected to have a length of 2, but received length ${t.strides.length}.`);this.strides=t.strides}else this.strides=[t.strides,t.strides];z(this.poolSize,"poolSize"),z(this.strides,"strides"),this.padding=null==t.padding?"valid":t.padding,this.dataFormat=null==t.dataFormat?"channelsLast":t.dataFormat,P(this.dataFormat),U(this.padding),this.inputSpec=[new Ht({ndim:4})]}computeOutputShape(t){t=Wt(t);let e="channelsFirst"===this.dataFormat?t[2]:t[1],n="channelsFirst"===this.dataFormat?t[3]:t[2];return e=_u(e,this.poolSize[0],this.padding,this.strides[0]),n=_u(n,this.poolSize[1],this.padding,this.strides[1]),"channelsFirst"===this.dataFormat?[t[0],t[1],e,n]:[t[0],e,n,t[3]]}call(t,n){return e.tidy((()=>(this.invokeCallHook(t,n),this.poolingFunction(Ut(t),this.poolSize,this.strides,this.padding,this.dataFormat))))}getConfig(){const t={poolSize:this.poolSize,padding:this.padding,strides:this.strides,dataFormat:this.dataFormat},e=super.getConfig();return Object.assign(t,e),t}}class Gh extends Kh{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return P(i),U(s),Uh(t,e,n,s,i,"max")}}Gh.className="MaxPooling2D",e.serialization.registerClass(Gh);class Hh extends Kh{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return P(i),U(s),Uh(t,e,n,s,i,"avg")}}Hh.className="AveragePooling2D",e.serialization.registerClass(Hh);class Jh extends Qt{constructor(t){if(null==t.poolSize&&(t.poolSize=[2,2,2]),super(t),this.poolSize=Array.isArray(t.poolSize)?t.poolSize:[t.poolSize,t.poolSize,t.poolSize],null==t.strides)this.strides=this.poolSize;else if(Array.isArray(t.strides)){if(3!==t.strides.length)throw new o(`If the strides property of a 3D pooling layer is an Array, it is expected to have a length of 3, but received length ${t.strides.length}.`);this.strides=t.strides}else this.strides=[t.strides,t.strides,t.strides];z(this.poolSize,"poolSize"),z(this.strides,"strides"),this.padding=null==t.padding?"valid":t.padding,this.dataFormat=null==t.dataFormat?"channelsLast":t.dataFormat,P(this.dataFormat),U(this.padding),this.inputSpec=[new Ht({ndim:5})]}computeOutputShape(t){t=Wt(t);let e="channelsFirst"===this.dataFormat?t[2]:t[1],n="channelsFirst"===this.dataFormat?t[3]:t[2],s="channelsFirst"===this.dataFormat?t[4]:t[3];return e=_u(e,this.poolSize[0],this.padding,this.strides[0]),n=_u(n,this.poolSize[1],this.padding,this.strides[1]),s=_u(s,this.poolSize[2],this.padding,this.strides[2]),"channelsFirst"===this.dataFormat?[t[0],t[1],e,n,s]:[t[0],e,n,s,t[4]]}call(t,n){return e.tidy((()=>(this.invokeCallHook(t,n),this.poolingFunction(Ut(t),this.poolSize,this.strides,this.padding,this.dataFormat))))}getConfig(){const t={poolSize:this.poolSize,padding:this.padding,strides:this.strides,dataFormat:this.dataFormat},e=super.getConfig();return Object.assign(t,e),t}}class Zh extends Jh{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return P(i),U(s),Wh(t,e,n,s,i,"max")}}Zh.className="MaxPooling3D",e.serialization.registerClass(Zh);class Yh extends Jh{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return P(i),U(s),Wh(t,e,n,s,i,"avg")}}Yh.className="AveragePooling3D",e.serialization.registerClass(Yh);class Xh extends Qt{constructor(t){super(t),this.inputSpec=[new Ht({ndim:3})]}computeOutputShape(t){return[t[0],t[2]]}call(t,e){throw new l}}class Qh extends Xh{constructor(t){super(t||{})}call(t,n){return e.tidy((()=>{const e=Ut(t);return i.mean(e,1)}))}}Qh.className="GlobalAveragePooling1D",e.serialization.registerClass(Qh);class tc extends Xh{constructor(t){super(t||{})}call(t,n){return e.tidy((()=>{const e=Ut(t);return i.max(e,1)}))}}tc.className="GlobalMaxPooling1D",e.serialization.registerClass(tc);class ec extends Qt{constructor(t){super(t),this.dataFormat=null==t.dataFormat?"channelsLast":t.dataFormat,P(this.dataFormat),this.inputSpec=[new Ht({ndim:4})]}computeOutputShape(t){return"channelsLast"===this.dataFormat?[t[0],t[3]]:[t[0],t[1]]}call(t,e){throw new l}getConfig(){const t={dataFormat:this.dataFormat},e=super.getConfig();return Object.assign(t,e),t}}class nc extends ec{call(t,n){return e.tidy((()=>{const e=Ut(t);return"channelsLast"===this.dataFormat?i.mean(e,[1,2]):i.mean(e,[2,3])}))}}nc.className="GlobalAveragePooling2D",e.serialization.registerClass(nc);class sc extends ec{call(t,n){return e.tidy((()=>{const e=Ut(t);return"channelsLast"===this.dataFormat?i.max(e,[1,2]):i.max(e,[2,3])}))}}sc.className="GlobalMaxPooling2D",e.serialization.registerClass(sc);class ic extends Qt{constructor(t){super(t),this.layer=t.layer}build(t){this.built=!0}get trainable(){return null!=this.layer&&this.layer.trainable}set trainable(t){null!=this.layer&&(this.layer.trainable=t)}get trainableWeights(){return this.layer.trainableWeights}get nonTrainableWeights(){return this.layer.nonTrainableWeights}get updates(){return this.layer._updates}get losses(){return this.layer.losses}getWeights(){return this.layer.getWeights()}setWeights(t){this.layer.setWeights(t)}getConfig(){const t={layer:{className:this.layer.getClassName(),config:this.layer.getConfig()}},e=super.getConfig();return Object.assign(t,e),t}setFastWeightInitDuringBuild(t){super.setFastWeightInitDuringBuild(t),null!=this.layer&&this.layer.setFastWeightInitDuringBuild(t)}static fromConfig(t,e,n={}){const s=el(e.layer,n);delete e.layer;const i={layer:s};return Object.assign(i,e),new t(i)}}class rc extends ic{constructor(t){super(t),this.supportsMasking=!0}build(t){if((t=Wt(t)).length<3)throw new o(`TimeDistributed layer expects an input shape >= 3D, but received input shape ${JSON.stringify(t)}`);this.inputSpec=[{shape:t}];const e=[t[0]].concat(t.slice(2));this.layer.built||(this.layer.build(e),this.layer.built=!0),super.build(t)}computeOutputShape(t){const e=[(t=Wt(t))[0]].concat(t.slice(2)),n=this.layer.computeOutputShape(e),s=t[1];return[n[0],s].concat(n.slice(1))}call(t,n){return e.tidy((()=>Qu(((t,e)=>[Ut(this.layer.call(t,n)),[]]),t=Ut(t),[],!1,null,null,!1,!0)[1]))}}rc.className="TimeDistributed",e.serialization.registerClass(rc);class ac extends ic{constructor(t){super(t);const e=t.layer.getConfig(),n={};n.className=t.layer.getClassName(),n.config=e,this.forwardLayer=el(n),e.goBackwards=!0!==e.goBackwards;const s={};var i;if(s.className=t.layer.getClassName(),s.config=e,this.backwardLayer=el(s),this.forwardLayer.name="forward_"+this.forwardLayer.name,this.backwardLayer.name="backward_"+this.backwardLayer.name,this.mergeMode=void 0===t.mergeMode?"concat":t.mergeMode,i=this.mergeMode,I(M,"BidirectionalMergeMode",i),t.weights)throw new l("weights support is not implemented for Bidirectional layer yet.");this._stateful=t.layer.stateful,this.returnSequences=t.layer.returnSequences,this.returnState=t.layer.returnState,this.supportsMasking=!0,this._trainable=!0,this.inputSpec=t.layer.inputSpec,this.numConstants=null}get trainable(){return this._trainable}set trainable(t){this._trainable=t,null!=this.forwardLayer&&(this.forwardLayer.trainable=t),null!=this.backwardLayer&&(this.backwardLayer.trainable=t)}getWeights(){return this.forwardLayer.getWeights().concat(this.backwardLayer.getWeights())}setWeights(t){const e=t.length,n=Math.floor(e/2);this.forwardLayer.setWeights(t.slice(0,n)),this.backwardLayer.setWeights(t.slice(n))}computeOutputShape(t){let e,n,s,i=this.forwardLayer.computeOutputShape(t);return Array.isArray(i)&&Array.isArray(i[0])||(i=[i]),this.returnState?(s=i.slice(1),e=i[0]):e=i[0],"concat"===this.mergeMode?(e[e.length-1]*=2,n=[e]):n=null==this.mergeMode?[e,e.slice()]:[e],this.returnState?null==this.mergeMode?n.concat(s).concat(s.slice()):[e].concat(s).concat(s.slice()):f(n)}apply(t,e){let n=null==e?null:e.initialState,s=null==e?null:e.constants;null==e&&(e={});const i=Xu(t,n,s,this.numConstants);if(t=i.inputs,n=i.initialState,s=i.constants,Array.isArray(t)&&(n=t.slice(1),t=t[0]),(null==n||0===n.length)&&null==s)return super.apply(t,e);const r=[],a=[];if(null!=n){const t=n.length;if(t%2>0)throw new o("When passing `initialState` to a Bidrectional RNN, the state should be an Array containing the states of the underlying RNNs.");e.initialState=n,r.push(...n);const s=n.map((t=>new Ht({shape:t.shape})));this.forwardLayer.stateSpec=s.slice(0,t/2),this.backwardLayer.stateSpec=s.slice(t/2),a.push(...s)}if(null!=s)throw new l("Support for constants in Bidirectional layers is not implemented yet.");const u=r[0]instanceof Jt;for(const t of r)if(t instanceof Jt!==u)throw new o("The initial state of a Bidirectional layer cannot be specified as a mix of symbolic and non-symbolic tensors");if(u){const n=[t].concat(r),s=this.inputSpec.concat(a),i=this.inputSpec;this.inputSpec=s;const o=super.apply(n,e);return this.inputSpec=i,o}return super.apply(t,e)}call(t,n){return e.tidy((()=>{const e=n.initialState;let s,r,a,o;if(null==e)s=this.forwardLayer.call(t,n),r=this.backwardLayer.call(t,n);else{const i=e.slice(0,e.length/2),a=e.slice(e.length/2);s=this.forwardLayer.call(t,Object.assign(n,{initialState:i})),r=this.backwardLayer.call(t,Object.assign(n,{initialState:a}))}return this.returnState&&(Array.isArray(s)&&(a=s.slice(1).concat(r.slice(1))),s=s[0],r=r[0]),this.returnSequences&&(r=i.reverse(r,1)),"concat"===this.mergeMode?o=at([s,r]):"sum"===this.mergeMode?o=i.add(s,r):"ave"===this.mergeMode?o=i.mul(.5,i.add(s,r)):"mul"===this.mergeMode?o=i.mul(s,r):null==this.mergeMode&&(o=[s,r]),this.returnState?null==this.mergeMode?o.concat(a):[o].concat(a):o}))}resetStates(t){this.forwardLayer.resetStates(),this.backwardLayer.resetStates()}build(t){q(this.forwardLayer.name,(()=>{this.forwardLayer.build(t)})),q(this.backwardLayer.name,(()=>{this.backwardLayer.build(t)})),this.built=!0}computeMask(t,e){let n;if(Array.isArray(e)&&(e=e[0]),n=this.returnSequences?null==this.mergeMode?[e,e]:e:null==this.mergeMode?[null,null]:null,this.returnState){const t=this.forwardLayer.states.map((t=>null));return Array.isArray(n)?n.concat(t).concat(t):[n].concat(t).concat(t)}return n}get trainableWeights(){return this.forwardLayer.trainableWeights.concat(this.backwardLayer.trainableWeights)}get nonTrainableWeights(){return this.forwardLayer.nonTrainableWeights.concat(this.backwardLayer.nonTrainableWeights)}setFastWeightInitDuringBuild(t){super.setFastWeightInitDuringBuild(t),null!=this.forwardLayer&&this.forwardLayer.setFastWeightInitDuringBuild(t),null!=this.backwardLayer&&this.backwardLayer.setFastWeightInitDuringBuild(t)}getConfig(){const t={mergeMode:this.mergeMode},e=super.getConfig();return Object.assign(t,e),t}static fromConfig(t,e){const n=el(e.layer);if(delete e.layer,null!=e.numConstants)throw new l("Deserialization of a Bidirectional layer with numConstants present is not supported yet.");const s=e;return s.layer=n,new t(s)}}ac.className="Bidirectional",e.serialization.registerClass(ac);class oc extends Qt{constructor(t){super(t),this.scale=t.scale,t.offset?this.offset=t.offset:this.offset=0}getConfig(){const t={scale:this.scale,offset:this.offset},e=super.getConfig();return Object.assign(t,e),t}call(t,n){return e.tidy((()=>("float32"!==(t=Ut(t)).dtype&&(t=et(t,"float32")),e.add(e.mul(t,this.scale),this.offset))))}}oc.className="Rescaling",e.serialization.registerClass(oc);const{resizeBilinear:lc,cropAndResize:uc}=e.image;class hc extends Qt{constructor(t){super(t),this.height=t.height,this.width=t.width}centerCrop(t,n,s,i,r,a,o,l){return e.tidy((()=>{let u,h=!1;const c=[n/a,s/o,(i+n)/a,(r+s)/o],p=[];3===t.rank?(h=!0,u=e.stack([t])):u=t;for(let t=0;tet(lc(t,[n,s]),i)))}call(t,n){return e.tidy((()=>{const e=Ut(t),n=e.dtype,s=e.shape,i=s[s.length-3],r=s[s.length-2];let a=0;i!==this.height&&(a=Math.floor((i-this.height)/2));let o=0;return r!==this.width&&(o=Math.floor((r-this.width)/2),0===o&&(o=1)),a>=0&&o>=0?this.centerCrop(e,a,o,this.height,this.width,i,r,n):this.upsize(t,this.height,this.width,n)}))}getConfig(){const t={height:this.height,width:this.width},e=super.getConfig();return Object.assign(t,e),t}computeOutputShape(t){const e=(t=Wt(t)).length-3,n=t.length-2;return t[e]=this.height,t[n]=this.width,t}}hc.className="CenterCrop",e.serialization.registerClass(hc);class cc extends Qt{constructor(t){super(t),this.numTokens=t.numTokens,t.outputMode?this.outputMode=t.outputMode:this.outputMode="multiHot"}getConfig(){const t={numTokens:this.numTokens,outputMode:this.outputMode},e=super.getConfig();return Object.assign(t,e),t}computeOutputShape(t){return null==(t=Wt(t))?[this.numTokens]:"oneHot"===this.outputMode&&1!==t[t.length-1]?(t.push(this.numTokens),t):(t[t.length-1]=this.numTokens,t)}call(t,n){return e.tidy((()=>{let s;if("int32"!==(t=Ut(t)).dtype&&(t=et(t,"int32")),"undefined"!=typeof n.countWeights){if("count"!==this.outputMode)throw new o(`countWeights is not used when outputMode !== count.\n Received countWeights=${n.countWeights}`);s=Ut(n.countWeights)}const i=e.max(t),r=e.min(t),a=e.greater(this.numTokens,i).bufferSync().get(0),l=e.greaterEqual(r,0).bufferSync().get(0);if(!a||!l)throw new o(`Input values must be between 0 < values <= numTokens with numTokens=${this.numTokens}`);return function(t,n,s,i){let r=Ut(t);if("int32"!==r.dtype&&(r=et(r,"int32")),"int"===n)return r;const a=r.shape;if(0===r.rank&&(r=e.expandDims(r,-1)),"oneHot"===n&&1!==r.shape[r.shape.length-1]&&(r=e.expandDims(r,-1)),r.rank>2)throw new o(`When outputMode is not int, maximum output rank is 2 Received outputMode ${n} and input shape ${a} which would result in output rank ${r.rank}.`);const l=["multiHot","oneHot"].includes(n),u=r;let h;if(h="undefined"!=typeof i&&"count"===n?e.denseBincount(u,i,s,l):e.denseBincount(u,[],s,l),"tfIdf"!==n)return h;if(i)return e.mul(h,i);throw new o("When outputMode is 'tfIdf', weights must be provided.")}(t,this.outputMode,this.numTokens,s)}))}}cc.className="CategoryEncoding",e.serialization.registerClass(cc);const pc=new Set(["bilinear","nearest"]);class dc extends Qt{constructor(t){if(super(t),this.height=t.height,this.width=t.width,t.interpolation){if(!pc.has(t.interpolation))throw new o(`Invalid interpolation parameter: ${t.interpolation} is not implemented`);this.interpolation=t.interpolation}else this.interpolation="bilinear";this.cropToAspectRatio=Boolean(t.cropToAspectRatio)}computeOutputShape(t){const e=(t=Wt(t))[2];return[this.height,this.width,e]}getConfig(){const t={height:this.height,width:this.width,interpolation:this.interpolation,cropToAspectRatio:this.cropToAspectRatio},e=super.getConfig();return Object.assign(t,e),t}call(t,n){return e.tidy((()=>{const n=[this.height,this.width];if("bilinear"===this.interpolation)return e.image.resizeBilinear(t,n,!this.cropToAspectRatio);if("nearest"===this.interpolation)return e.image.resizeNearestNeighbor(t,n,!this.cropToAspectRatio);throw new Error(`Interpolation is ${this.interpolation} but only ${[...pc]} are supported`)}))}}dc.className="Resizing",e.serialization.registerClass(dc);class fc{constructor(t){this.seed=t}next(){if(void 0!==this.seed)return this.seed++}}fc.className="RandomSeed";class gc extends Qt{constructor(t){super(t),this.randomGenerator=new fc(t.seed)}getConfig(){const t={seed:this.randomGenerator.seed},e=super.getConfig();return Object.assign(t,e),t}}gc.className="BaseRandomLayer";const mc=new Set(["bilinear","nearest"]);class yc extends gc{constructor(t){super(t);const{factor:e,interpolation:n="bilinear"}=t;if(this.factor=e,Array.isArray(this.factor)&&2===this.factor.length)this.widthLower=this.factor[0],this.widthUpper=this.factor[1];else{if(Array.isArray(this.factor)||!(this.factor>0))throw new o(`Invalid factor: ${this.factor}. Must be positive number or tuple of 2 numbers`);this.widthLower=-this.factor,this.widthUpper=this.factor}if(this.widthLower<-1||this.widthUpper<-1)throw new o(`factor must have values larger than -1. Got: ${this.factor}`);if(this.widthUpper{const n=Ut(t);this.imgHeight=n.shape[n.shape.length-3];const s=n.shape[n.shape.length-2];this.widthFactor=e.randomUniform([1],1+this.widthLower,1+this.widthUpper,"float32",this.randomGenerator.next());let i=this.widthFactor.dataSync()[0]*s;i=Math.round(i);const r=[this.imgHeight,i];switch(this.interpolation){case"bilinear":return e.image.resizeBilinear(t,r);case"nearest":return e.image.resizeNearestNeighbor(t,r);default:throw new Error(`Interpolation is ${this.interpolation}\n but only ${[...mc]} are supported`)}}))}}function bc(t){return new Vh(t)}function wc(t){return new Hh(t)}function kc(t){return new Yh(t)}function vc(t){return new tc(t)}function Sc(t){return new sc(t)}function xc(t){return new qh(t)}function Nc(t){return new Gh(t)}yc.className="RandomWidth",e.serialization.registerClass(yc);var Ic={__proto__:null,Layer:Qt,RNN:th,RNNCell:eh,activation:function(t){return new bh(t)},add:function(t){return new Ih(t)},alphaDropout:function(t){return new _h(t)},average:function(t){return new zh(t)},averagePooling1d:bc,averagePooling2d:wc,averagePooling3d:kc,avgPool1d:function(t){return bc(t)},avgPool2d:function(t){return wc(t)},avgPool3d:function(t){return kc(t)},avgPooling1d:function(t){return bc(t)},avgPooling2d:function(t){return wc(t)},avgPooling3d:function(t){return kc(t)},batchNormalization:function(t){return new Mh(t)},bidirectional:function(t){return new ac(t)},categoryEncoding:function(t){return new cc(t)},centerCrop:function(t){return new hc(t)},concatenate:function(t){return new Ch(t)},conv1d:function(t){return new Hu(t)},conv2d:function(t){return new Wu(t)},conv2dTranspose:function(t){return new qu(t)},conv3d:function(t){return new ju(t)},conv3dTranspose:function(t){return new Vu(t)},convLstm2d:function(t){return new dh(t)},convLstm2dCell:function(t){return new ph(t)},cropping2D:function(t){return new Ju(t)},dense:function(t){return new mh(t)},depthwiseConv2d:function(t){return new Yu(t)},dot:function(t){return new Fh(t)},dropout:function(t){return new fh(t)},elu:function(t){return new $u(t)},embedding:function(t){return new xh(t)},flatten:function(t){return new yh(t)},gaussianDropout:function(t){return new Lh(t)},gaussianNoise:function(t){return new Dh(t)},globalAveragePooling1d:function(t){return new Qh(t)},globalAveragePooling2d:function(t){return new nc(t)},globalMaxPool1d:vc,globalMaxPool2d:Sc,globalMaxPooling1d:vc,globalMaxPooling2d:Sc,gru:function(t){return new rh(t)},gruCell:function(t){return new ih(t)},input:nu,inputLayer:function(t){return new ee(t)},layerNormalization:function(t){return new Bh(t)},leakyReLU:function(t){return new Tu(t)},lstm:function(t){return new oh(t)},lstmCell:function(t){return new ah(t)},masking:function(t){return new Sh(t)},maxPool1d:xc,maxPool2d:Nc,maxPooling1d:xc,maxPooling2d:Nc,maxPooling3d:function(t){return new Zh(t)},maximum:function(t){return new Eh(t)},minimum:function(t){return new Th(t)},multiply:function(t){return new Ah(t)},permute:function(t){return new vh(t)},prelu:function(t){return new Cu(t)},randomWidth:function(t){return new yc(t)},reLU:function(t){return new Eu(t)},repeatVector:function(t){return new wh(t)},rescaling:function(t){return new oc(t)},reshape:function(t){return new kh(t)},resizing:function(t){return new dc(t)},rnn:function(t){return new th(t)},separableConv2d:function(t){return new Gu(t)},simpleRNN:function(t){return new sh(t)},simpleRNNCell:function(t){return new nh(t)},softmax:function(t){return new Du(t)},spatialDropout1d:function(t){return new gh(t)},stackedRNNCells:function(t){return new lh(t)},thresholdedReLU:function(t){return new Fu(t)},timeDistributed:function(t){return new rc(t)},upSampling2d:function(t){return new Zu(t)},zeroPadding2d:function(t){return new Ph(t)}};var Ac={__proto__:null,MAPE:function(t,e){return rl(t,e)},MSE:function(t,e){return sl(t,e)},binaryAccuracy:function(t,e){return pl(t,e)},binaryCrossentropy:function(t,e){return yl(t,e)},categoricalAccuracy:function(t,e){return dl(t,e)},categoricalCrossentropy:function(t,e){return wl(t,e)},cosineProximity:function(t,e){return ul(t,e)},mape:function(t,e){return rl(t,e)},meanAbsoluteError:function(t,e){return il(t,e)},meanAbsolutePercentageError:function(t,e){return rl(t,e)},meanSquaredError:function(t,e){return sl(t,e)},mse:function(t,e){return sl(t,e)},precision:function(t,e){return gl(t,e)},recall:function(t,e){return ml(t,e)},sparseCategoricalAccuracy:function(t,e){return bl(t,e)}},zc={__proto__:null,modelFromJSON:async function(t,n){"modelTopology"in t||(t={modelTopology:t});let s=t.modelTopology;null!=s.model_config&&(s=s.model_config);const i=el(Fl(s),n);if(null!=t.weightsManifest){const n=await e.io.loadWeights(t.weightsManifest,t.pathPrefix,i.weights.map((t=>t.originalName))),s={};for(const t of i.weights)s[t.originalName]=n[t.originalName];i.loadWeights(s),e.dispose(n)}return i}};var Ec={__proto__:null,l1:function(t){return vu(e=t),new xu({l1:null!=e?e.l1:null,l2:0});var e},l1l2:function(t){return new xu(t)},l2:function(t){return vu(e=t),new xu({l2:null!=e?e.l2:null,l1:0});var e}};class Tc extends Go{constructor(){super(...arguments),this.model=null}setModel(t){if(!(t instanceof Ql))throw new Error("model must be a LayersModel, not some other Container");this.model=t}}function Cc(t,e){return te}class Fc extends Tc{constructor(t){if(super(),null==t&&(t={}),t.restoreBestWeights)throw new l("restoreBestWeights = True is not implemented in EarlyStopping yet.");this.monitor=t.monitor||"val_loss",this.minDelta=Math.abs(t.minDelta||0),this.patience=t.patience||0,this.verbose=t.verbose||0,this.mode=t.mode||"auto",this.baseline=t.baseline,-1===["auto","min","max"].indexOf(this.mode)&&(console.warn(`EarlyStopping mode '${this.mode}' is invalid. Falling back to mode 'auto'.`),this.mode="auto"),"min"===this.mode?this.monitorFunc=Cc:"max"===this.mode||-1!==this.monitor.indexOf("acc")?this.monitorFunc=$c:this.monitorFunc=Cc,this.monitorFunc===Cc&&(this.minDelta*=-1)}async onTrainBegin(t){this.wait=0,this.stoppedEpoch=0,null!=this.baseline?this.best=this.baseline:this.best=this.monitorFunc===Cc?1/0:-1/0}async onEpochEnd(t,e){await Vo(e);const n=this.getMonitorValue(e);null!=n&&(this.monitorFunc(n-this.minDelta,this.best)?(this.best=n,this.wait=0):(this.wait++,this.wait>=this.patience&&(this.stoppedEpoch=t,this.model.stopTraining=!0)))}async onTrainEnd(t){this.stoppedEpoch>0&&this.verbose&&console.log(`Epoch ${this.stoppedEpoch}: early stopping.`)}getMonitorValue(t){null==t&&(t={});const e=t[this.monitor];return null==e&&console.warn(`Metric for EarlyStopping ${this.monitor} is not available. Available metrics are: ${Object.keys(t)}`),e}}const Dc={earlyStopping:function(t){return new Fc(t)}};t.Callback=Tc,t.CallbackList=Ho,t.CustomCallback=Yo,t.EarlyStopping=Fc,t.History=Zo,t.InputSpec=Ht,t.LayerVariable=Vt,t.LayersModel=Ql,t.RNN=th,t.Sequential=eu,t.SymbolicTensor=Jt,t.callbacks=Dc,t.constraints=Wo,t.initializers=qo,t.input=nu,t.layers=Ic,t.loadLayersModel=async function(t,n){if(null==n&&(n={}),"string"==typeof t){const s=e.io.getLoadHandlers(t,n);if(0===s.length)s.push(e.io.browserHTTPRequest(t,n));else if(s.length>1)throw new o(`Found more than one (${s.length}) load handlers for URL '${t}'`);t=s[0]}return async function(t,n,s){null==s&&(s={});if(null==t.load)throw new o("Cannot proceed with model loading because the IOHandler provided does not have the `load` method implemented.");const i=await t.load();let r=i.modelTopology;null!=r.model_config&&(r=r.model_config);const a=null==s.strict||s.strict,l=null!=i.weightData&&null!=i.weightSpecs&&a,u=el(Fl(r),n,l),h=i.trainingConfig;null!=h&&u.loadTrainingConfig(h);null!=i.userDefinedMetadata&&u.setUserDefinedMetadata(i.userDefinedMetadata);if(null!=i.weightData){if(null==i.weightSpecs)throw new o("LayersModel artifacts contains weight data, but not weight specs. Therefore loading of weights cannot proceed.");const{modelWeights:t,optimizerWeights:n}=function(t,n){const s=e.io.decodeWeights(t,n),i={},r=[];return n.forEach((t=>{"optimizer"===t.group?r.push({name:t.name,tensor:s[t.name]}):i[t.name]=s[t.name]})),{modelWeights:i,optimizerWeights:r}}(i.weightData,i.weightSpecs);u.loadWeights(t,a),null!=u.optimizer&&n.length>0&&await u.optimizer.setWeights(n),e.dispose(t),e.dispose(n.map((t=>t.tensor)))}return u}(t,void 0,n)},t.metrics=Ac,t.model=function(t){return new Ql(t)},t.models=zc,t.registerCallbackConstructor=function(t,e){Qo.registerCallbackConstructor(t,e)},t.regularizers=Ec,t.sequential=function(t){return new eu(t)},t.version_layers=Ll})); //# sourceMappingURL=tf-layers.es2017.min.js.map