/** * @license * Copyright 2023 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import*as t from"@tensorflow/tfjs-core";import{util as e,backend as n,tidy as s,tensor1d as i,serialization as r,zeros as a,ones as o,mul as l,scalar as u,randomUniform as h,truncatedNormal as c,eye as p,linalg as d,dispose as f,memory as g,cast as m,env as y,nextFrame as b,add as w,div as k,keep as v,train as S,clone as x,argMax as N,reshape as I,Tensor as A,Optimizer as T,io as E,sum as C,abs as z,relu as $,clipByValue as F,leakyRelu as D,prelu as L,elu as _,greater as R,sub as M,exp as O,logSumExp as B,transpose as P,any as U,notEqual as W,zerosLike as j,greaterEqual as V,moments as q,stack as K,tensor as G,range as H,unstack as J,image as Z,expandDims as Y,denseBincount as X,max as Q,min as tt}from"@tensorflow/tfjs-core";function et(t,e){return e.forEach((function(e){e&&"string"!=typeof e&&!Array.isArray(e)&&Object.keys(e).forEach((function(n){if("default"!==n&&!(n in t)){var s=Object.getOwnPropertyDescriptor(e,n);Object.defineProperty(t,n,s.get?s:{enumerable:!0,get:function(){return e[n]}})}}))})),t}class nt extends Error{constructor(t){super(t),Object.setPrototypeOf(this,nt.prototype)}}class st extends Error{constructor(t){super(t),Object.setPrototypeOf(this,st.prototype)}}class it extends Error{constructor(t){super(t),Object.setPrototypeOf(this,it.prototype)}}class rt extends Error{constructor(t){super(t),Object.setPrototypeOf(this,rt.prototype)}}class at extends Error{constructor(t){super(t),Object.setPrototypeOf(this,at.prototype)}}class ot{constructor(t){this.maxEntries=t||100,this.cache=new Map}get(t){let e;return this.cache.has(t)&&(e=this.cache.get(t),this.cache.delete(t),this.cache.set(t,e)),e}put(t,e){if(this.cache.has(t))this.cache.delete(t);else if(this.cache.size>=this.maxEntries){const t=this.cache.keys().next().value;this.cache.delete(t)}this.cache.set(t,e)}getMaxEntries(){return this.maxEntries}setMaxEntries(t){if(t<0)throw new Error(`The maxEntries of LRU caches must be at least 0, but got ${t}.`);if(this.maxEntries>t)for(let e=0;ee.toUpperCase()))}let gt={};function mt(t){if(null==t)return null;const e={};return e.className=t.getClassName(),e.config=t.getConfig(),e}function yt(t){if(null!=t&&"object"==typeof t)if(Array.isArray(t))t.forEach((t=>yt(t)));else{const e=Object.keys(t);for(const n of e){const e=t[n];null!=e&&"object"==typeof e&&(Array.isArray(e)||"ndarray"!==e.type||"number"!=typeof e.value?yt(e):t[n]=e.value)}}}function bt(t,e={},n={},s="object",i=!1){if("string"==typeof t){const i=t;let r;if(i in n)r=n[i];else if(i in gt)r=gt[i];else if(r=e[i],null==r)throw new it(`Unknown ${s}: ${t}. This may be due to one of the following reasons:\n1. The ${s} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.\n2. The custom ${s} is defined in JavaScript, but is not registered properly with tf.serialization.registerClass().`);return r}{const r=t;if(null==r.className||null==r.config)throw new it(`${s}: Improper config format: ${JSON.stringify(r)}.\n'className' and 'config' must set.`);const a=r.className;let o,l;if(a in n?[o,l]=n[a]:a in gt?[o,l]=gt.className:a in e&&([o,l]=e[a]),null==o)throw new it(`Unknown ${s}: ${a}. This may be due to one of the following reasons:\n1. The ${s} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.\n2. The custom ${s} is defined in JavaScript, but is not registered properly with tf.serialization.registerClass().`);if(null!=l){const t={};for(const e of Object.keys(gt))t[e]=gt[e];for(const e of Object.keys(n))t[e]=n[e];r.config.customObjects=t;const e=Object.assign({},gt);for(const t of Object.keys(n))gt[t]=n[t];yt(r.config);const s=l(o,r.config,n,i);return gt=Object.assign({},e),s}{const t=Object.assign({},gt);for(const t of Object.keys(n))gt[t]=n[t];const e=new o(r.config);return gt=Object.assign({},t),e}}}function wt(t,e){return-1*function(t,e){return te?1:0}(t,e)}function kt(t){if(null==t)return t;const e=[];for(const n of t)-1===e.indexOf(n)&&e.push(n);return e}function vt(t){if(null==t)throw new it(`Invalid value in obj: ${JSON.stringify(t)}`);for(const e in t)if(t.hasOwnProperty(e))return!1;return!0}function St(t,e,n){if(null!=n&&t.indexOf(n)<0)throw new it(`${n} is not a valid ${e}. Valid values are ${t} or null/undefined.`)}function xt(t,e,n=0,s=1/0){return ut(n>=0),ut(s>=n),Array.isArray(t)&&t.length>=n&&t.length<=s&&t.every((t=>typeof t===e))}function Nt(t,n){Array.isArray(t)?(e.assert(t.length>0,(()=>`${n} is unexpectedly an empty array.`)),t.forEach(((t,e)=>Nt(t,`element ${e+1} of ${n}`)))):e.assert(Number.isInteger(t)&&t>0,(()=>`Expected ${n} to be a positive integer, but got ${It(t)}.`))}function It(t){return null===t?"null":Array.isArray(t)?"["+t.map((t=>It(t))).join(",")+"]":"string"==typeof t?`"${t}"`:`${t}`}function At(t){return"relu"===t?"relu":"linear"===t?"linear":"elu"===t?"elu":null}let Tt=0;function Et(){return Tt++}const Ct={};function zt(t=""){return t in Ct||(Ct[t]=0),Ct[t]+=1,t+Ct[t].toString()}const $t=["channelsFirst","channelsLast"],Ft=["nearest","bilinear"],Dt=["valid","same","causal"],Lt=["max","avg"],_t=["sum","mul","concat","ave"],Rt=new Map;function Mt(t){St($t,"DataFormat",t)}function Ot(t){St(Dt,"PaddingMode",t)}function Bt(t){St(Lt,"PoolMode",t)}const Pt=[];function Ut(t,e){Pt.push(t);try{const t=e();return Pt.pop(),t}catch(t){throw Pt.pop(),t}}function Wt(t){if(!qt(t))throw new Error("Not a valid tensor name: '"+t+"'");return(0===Pt.length?"":Pt.join("/")+"/")+t}function jt(t){if(!qt(t))throw new Error("Not a valid tensor name: '"+t+"'");Rt.has(t)||Rt.set(t,0);const e=Rt.get(t);if(Rt.set(t,Rt.get(t)+1),e>0){const n=`${t}_${e}`;return Rt.set(n,1),n}return t}const Vt=new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);function qt(t){return!!t.match(Vt)}function Kt(t,e,n){null==e&&(e=0),null==n&&(n=t.length);let s=1;for(let i=e;ie&&(e=s)}return e}function Jt(t,e){if(e{switch(e.rank){case 1:return t.slice1d(e,n,i);case 2:return t.slice2d(e,[n,0],[i,e.shape[1]]);case 3:return t.slice3d(e,[n,0,0],[i,e.shape[1],e.shape[2]]);case 4:return t.slice4d(e,[n,0,0,0],[i,e.shape[1],e.shape[2],e.shape[3]]);case 5:return t.slice(e,[n,0,0,0,0],[i,e.shape[1],e.shape[2],e.shape[3],e.shape[4]]);case 6:return t.slice(e,[n,0,0,0,0,0],[i,e.shape[1],e.shape[2],e.shape[3],e.shape[4],e.shape[5]]);default:throw new it(`sliceAlongFirstAxis() received an unsupported tensor rank: ${e.rank}`)}}))}function ee(e,n,i){return s((()=>{switch(e.rank){case 1:return t.slice1d(e,n,i);case 2:return t.slice2d(e,[0,n],[e.shape[0],i]);case 3:return t.slice3d(e,[0,0,n],[e.shape[0],e.shape[1],i]);case 4:return t.slice4d(e,[0,0,0,n],[e.shape[0],e.shape[1],e.shape[2],i]);default:throw new it(`sliceAlongLastAxis() received an unsupported tensor rank: ${e.rank}`)}}))}function ne(e,n,i,r){return s((()=>{switch(e.rank){case 1:return t.slice1d(e,n,i);case 2:switch(r){case 1:return te(e,n,i);case 2:return ee(e,n,i);default:throw new it(`The axis is not within the rank of the tensor ${r}`)}case 3:switch(r){case 1:return te(e,n,i);case 2:return t.slice3d(e,[0,n,0],[e.shape[0],i,e.shape[2]]);case 3:return ee(e,n,i);default:throw new it(`The axis is not within the rank of the tensor ${r}`)}case 4:switch(r){case 1:return te(e,n,i);case 2:return t.slice4d(e,[0,n,0,0],[e.shape[0],i,e.shape[2],e.shape[3]]);case 3:return t.slice4d(e,[0,0,n,0],[e.shape[0],e.shape[1],i,e.shape[3]]);case 4:return ee(e,n,i);default:throw new it(`The axis is not within the rank of the tensor ${r}`)}default:throw new it(`sliceAlongLastAxis() received an unsupported tensor rank: ${e.rank}`)}}))}function se(e,n=-1){let s;return n<0&&(s=e[0].rank,n=0!==s?s:0),n===e[0].rank&&(n=-1),t.concat(e,n)}function ie(e,n){switch(e.rank){case 1:return t.concat1d([e,n]);case 2:return t.concat2d([e,n],0);case 3:return t.concat3d([e,n],0);case 4:return t.concat4d([e,n],0);default:throw new it(`concatAlongFirstAxis() received an unsupported tensor rank: ${e.rank}`)}}function re(e,n){if(Array.isArray(n)||(n=[n]),e.rank!==n.length)throw new it(`The length of input n (${n.length}) does not match the number of dimensions in input x (${e.rank})`);return t.tile(e,n)}function ae(e,n=0,s=1,i,r){return t.randomNormal(e,n,s,i,r)}function oe(e,n,s,i){if(e.rank<2||n.rank<2)throw new rt(`dot requires both inputs to be rank >= 2 but got x shape = ${e.shape} and y shape = ${n.shape}`);if(n.rank>=3){if(e.shape.slice(-1)[0]!==n.shape.slice(-2)[0])throw new rt(`If rank y >= 3, then the second last dim of y must equal the last dim of x but got x shape = ${e.shape} and y shape = ${n.shape}`)}if(2===e.rank&&2===n.rank){const r=!1,a=!1;return t.fused.matMul({a:e,b:n,transposeA:r,transposeB:a,bias:i?he(e.rank,i,"channelsLast"):null,activation:s})}{const r=e.shape.slice(),a=r.pop();e=t.reshape(e,[-1,a]);const o=n.shape.slice(),l=o.pop(),u=o.pop(),h=[...o,l],c=Array.from({length:n.rank},((t,e)=>0===e?n.rank-2:e<=n.rank-2?e-1:e));n=t.reshape(t.transpose(n,c),[u,-1]);const p=[...r,...h],d=!1,f=!1;return t.reshape(t.fused.matMul({a:e,b:n,transposeA:d,transposeB:f,bias:i?he(e.rank,i,"channelsLast"):null,activation:s}),p)}}function le(e,n,r){return s((()=>(n=Array.isArray(n)?i(n,"int32"):t.cast(n,"int32"),t.gather(e,n,r))))}function ue(e){return t.mul(e,e)}function he(e,n,s){const i=n.shape;if(1!==n.rank&&n.rank!==e)throw new it(`Unexpected bias dimensions: ${n.rank}; expected it to be 1 or ${e}`);if(5===e){if("channelsFirst"===s)return 1===i.length?t.reshape(n,[1,i[0],1,1,1]):t.reshape(n,[1,i[3],i[0],i[1],i[2]]);if("channelsLast"===s)return 1===i.length?t.reshape(n,[1,1,1,1,i[0]]):t.reshape(n,[1].concat(i))}else if(4===e){if("channelsFirst"===s)return 1===i.length?t.reshape(n,[1,i[0],1,1]):t.reshape(n,[1,i[2],i[0],i[1]]);if("channelsLast"===s)return 1===i.length?t.reshape(n,[1,1,1,i[0]]):t.reshape(n,[1].concat(i))}else if(3===e){if("channelsFirst"===s)return 1===i.length?t.reshape(n,[1,i[0],1]):t.reshape(n,[1,i[1],i[0]]);if("channelsLast"===s)return 1===i.length?t.reshape(n,[1,1,i[0]]):t.reshape(n,[1].concat(i))}else if(e<3)return n;throw new it(`Unsupported input rank by biasAdd: ${n.rank}`)}function ce(e,n,i){return s((()=>(null==i&&(i="channelsLast"),Mt(i),t.add(e,he(e.rank,n,i)))))}function pe(e,n,i,r){return s((()=>t.dropout(e,n,i,r)))}function de(t,e,n=!1){return n?t():e()}const fe=["fanIn","fanOut","fanAvg"],ge=["normal","uniform","truncatedNormal"];class me extends r.Serializable{fromConfigUsesCustomObjects(){return!1}getConfig(){return{}}}class ye extends me{apply(t,e){return a(t,e)}}ye.className="Zeros",r.registerClass(ye);class be extends me{apply(t,e){return o(t,e)}}be.className="Ones",r.registerClass(be);class we extends me{constructor(t){if(super(),"object"!=typeof t)throw new it(`Expected argument of type ConstantConfig but got ${t}`);if(void 0===t.value)throw new it(`config must have value set but got ${t}`);this.value=t.value}apply(t,e){return s((()=>l(u(this.value),o(t,e))))}getConfig(){return{value:this.value}}}we.className="Constant",r.registerClass(we);class ke extends me{constructor(t){super(),this.DEFAULT_MINVAL=-.05,this.DEFAULT_MAXVAL=.05,this.minval=t.minval||this.DEFAULT_MINVAL,this.maxval=t.maxval||this.DEFAULT_MAXVAL,this.seed=t.seed}apply(t,e){return h(t,this.minval,this.maxval,e,this.seed)}getConfig(){return{minval:this.minval,maxval:this.maxval,seed:this.seed}}}ke.className="RandomUniform",r.registerClass(ke);class ve extends me{constructor(t){super(),this.DEFAULT_MEAN=0,this.DEFAULT_STDDEV=.05,this.mean=t.mean||this.DEFAULT_MEAN,this.stddev=t.stddev||this.DEFAULT_STDDEV,this.seed=t.seed}apply(t,e){if("float32"!==(e=e||"float32")&&"int32"!==e)throw new rt(`randomNormal does not support dType ${e}.`);return ae(t,this.mean,this.stddev,e,this.seed)}getConfig(){return{mean:this.mean,stddev:this.stddev,seed:this.seed}}}ve.className="RandomNormal",r.registerClass(ve);class Se extends me{constructor(t){super(),this.DEFAULT_MEAN=0,this.DEFAULT_STDDEV=.05,this.mean=t.mean||this.DEFAULT_MEAN,this.stddev=t.stddev||this.DEFAULT_STDDEV,this.seed=t.seed}apply(t,e){if("float32"!==(e=e||"float32")&&"int32"!==e)throw new rt(`truncatedNormal does not support dType ${e}.`);return c(t,this.mean,this.stddev,e,this.seed)}getConfig(){return{mean:this.mean,stddev:this.stddev,seed:this.seed}}}Se.className="TruncatedNormal",r.registerClass(Se);let xe=class extends me{constructor(t){super(),this.gain=null!=t.gain?t.gain:1}apply(t,e){return s((()=>{if(2!==t.length||t[0]!==t[1])throw new it("Identity matrix initializer can only be used for 2D square matrices.");return l(this.gain,p(t[0]))}))}getConfig(){return{gain:this.gain}}};xe.className="Identity",r.registerClass(xe);class Ne extends me{constructor(t){if(super(),t.scale<0)throw new it(`scale must be a positive float. Got: ${t.scale}`);var e;this.scale=null==t.scale?1:t.scale,this.mode=null==t.mode?"fanIn":t.mode,e=this.mode,St(fe,"FanMode",e),this.distribution=null==t.distribution?"normal":t.distribution,function(t){St(ge,"Distribution",t)}(this.distribution),this.seed=t.seed}apply(t,e){const n=function(t,e="channelsLast"){let n,s;if(Mt(e),2===t.length)n=t[0],s=t[1];else if(-1!==[3,4,5].indexOf(t.length)){if("channelsFirst"===e){const e=Kt(t,2);n=t[1]*e,s=t[0]*e}else if("channelsLast"===e){const e=Kt(t,0,t.length-2);n=t[t.length-2]*e,s=t[t.length-1]*e}}else{const e=Kt(t);n=Math.sqrt(e),s=Math.sqrt(e)}return[n,s]}(t),s=n[0],i=n[1];let r=this.scale;if("fanIn"===this.mode?r/=Math.max(1,s):"fanOut"===this.mode?r/=Math.max(1,i):r/=Math.max(1,(s+i)/2),"normal"===this.distribution){const n=Math.sqrt(r);if("float32"!==(e=e||"float32")&&"int32"!==e)throw new rt(`${this.getClassName()} does not support dType ${e}.`);return c(t,0,n,e,this.seed)}{const n=Math.sqrt(3*r);return h(t,-n,n,e,this.seed)}}getConfig(){return{scale:this.scale,mode:this.mode,distribution:this.distribution,seed:this.seed}}}Ne.className="VarianceScaling",r.registerClass(Ne);class Ie extends Ne{constructor(t){super({scale:1,mode:"fanAvg",distribution:"uniform",seed:null==t?null:t.seed})}getClassName(){return Ne.className}}Ie.className="GlorotUniform",r.registerClass(Ie);class Ae extends Ne{constructor(t){super({scale:1,mode:"fanAvg",distribution:"normal",seed:null==t?null:t.seed})}getClassName(){return Ne.className}}Ae.className="GlorotNormal",r.registerClass(Ae);class Te extends Ne{constructor(t){super({scale:2,mode:"fanIn",distribution:"normal",seed:null==t?null:t.seed})}getClassName(){return Ne.className}}Te.className="HeNormal",r.registerClass(Te);class Ee extends Ne{constructor(t){super({scale:2,mode:"fanIn",distribution:"uniform",seed:null==t?null:t.seed})}getClassName(){return Ne.className}}Ee.className="HeUniform",r.registerClass(Ee);class Ce extends Ne{constructor(t){super({scale:1,mode:"fanIn",distribution:"normal",seed:null==t?null:t.seed})}getClassName(){return Ne.className}}Ce.className="LeCunNormal",r.registerClass(Ce);class ze extends Ne{constructor(t){super({scale:1,mode:"fanIn",distribution:"uniform",seed:null==t?null:t.seed})}getClassName(){return Ne.className}}ze.className="LeCunUniform",r.registerClass(ze);class $e extends me{constructor(t){super(),this.DEFAULT_GAIN=1,this.ELEMENTS_WARN_SLOW=2e3,this.gain=null==t.gain?this.DEFAULT_GAIN:t.gain,this.seed=t.seed}apply(t,n){return s((()=>{if(t.length<2)throw new rt("Shape must be at least 2D.");if("int32"!==n&&"float32"!==n&&void 0!==n)throw new TypeError(`Unsupported data type ${n}.`);const s=e.sizeFromShape(t.slice(0,-1)),i=t[t.length-1],r=s*i;r>this.ELEMENTS_WARN_SLOW&&console.warn(`Orthogonal initializer is being called on a matrix with more than ${this.ELEMENTS_WARN_SLOW} (${r}) elements: Slowness may result.`);const a=ae([Math.max(i,s),Math.min(i,s)],0,1,n,this.seed),o=d.qr(a,!1);let h=o[0];const c=o[1].flatten().stridedSlice([0],[Math.min(i,s)*Math.min(i,s)],[Math.min(i,s)+1]);return h=l(h,c.sign()),st*e));return e}class Ue{constructor(e,n="float32",s="Variable",i=!0,r=null){this.dtype=null==n?"float32":n,this.shape=e.shape,this.id=Et(),s=null==s?"Variable":s,this.originalName=Wt(s),this.name=jt(this.originalName),this.trainable_=i,this.constraint=r,this.val=t.variable(e,this.trainable_,this.name,this.dtype)}read(){return this.assertNotDisposed(),this.val}write(t){return this.assertNotDisposed(),function(t,e){if(t.shape.toString()!==e.shape.toString())throw new Error("Shape mismatch: "+JSON.stringify(t.shape)+" vs. "+JSON.stringify(e.shape))}(this.val,t),this.val.id!==t.id&&(this.val.assign(t),null!=this.constraint&&this.val.assign(this.constraint.apply(this.val))),this}dispose(){this.assertNotDisposed(),this.val.dispose()}assertNotDisposed(){if(this.val.isDisposed)throw new Error(`LayersVariable ${this.name} is already disposed.`)}get trainable(){return this.trainable_}set trainable(t){this.trainable_=t,this.val.trainable=t}}function We(t){return t.map((t=>t.read()))}function je(t){t.forEach((t=>{t[0].write(t[1])}))}class Ve{constructor(t){this.dtype=t.dtype,this.shape=t.shape,null!=t.shape?this.ndim=t.shape.length:this.ndim=t.ndim,this.maxNDim=t.maxNDim,this.minNDim=t.minNDim,this.axes=t.axes||{}}}class qe{constructor(t,e,n,s,i,r,a){this.dtype=t,this.shape=e,this.sourceLayer=n,this.inputs=s,this.callArgs=i,this.outputTensorIndex=a,this.id=Et(),null!=r&&(this.originalName=Wt(r),this.name=jt(this.originalName)),this.rank=e.length}}let Ke=0;class Ge{constructor(t,e){this.callArgs=e,this.id=Ke++,this.outboundLayer=t.outboundLayer,this.inboundLayers=t.inboundLayers,this.nodeIndices=t.nodeIndices,this.tensorIndices=t.tensorIndices,this.inputTensors=t.inputTensors,this.outputTensors=t.outputTensors,this.inputMasks=t.inputMasks,this.outputMasks=t.outputMasks,this.inputShapes=t.inputShapes,this.outputShapes=t.outputShapes;for(const e of t.inboundLayers)null!=e&&e.outboundNodes.push(this);t.outboundLayer.inboundNodes.push(this)}getConfig(){const t=[];for(const e of this.inboundLayers)null!=e?t.push(e.name):t.push(null);return{outboundLayer:this.outboundLayer?this.outboundLayer.name:null,inboundLayers:t,nodeIndices:this.nodeIndices,tensorIndices:this.tensorIndices}}}let He=0;class Je extends r.Serializable{constructor(t={}){super(),this._callHook=null,this._addedWeightNames=[],this._stateful=!1,this.id=He++,this.activityRegularizer=null,this.inputSpec=null,this.supportsMasking=!1,this._trainableWeights=[],this._nonTrainableWeights=[],this._losses=[],this._updates=[],this._built=!1,this.inboundNodes=[],this.outboundNodes=[];let e=t.name;if(!e){const t=this.getClassName();e=dt(t)+"_"+zt(t)}if(this.name=e,this.trainable_=null==t.trainable||t.trainable,null!=t.inputShape||null!=t.batchInputShape){let e;if(null!=t.batchInputShape)e=t.batchInputShape;else if(null!=t.inputShape){let n=null;null!=t.batchSize&&(n=t.batchSize),e=[n].concat(t.inputShape)}this.batchInputShape=e;let n=t.dtype;null==n&&(n=t.inputDType),null==n&&(n="float32"),this.dtype=n}null!=t.weights?this.initialWeights=t.weights:this.initialWeights=null,this._refCount=null,this.fastWeightInitDuringBuild=!1}static nodeKey(t,e){return t.name+"_ib-"+e.toString()}getNodeAtIndex(t,e){if(0===this.inboundNodes.length)throw new st(`The layer has never been called and thus has no defined ${e}.`);if(this.inboundNodes.length<=t)throw new it(`Asked to get ${e} at node ${t}, but the layer has only ${this.inboundNodes.length} inbound nodes.`);return this.inboundNodes[t]}getInputAt(t){return ct(this.getNodeAtIndex(t,"input").inputTensors)}getOutputAt(t){return ct(this.getNodeAtIndex(t,"output").outputTensors)}get input(){if(this.inboundNodes.length>1)throw new nt(`Layer ${this.name} has multiple inbound nodes, hence the notion of "layer input" is ill-defined. Use \`getInputAt(nodeIndex)\` instead.`);if(0===this.inboundNodes.length)throw new nt(`Layer ${this.name} is not connected, no input to return.`);return ct(this.getNodeAtIndex(0,"input").inputTensors)}get output(){if(0===this.inboundNodes.length)throw new nt(`Layer ${this.name} has no inbound nodes.`);if(this.inboundNodes.length>1)throw new nt(`Layer ${this.name} has multiple inbound nodes, hence the notion of "layer output" is ill-defined. Use \`getOutputAt(nodeIndex)\` instead.`);return ct(this.getNodeAtIndex(0,"output").outputTensors)}get losses(){return this._losses}calculateLosses(){return this.losses.map((t=>t()))}get updates(){return this._updates}get built(){return this._built}set built(t){this._built=t}get trainable(){return this.trainable_}set trainable(t){this._trainableWeights.forEach((e=>e.trainable=t)),this.trainable_=t}get trainableWeights(){return this.trainable_?this._trainableWeights.filter((t=>t.trainable)):[]}set trainableWeights(t){this._trainableWeights=t}get nonTrainableWeights(){return this.trainable?this._trainableWeights.filter((t=>!t.trainable)).concat(this._nonTrainableWeights):this._trainableWeights.concat(this._nonTrainableWeights)}set nonTrainableWeights(t){this._nonTrainableWeights=t}get weights(){return this.trainableWeights.concat(this.nonTrainableWeights)}get stateful(){return this._stateful}resetStates(){if(!this.stateful)throw new Error("Cannot call the resetStates() method of a non-stateful Layer object.")}assertInputCompatibility(t){const e=pt(t);if(null==this.inputSpec||0===this.inputSpec.length)return;const n=pt(this.inputSpec);if(e.length!==n.length)throw new it(`Layer ${this.name} expects ${n.length} inputs, but it received ${e.length} input tensors. Input received: ${t}`);for(let t=0;ti.maxNDim)throw new it(`Input ${t} is incompatible with layer ${this.name}: expected max_ndim=${i.maxNDim}, found ndim=${r}`);if(null!=i.minNDim&&r=0?e[s]:e[e.length+s];if(null!=r&&-1===[r,null].indexOf(a))throw new it(`Input ${t} is incompatible with layer ${this.name}: expected axis ${s} of input shape to have value ${r} but got shape ${e}.`)}}if(null!=i.shape)for(let e=0;e{if(!this.built){this.assertInputCompatibility(t);const e=[];for(const n of pt(t))e.push(n.shape);this.build(ct(e)),this.built=!0,this.initialWeights&&this.setWeights(this.initialWeights),null===this._refCount&&i&&(this._refCount=1)}if(this.assertInputCompatibility(t),i){let s=this.call(t,e);this.supportsMasking&&this.setMaskMetadata(t,s);const i=pt(s),r=[];for(let t of i)-1!==n.indexOf(t)&&(t=t.clone()),r.push(t);if(s=ct(r),null!=this.activityRegularizer)throw new rt("Layer invocation in the presence of activity regularizer(s) is not supported yet.");return s}{const n=function(t){t=pt(t);const e=[];for(const n of t)e.push(n.shape);return ct(e)}(t),s=this.computeOutputShape(n);let i;const r="float32";if(this.warnOnIncompatibleInputShape(Array.isArray(t)?n[0]:n),i=null!=s&&s.length>0&&Array.isArray(s[0])?s.map(((n,s)=>new qe(r,n,this,pt(t),e,this.name,s))):new qe(r,s,this,pt(t),e,this.name),this.addInboundNode(t,i,null,null,n,s,e),this._refCount++,null!=this.activityRegularizer)throw new rt("Layer invocation in the presence of activity regularizer(s) is not supported yet.");return i}}))}warnOnIncompatibleInputShape(t){if(null!=this.batchInputShape)if(t.length!==this.batchInputShape.length)console.warn(`The rank of the input tensor provided (shape: ${JSON.stringify(t)}) does not match that of the batchInputShape (${JSON.stringify(this.batchInputShape)}) of the layer ${this.name}`);else{let e=!1;this.batchInputShape.forEach(((n,s)=>{null!=n&&null!=t[s]&&t[s]!==n&&(e=!0)})),e&&console.warn(`The shape of the input tensor (${JSON.stringify(t)}) does not match the expectation of layer ${this.name}: ${JSON.stringify(this.batchInputShape)}`)}}get outputShape(){if(null==this.inboundNodes||0===this.inboundNodes.length)throw new nt(`The layer ${this.name} has never been called and thus has no defined output shape.`);const t=[];for(const e of this.inboundNodes){const n=JSON.stringify(e.outputShapes);-1===t.indexOf(n)&&t.push(n)}if(1===t.length){const t=this.inboundNodes[0].outputShapes;return Array.isArray(t)&&Array.isArray(t[0])&&1===t.length?t[0]:t}throw new nt(`The layer ${this.name} has multiple inbound nodes with different output shapes. Hence the notion of "output shape" is ill-defined for the layer.`)}countParams(){if(!this.built)throw new st(`You tried to call countParams() on ${this.name}, but the layer is not built yet. Build it first by calling build(batchInputShape).`);return Pe(this.weights)}build(t){this.built=!0}getWeights(t=!1){return We(t?this.trainableWeights:this.weights)}setWeights(t){s((()=>{const n=this.weights;if(n.length!==t.length)throw new it(`You called setWeights(weights) on layer "${this.name}" with a weight list of length ${t.length}, but the layer was expecting ${n.length} weights. Provided weights: ${t}...`);if(0===n.length)return;const s=[],i=We(n);for(let r=0;ri.apply(u.read()))),null==r&&(r=!0),r?this._trainableWeights.push(u):this._nonTrainableWeights.push(u),u}setFastWeightInitDuringBuild(t){this.fastWeightInitDuringBuild=t}addLoss(t){null==t||Array.isArray(t)&&0===t.length||(t=pt(t),void 0!==this._losses&&null!==this._losses&&this.losses.push(...t))}computeOutputShape(t){return t}computeMask(t,e){if(!this.supportsMasking){if(null!=e){if(!Array.isArray(e))throw new TypeError(`Layer ${this.name} does not support masking, but was passed an inputMask.`);e.forEach((t=>{if(null!=t)throw new TypeError(`Layer ${this.name} does not support masking, but was passed an inputMask.`)}))}return null}return e}setMaskMetadata(t,e,n){if(!this.supportsMasking)return;const s=this.computeMask(t,n),i=pt(e),r=pt(s);if(i.length!==r.length)throw new Error(`${this.name} outputs ${i.length} tensors but ${i.length} masks for those tensors`);for(let t=0;tt.dispose())),this.weights.length}assertNotDisposed(){if(0===this._refCount)throw new Error(`Layer '${this.name}' is already disposed.`)}dispose(){if(!this.built)throw new Error(`Cannot dispose Layer ${this.name} because it has not been built yet.`);if(null===this._refCount)throw new Error(`Cannot dispose Layer ${this.name} because it has not been used yet.`);this.assertNotDisposed();let t=0;return 0==--this._refCount&&(t=this.disposeWeights()),{refCountAfterDispose:this._refCount,numDisposedVariables:t}}}function Ze(t,e,n){if((null==e||null!=n&&n>0)&&(e=t.sourceLayer,n=t.nodeIndex),0===e.inboundNodes.length)return[t];{const t=e.inboundNodes[n];if(0===t.inboundLayers.length)return t.inputTensors;{const e=[];for(let n=0;nt.name)),u=[],h=n.names();for(const t of l)-1!==h.indexOf(t)?u.push(n.getValue(t)):u.push(null);null!=i&&(i.maxNumTensors=-1/0,i.minNumTensors=1/0);const c=l.join(",")+"|"+n.names().sort().join(",");let p,d=tn.get(c);if(null==d){const t=function(t,n){e.assert(null!=t&&t.length>0,(()=>"Expected at least one fetch, got none"));let s=[],i={};if(1===t.length){const e=rn(t[0],n);s=e.sorted,i=e.recipientMap}else{const e=new Set;for(const r of t){const{sorted:t,recipientMap:a}=rn(r,n);for(const n of t)e.has(n.name)||(s.push(n),e.add(n.name));for(const t in a)null==i[t]&&(i[t]=new Set),a[t].forEach((e=>i[t].add(e)))}}return{sorted:s,recipientCounts:sn(i)}}(o,n);d=t.sorted,p=t.recipientCounts,tn.put(c,d),en.put(c,p)}p={},r||Object.assign(p,en.get(c));const m=new Qe(n);for(let t=0;ti.maxNumTensors&&(i.maxNumTensors=t),t0;){const t=r[r.length-1];if(n.has(t.name)){r.pop();continue}const e=a[a.length-1]===r.length-1;if(0===t.inputs.length||e)r.pop(),s.push(t),n.add(t.name),e&&a.pop();else{a.push(r.length-1);for(const e of t.inputs)null==i[e.name]&&(i[e.name]=new Set),i[e.name].add(t.name),n.has(e.name)||r.push(e)}}return{sorted:s,recipientMap:i}}function an(t){let e;if(1===t.sourceLayer.inboundNodes.length)e=t.sourceLayer.output;else{let n=null;for(let e=0;e100),(function(t){null!=tn&&tn.setMaxEntries(t),null!=en&&en.setMaxEntries(t)}));function on(t){throw new Error(`'${t}' not yet implemented or not found in the registry. This kernel may not be supported by the tfjs backend you have chosen`)}function ln(t,e){if(!t)throw new Error("string"==typeof e?e:e())}function un(t){if(0===t.length)return 1;let e=t[0];for(let n=1;ne)):[].concat(t)).every((t=>t>=-n&&t`All values in axis param must be in range [-${n}, ${n}) but got axis ${t}`)),ln(t.every((t=>cn(t))),(()=>`All values in axis param must be integers but got axis ${t}`)),t.map((t=>t<0?n+t:t))}function fn(t){if("float32"===t||"int32"===t)return 4;if("complex64"===t)return 8;if("bool"===t)return 1;throw new Error(`Unknown dtype ${t}`)}function gn(t){return"string"==typeof t||t instanceof String}function mn(t){return Array.isArray(t)?mn(t[0]):t instanceof Float32Array?"float32":t instanceof Int32Array||t instanceof Uint8Array||t instanceof Uint8ClampedArray?"int32":"number"==typeof t?"float32":gn(t)?"string":function(t){return"boolean"==typeof t}(t)?"bool":"float32"}function yn(t){return!!(t&&t.constructor&&t.call&&t.apply)}function bn(t){const e=t.length;if(e<2)return[];const n=new Array(e-1);n[e-2]=t[e-1];for(let s=e-3;s>=0;--s)n[s]=n[s+1]*t[s+1];return n}function wn(t,e,n,s=!1){const i=new Array;if(1===e.length){const r=e[0]*(s?2:1);for(let e=0;et*e))*(s?2:1);for(let e=0;et*e))*(n?2:1);if(0===s)return[];if(s!==e.length)throw new Error(`[${t}] does not match the input size ${e.length}${n?" for a complex tensor":""}.`);return wn(0,t,e,n)}function vn(t,e){const n=Sn(t,e);for(let t=0;t{ln(Number.isInteger(e)&&e>=0,(()=>`Tensor must have a shape comprised of positive integers but got shape [${t}].`))}))}function Nn(t){return t&&t.then&&"function"==typeof t.then}class In{constructor(t){this.global=t,this.flags={},this.flagRegistry={},this.urlFlags={},this.getQueryParams=An,this.populateURLFlags()}setPlatform(t,e){null!=this.platform&&(Tn().getBool("IS_TEST")||Tn().getBool("PROD")||console.warn(`Platform ${this.platformName} has already been set. Overwriting the platform with ${t}.`)),this.platformName=t,this.platform=e}registerFlag(t,e,n){if(this.flagRegistry[t]={evaluationFn:e,setHook:n},null!=this.urlFlags[t]){const e=this.urlFlags[t];Tn().getBool("IS_TEST")||Tn().getBool("PROD")||console.warn(`Setting feature override from URL ${t}: ${e}.`),this.set(t,e)}}async getAsync(t){return t in this.flags||(this.flags[t]=await this.evaluateFlag(t)),this.flags[t]}get(t){if(t in this.flags)return this.flags[t];const e=this.evaluateFlag(t);if(Nn(e))throw new Error(`Flag ${t} cannot be synchronously evaluated. Please use getAsync() instead.`);return this.flags[t]=e,this.flags[t]}getNumber(t){return this.get(t)}getBool(t){return this.get(t)}getString(t){return this.get(t)}getFlags(){return this.flags}get features(){return this.flags}set(t,e){if(null==this.flagRegistry[t])throw new Error(`Cannot set flag ${t} as it has not been registered.`);this.flags[t]=e,null!=this.flagRegistry[t].setHook&&this.flagRegistry[t].setHook(e)}evaluateFlag(t){if(null==this.flagRegistry[t])throw new Error(`Cannot evaluate flag '${t}': no evaluation function found.`);return this.flagRegistry[t].evaluationFn()}setFlags(t){this.flags=Object.assign({},t)}reset(){this.flags={},this.urlFlags={},this.populateURLFlags()}populateURLFlags(){if("undefined"==typeof this.global||"undefined"==typeof this.global.location||"undefined"==typeof this.global.location.search)return;const t=this.getQueryParams(this.global.location.search);if("tfjsflags"in t){t.tfjsflags.split(",").forEach((t=>{const[e,n]=t.split(":");this.urlFlags[e]=function(t,e){const n=e.toLowerCase();return"true"===n||"false"===n?"true"===n:""+ +n===n?+n:e}(0,n)}))}}}function An(t){const e={};return t.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g,((t,...n)=>(function(t,e,n){t[decodeURIComponent(e)]=decodeURIComponent(n||"")}(e,n[0],n[1]),n.join("=")))),e}function Tn(){return Cn}let En,Cn=null;function zn(){if(null==En){let t;if("undefined"!=typeof window)t=window;else if("undefined"!=typeof global)t=global;else if("undefined"!=typeof process)t=process;else{if("undefined"==typeof self)throw new Error("Could not find a global object");t=self}En=t}return En}function $n(t,e){const n=function(){const t=zn();return null==t._tfGlobals&&(t._tfGlobals=new Map),t._tfGlobals}();if(n.has(t))return n.get(t);{const s=e();return n.set(t,s),n.get(t)}}function Fn(...t){Tn().getBool("IS_TEST")||Tn().getBool("PROD")||console.warn(...t)}const Dn=$n("kernelRegistry",(()=>new Map)),Ln=$n("gradRegistry",(()=>new Map));function _n(t,e){const n=function(t,e){return`${e}_${t}`}(t,e);return Dn.get(n)}function Rn(t){return Ln.get(t)}function Mn(t){const e=Dn.entries(),n=[];for(;;){const{done:s,value:i}=e.next();if(s)break;const[r,a]=i,[o]=r.split("_");o===t&&n.push(a)}return n}function On(t){const{kernelName:e}=t;Ln.has(e)&&Tn().getBool("DEBUG")&&Fn(`Overriding the gradient for '${e}'`),Ln.set(e,t)}var Bn="undefined"!=typeof globalThis?globalThis:"undefined"!=typeof window?window:"undefined"!=typeof global?global:"undefined"!=typeof self?self:{};function Pn(t){return t&&t.__esModule&&Object.prototype.hasOwnProperty.call(t,"default")?t.default:t}function Un(t){if(t.__esModule)return t;var e=t.default;if("function"==typeof e){var n=function t(){if(this instanceof t){var n=[null];n.push.apply(n,arguments);var s=Function.bind.apply(e,n);return new s}return e.apply(this,arguments)};n.prototype=e.prototype}else n={};return Object.defineProperty(n,"__esModule",{value:!0}),Object.keys(t).forEach((function(e){var s=Object.getOwnPropertyDescriptor(t,e);Object.defineProperty(n,e,s.get?s:{enumerable:!0,get:function(){return t[e]}})})),n}var Wn=Vn,jn=null;try{jn=new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([0,97,115,109,1,0,0,0,1,13,2,96,0,1,127,96,4,127,127,127,127,1,127,3,7,6,0,1,1,1,1,1,6,6,1,127,1,65,0,11,7,50,6,3,109,117,108,0,1,5,100,105,118,95,115,0,2,5,100,105,118,95,117,0,3,5,114,101,109,95,115,0,4,5,114,101,109,95,117,0,5,8,103,101,116,95,104,105,103,104,0,0,10,191,1,6,4,0,35,0,11,36,1,1,126,32,0,173,32,1,173,66,32,134,132,32,2,173,32,3,173,66,32,134,132,126,34,4,66,32,135,167,36,0,32,4,167,11,36,1,1,126,32,0,173,32,1,173,66,32,134,132,32,2,173,32,3,173,66,32,134,132,127,34,4,66,32,135,167,36,0,32,4,167,11,36,1,1,126,32,0,173,32,1,173,66,32,134,132,32,2,173,32,3,173,66,32,134,132,128,34,4,66,32,135,167,36,0,32,4,167,11,36,1,1,126,32,0,173,32,1,173,66,32,134,132,32,2,173,32,3,173,66,32,134,132,129,34,4,66,32,135,167,36,0,32,4,167,11,36,1,1,126,32,0,173,32,1,173,66,32,134,132,32,2,173,32,3,173,66,32,134,132,130,34,4,66,32,135,167,36,0,32,4,167,11])),{}).exports}catch(t){}function Vn(t,e,n){this.low=0|t,this.high=0|e,this.unsigned=!!n}function qn(t){return!0===(t&&t.__isLong__)}Vn.prototype.__isLong__,Object.defineProperty(Vn.prototype,"__isLong__",{value:!0}),Vn.isLong=qn;var Kn={},Gn={};function Hn(t,e){var n,s,i;return e?(i=0<=(t>>>=0)&&t<256)&&(s=Gn[t])?s:(n=Zn(t,(0|t)<0?-1:0,!0),i&&(Gn[t]=n),n):(i=-128<=(t|=0)&&t<128)&&(s=Kn[t])?s:(n=Zn(t,t<0?-1:0,!1),i&&(Kn[t]=n),n)}function Jn(t,e){if(isNaN(t))return e?rs:is;if(e){if(t<0)return rs;if(t>=es)return hs}else{if(t<=-ns)return cs;if(t+1>=ns)return us}return t<0?Jn(-t,e).neg():Zn(t%ts|0,t/ts|0,e)}function Zn(t,e,n){return new Vn(t,e,n)}Vn.fromInt=Hn,Vn.fromNumber=Jn,Vn.fromBits=Zn;var Yn=Math.pow;function Xn(t,e,n){if(0===t.length)throw Error("empty string");if("NaN"===t||"Infinity"===t||"+Infinity"===t||"-Infinity"===t)return is;if("number"==typeof e?(n=e,e=!1):e=!!e,(n=n||10)<2||360)throw Error("interior hyphen");if(0===s)return Xn(t.substring(1),e,n).neg();for(var i=Jn(Yn(n,8)),r=is,a=0;a>>0:this.low},ps.toNumber=function(){return this.unsigned?(this.high>>>0)*ts+(this.low>>>0):this.high*ts+(this.low>>>0)},ps.toString=function(t){if((t=t||10)<2||36>>0).toString(t);if((r=o).isZero())return l+a;for(;l.length<6;)l="0"+l;a=""+l+a}},ps.getHighBits=function(){return this.high},ps.getHighBitsUnsigned=function(){return this.high>>>0},ps.getLowBits=function(){return this.low},ps.getLowBitsUnsigned=function(){return this.low>>>0},ps.getNumBitsAbs=function(){if(this.isNegative())return this.eq(cs)?64:this.neg().getNumBitsAbs();for(var t=0!=this.high?this.high:this.low,e=31;e>0&&0==(t&1<=0},ps.isOdd=function(){return 1==(1&this.low)},ps.isEven=function(){return 0==(1&this.low)},ps.equals=function(t){return qn(t)||(t=Qn(t)),(this.unsigned===t.unsigned||this.high>>>31!=1||t.high>>>31!=1)&&(this.high===t.high&&this.low===t.low)},ps.eq=ps.equals,ps.notEquals=function(t){return!this.eq(t)},ps.neq=ps.notEquals,ps.ne=ps.notEquals,ps.lessThan=function(t){return this.comp(t)<0},ps.lt=ps.lessThan,ps.lessThanOrEqual=function(t){return this.comp(t)<=0},ps.lte=ps.lessThanOrEqual,ps.le=ps.lessThanOrEqual,ps.greaterThan=function(t){return this.comp(t)>0},ps.gt=ps.greaterThan,ps.greaterThanOrEqual=function(t){return this.comp(t)>=0},ps.gte=ps.greaterThanOrEqual,ps.ge=ps.greaterThanOrEqual,ps.compare=function(t){if(qn(t)||(t=Qn(t)),this.eq(t))return 0;var e=this.isNegative(),n=t.isNegative();return e&&!n?-1:!e&&n?1:this.unsigned?t.high>>>0>this.high>>>0||t.high===this.high&&t.low>>>0>this.low>>>0?-1:1:this.sub(t).isNegative()?-1:1},ps.comp=ps.compare,ps.negate=function(){return!this.unsigned&&this.eq(cs)?cs:this.not().add(as)},ps.neg=ps.negate,ps.add=function(t){qn(t)||(t=Qn(t));var e=this.high>>>16,n=65535&this.high,s=this.low>>>16,i=65535&this.low,r=t.high>>>16,a=65535&t.high,o=t.low>>>16,l=0,u=0,h=0,c=0;return h+=(c+=i+(65535&t.low))>>>16,u+=(h+=s+o)>>>16,l+=(u+=n+a)>>>16,l+=e+r,Zn((h&=65535)<<16|(c&=65535),(l&=65535)<<16|(u&=65535),this.unsigned)},ps.subtract=function(t){return qn(t)||(t=Qn(t)),this.add(t.neg())},ps.sub=ps.subtract,ps.multiply=function(t){if(this.isZero())return is;if(qn(t)||(t=Qn(t)),jn)return Zn(jn.mul(this.low,this.high,t.low,t.high),jn.get_high(),this.unsigned);if(t.isZero())return is;if(this.eq(cs))return t.isOdd()?cs:is;if(t.eq(cs))return this.isOdd()?cs:is;if(this.isNegative())return t.isNegative()?this.neg().mul(t.neg()):this.neg().mul(t).neg();if(t.isNegative())return this.mul(t.neg()).neg();if(this.lt(ss)&&t.lt(ss))return Jn(this.toNumber()*t.toNumber(),this.unsigned);var e=this.high>>>16,n=65535&this.high,s=this.low>>>16,i=65535&this.low,r=t.high>>>16,a=65535&t.high,o=t.low>>>16,l=65535&t.low,u=0,h=0,c=0,p=0;return c+=(p+=i*l)>>>16,h+=(c+=s*l)>>>16,c&=65535,h+=(c+=i*o)>>>16,u+=(h+=n*l)>>>16,h&=65535,u+=(h+=s*o)>>>16,h&=65535,u+=(h+=i*a)>>>16,u+=e*l+n*o+s*a+i*r,Zn((c&=65535)<<16|(p&=65535),(u&=65535)<<16|(h&=65535),this.unsigned)},ps.mul=ps.multiply,ps.divide=function(t){if(qn(t)||(t=Qn(t)),t.isZero())throw Error("division by zero");var e,n,s;if(jn)return this.unsigned||-2147483648!==this.high||-1!==t.low||-1!==t.high?Zn((this.unsigned?jn.div_u:jn.div_s)(this.low,this.high,t.low,t.high),jn.get_high(),this.unsigned):this;if(this.isZero())return this.unsigned?rs:is;if(this.unsigned){if(t.unsigned||(t=t.toUnsigned()),t.gt(this))return rs;if(t.gt(this.shru(1)))return os;s=rs}else{if(this.eq(cs))return t.eq(as)||t.eq(ls)?cs:t.eq(cs)?as:(e=this.shr(1).div(t).shl(1)).eq(is)?t.isNegative()?as:ls:(n=this.sub(t.mul(e)),s=e.add(n.div(t)));if(t.eq(cs))return this.unsigned?rs:is;if(this.isNegative())return t.isNegative()?this.neg().div(t.neg()):this.neg().div(t).neg();if(t.isNegative())return this.div(t.neg()).neg();s=is}for(n=this;n.gte(t);){e=Math.max(1,Math.floor(n.toNumber()/t.toNumber()));for(var i=Math.ceil(Math.log(e)/Math.LN2),r=i<=48?1:Yn(2,i-48),a=Jn(e),o=a.mul(t);o.isNegative()||o.gt(n);)o=(a=Jn(e-=r,this.unsigned)).mul(t);a.isZero()&&(a=as),s=s.add(a),n=n.sub(o)}return s},ps.div=ps.divide,ps.modulo=function(t){return qn(t)||(t=Qn(t)),jn?Zn((this.unsigned?jn.rem_u:jn.rem_s)(this.low,this.high,t.low,t.high),jn.get_high(),this.unsigned):this.sub(this.div(t).mul(t))},ps.mod=ps.modulo,ps.rem=ps.modulo,ps.not=function(){return Zn(~this.low,~this.high,this.unsigned)},ps.and=function(t){return qn(t)||(t=Qn(t)),Zn(this.low&t.low,this.high&t.high,this.unsigned)},ps.or=function(t){return qn(t)||(t=Qn(t)),Zn(this.low|t.low,this.high|t.high,this.unsigned)},ps.xor=function(t){return qn(t)||(t=Qn(t)),Zn(this.low^t.low,this.high^t.high,this.unsigned)},ps.shiftLeft=function(t){return qn(t)&&(t=t.toInt()),0==(t&=63)?this:t<32?Zn(this.low<>>32-t,this.unsigned):Zn(0,this.low<>>t|this.high<<32-t,this.high>>t,this.unsigned):Zn(this.high>>t-32,this.high>=0?0:-1,this.unsigned)},ps.shr=ps.shiftRight,ps.shiftRightUnsigned=function(t){if(qn(t)&&(t=t.toInt()),0===(t&=63))return this;var e=this.high;return t<32?Zn(this.low>>>t|e<<32-t,e>>>t,this.unsigned):Zn(32===t?e:e>>>t-32,0,this.unsigned)},ps.shru=ps.shiftRightUnsigned,ps.shr_u=ps.shiftRightUnsigned,ps.toSigned=function(){return this.unsigned?Zn(this.low,this.high,!1):this},ps.toUnsigned=function(){return this.unsigned?this:Zn(this.low,this.high,!0)},ps.toBytes=function(t){return t?this.toBytesLE():this.toBytesBE()},ps.toBytesLE=function(){var t=this.high,e=this.low;return[255&e,e>>>8&255,e>>>16&255,e>>>24,255&t,t>>>8&255,t>>>16&255,t>>>24]},ps.toBytesBE=function(){var t=this.high,e=this.low;return[t>>>24,t>>>16&255,t>>>8&255,255&t,e>>>24,e>>>16&255,e>>>8&255,255&e]},Vn.fromBytes=function(t,e,n){return n?Vn.fromBytesLE(t,e):Vn.fromBytesBE(t,e)},Vn.fromBytesLE=function(t,e){return new Vn(t[0]|t[1]<<8|t[2]<<16|t[3]<<24,t[4]|t[5]<<8|t[6]<<16|t[7]<<24,e)},Vn.fromBytesBE=function(t,e){return new Vn(t[4]<<24|t[5]<<16|t[6]<<8|t[7],t[0]<<24|t[1]<<16|t[2]<<8|t[3],e)};var ds=Pn(Wn);const fs=ds||et({__proto__:null,default:ds},[Wn]);function gs(t){return fs.fromString(t,!0,16)}function ms(t,e){if("string"===e)throw new Error("Cannot convert a string[] to a TypedArray");if(Array.isArray(t)&&(t=ks(t)),Tn().getBool("DEBUG")&&function(t,e){for(let n=0;n{s=n()};let r;const a=ys();if(this.backendTimer.timerAvailable())r=this.backendTimer.time(i);else{i();for(const t of s)t.dataSync();r=Promise.resolve({kernelMs:ys()-a})}if(Tn().getBool("CHECK_COMPUTATION_FOR_ERRORS"))for(let e=0;e{Ss(e,n.dtype,t)}))}return{kernelName:t,outputs:s,inputs:e,timeMs:r.then((t=>t.kernelMs)),extraInfo:r.then((t=>null!=t.getExtraProfileInfo?t.getExtraProfileInfo():""))}}logKernelProfile(t){const{kernelName:e,outputs:n,timeMs:s,inputs:i,extraInfo:r}=t;n.forEach((t=>{Promise.all([t.data(),s,r]).then((n=>{this.logger.logKernelProfile(e,t,n[0],n[1],i,n[2])}))}))}}function Ss(t,e,n){if("float32"!==e)return!1;for(let e=0;e0?s:""} `}}console.log(`%c${o}\t%c${a}\t%c${l}D ${h}\t%c${u}\t%c${c}\t%c${r}`,"font-weight:bold","color:red","color:blue","color: orange","color: green","color: steelblue")}}function Ns(t,e,n,s){const i=bn(e),r=function(t,e,n,s){const i=un(e),r=s[s.length-1],a=new Array(r).fill(0),o=e.length,l="complex64"===n?Es(t):t;if(o>1)for(let t=0;t" "+t)).join("\n")),l.join("\n")}function Is(t,e,n){let s;return s=Array.isArray(t)?`${parseFloat(t[0].toFixed(7))} + ${parseFloat(t[1].toFixed(7))}j`:gn(t)?`'${t}'`:"bool"===n?As(t):parseFloat(t.toFixed(7)).toString(),pn(s,e)}function As(t){return 0===t?"false":"true"}function Ts(t,e,n,s,i,r=!0){const a="complex64"===n?2:1,o=e[0],l=e.length;if(0===l){if("complex64"===n){return[Is(Es(t)[0],0,n)]}return"bool"===n?[As(t[0])]:[t[0].toString()]}if(1===l){if(o>20){const e=3*a;let s=Array.from(t.slice(0,e)),r=Array.from(t.slice((o-3)*a,o*a));return"complex64"===n&&(s=Es(s),r=Es(r)),["["+s.map(((t,e)=>Is(t,i[e],n))).join(", ")+", ..., "+r.map(((t,e)=>Is(t,i[o-3+e],n))).join(", ")+"]"]}return["["+("complex64"===n?Es(t):Array.from(t)).map(((t,e)=>Is(t,i[e],n))).join(", ")+"]"]}const u=e.slice(1),h=s.slice(1),c=s[0]*a,p=[];if(o>20){for(let e=0;e<3;e++){const s=e*c,r=s+c;p.push(...Ts(t.slice(s,r),u,n,h,i,!1))}p.push("...");for(let e=o-3;e0?p[0]+d:"");for(let t=1;tbs(t)))}catch(t){throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().")}}return t}dataToGPU(t){return this.throwIfDisposed(),Cs().readToGPU(this.dataId,t)}dataSync(){this.throwIfDisposed();const t=Cs().readSync(this.dataId);if("string"===this.dtype)try{return t.map((t=>bs(t)))}catch(t){throw new Error("Failed to decode the string bytes into utf-8. To get the original bytes, call tensor.bytes().")}return t}async bytes(){this.throwIfDisposed();const t=await Cs().read(this.dataId);return"string"===this.dtype?t:new Uint8Array(t.buffer)}dispose(){this.isDisposed||(this.kerasMask&&this.kerasMask.dispose(),Cs().disposeTensor(this),this.isDisposedInternal=!0)}get isDisposed(){return this.isDisposedInternal}throwIfDisposed(){if(this.isDisposed)throw new Error("Tensor is disposed.")}print(t=!1){return null.print(this,t)}clone(){return this.throwIfDisposed(),null.clone(this)}toString(t=!1){return Ns(this.dataSync(),this.shape,this.dtype,t)}cast(t){return this.throwIfDisposed(),null.cast(this,t)}variable(t=!0,e,n){return this.throwIfDisposed(),Cs().makeVariable(this,t,e,n)}}function $s(){return $n("Tensor",(()=>zs))}Object.defineProperty(zs,Symbol.hasInstance,{value:t=>!!t&&null!=t.data&&null!=t.dataSync&&null!=t.throwIfDisposed}),$s();class Fs extends zs{constructor(t,e,n,s){super(t.shape,t.dtype,t.dataId,s),this.trainable=e,this.name=n}assign(t){if(t.dtype!==this.dtype)throw new Error(`dtype of the new value (${t.dtype}) and previous value (${this.dtype}) must match`);if(!hn(t.shape,this.shape))throw new Error(`shape of the new value (${t.shape}) and previous value (${this.shape}) must match`);Cs().disposeTensor(this),this.dataId=t.dataId,Cs().incRef(this,null)}dispose(){Cs().disposeVariable(this),this.isDisposedInternal=!0}}var Ds,Ls,_s,Rs,Ms;Object.defineProperty(Fs,Symbol.hasInstance,{value:t=>t instanceof zs&&null!=t.assign&&t.assign instanceof Function}),function(t){t.R0="R0",t.R1="R1",t.R2="R2",t.R3="R3",t.R4="R4",t.R5="R5",t.R6="R6"}(Ds||(Ds={})),function(t){t.float32="float32",t.int32="int32",t.bool="int32",t.complex64="complex64"}(Ls||(Ls={})),function(t){t.float32="float32",t.int32="int32",t.bool="bool",t.complex64="complex64"}(_s||(_s={})),function(t){t.float32="float32",t.int32="float32",t.bool="float32",t.complex64="complex64"}(Rs||(Rs={})),function(t){t.float32="complex64",t.int32="complex64",t.bool="complex64",t.complex64="complex64"}(Ms||(Ms={}));const Os={float32:Rs,int32:Ls,bool:_s,complex64:Ms};function Bs(t){return null!=t&&"object"==typeof t&&"texture"in t&&t.texture instanceof WebGLTexture}function Ps(t){return"undefined"!=typeof GPUBuffer&&null!=t&&"object"==typeof t&&"buffer"in t&&t.buffer instanceof GPUBuffer}function Us(t,e){if(t.dtype===e.dtype)return[t,e];const n=function(t,e){if("string"===t||"string"===e){if("string"===t&&"string"===e)return"string";throw new Error(`Can not upcast ${t} with ${e}`)}return Os[t][e]}(t.dtype,e.dtype);return[t.cast(n),e.cast(n)]}function Ws(t){const e=[];return js(t,e,new Set),e}function js(t,e,n){if(null==t)return;if(t instanceof zs)return void e.push(t);if(s=t,!Array.isArray(s)&&"object"!=typeof s)return;var s;const i=t;for(const t in i){const s=i[t];n.has(s)||(n.add(s),js(s,e,n))}}function Vs(t){return null!=t.kernelName}class qs{constructor(){this.registeredVariables={},this.nextTapeNodeId=0,this.numBytes=0,this.numTensors=0,this.numStringTensors=0,this.numDataBuffers=0,this.gradientDepth=0,this.kernelDepth=0,this.scopeStack=[],this.numDataMovesStack=[],this.nextScopeId=0,this.tensorInfo=new WeakMap,this.profiling=!1,this.activeProfile={newBytes:0,newTensors:0,peakBytes:0,kernels:[],result:null,get kernelNames(){return Array.from(new Set(this.kernels.map((t=>t.name))))}}}dispose(){for(const t in this.registeredVariables)this.registeredVariables[t].dispose()}}class Ks{constructor(t){this.ENV=t,this.registry={},this.registryFactory={},this.pendingBackendInitId=0,this.state=new qs}async ready(){if(null!=this.pendingBackendInit)return this.pendingBackendInit.then((()=>{}));if(null!=this.backendInstance)return;const t=this.getSortedBackends();for(let e=0;e{null!=t.setupFunc&&t.setupFunc(this.backendInstance)}))}disposeRegisteredKernels(t){Mn(t).forEach((e=>{null!=e.disposeFunc&&e.disposeFunc(this.registry[t])}))}initializeBackend(t){const e=this.registryFactory[t];if(null==e)throw new Error(`Cannot initialize backend ${t}, no registration found.`);try{const n=e.factory();if(!n||n instanceof class{refCount(t){return on("refCount")}incRef(t){return on("incRef")}timerAvailable(){return!0}time(t){return on("time")}read(t){return on("read")}readSync(t){return on("readSync")}readToGPU(t,e){return on("readToGPU")}numDataIds(){return on("numDataIds")}disposeData(t,e){return on("disposeData")}write(t,e,n){return on("write")}move(t,e,n,s,i){return on("move")}createTensorFromGPUData(t,e,n){return on("createTensorFromGPUData")}memory(){return on("memory")}floatPrecision(){return on("floatPrecision")}epsilon(){return 32===this.floatPrecision()?1e-7:1e-4}dispose(){return on("dispose")}}||"function"!=typeof n.then)return this.registry[t]=n,{success:!0,asyncInit:!1};{const e=++this.pendingBackendInitId,s=n.then((n=>!(e(ethis.registryFactory[e].priority-this.registryFactory[t].priority))}initializeBackendsAndReturnBest(){const t=this.getSortedBackends();for(let e=0;ethis.startScope(s)),(()=>this.endScope(n)),(()=>(n=e(),n instanceof Promise&&console.error("Cannot return a Promise inside of tidy."),n)))}scopedRun(t,e,n){t();try{const t=n();return e(),t}catch(t){throw e(),t}}nextTensorId(){return Ks.nextTensorId++}nextVariableId(){return Ks.nextVariableId++}clone(t){const e=Gs.runKernel("Identity",{x:t}),n={x:t};return this.addTapeNode(this.state.activeScope.name,n,[e],(t=>({x:()=>{const e={x:t},n={dtype:"float32"};return Gs.runKernel("Cast",e,n)}})),[],{}),e}runKernel(t,e,n){null==this.backendName&&this.backend;if(!(null!=_n(t,this.backendName)))throw new Error(`Kernel '${t}' not registered for backend '${this.backendName}'`);return this.runKernelFunc({kernelName:t,inputs:e,attrs:n})}shouldCheckForMemLeaks(){return this.ENV.getBool("IS_TEST")}checkKernelForMemLeak(t,e,n){const s=this.backend.numDataIds();let i=0;n.forEach((t=>{i+="complex64"===t.dtype?3:1}));const r=this.state.numDataMovesStack[this.state.numDataMovesStack.length-1],a=s-e-i-r;if(a>0)throw new Error(`Backend '${this.backendName}' has an internal memory leak (${a} data ids) after running '${t}'`)}runKernelFunc(t){let e,n=[];const s=this.isTapeOn(),i=this.state.numBytes,r=this.state.numTensors;let a,o;this.shouldCheckForMemLeaks()&&this.state.numDataMovesStack.push(0),null==this.backendName&&this.backend;const l=Vs(t)?t.kernelName:null!=this.state.activeScope?this.state.activeScope.name:"";if(Vs(t)){const{kernelName:e,inputs:i,attrs:r}=t;null==this.backendName&&this.backend;const l=_n(e,this.backendName);ln(null!=l,(()=>`Cannot find registered kernel '${e}' for backend '${this.backendName}'`)),a=()=>{const t=this.backend.numDataIds();o=l.kernelFunc({inputs:i,attrs:r,backend:this.backend});const a=Array.isArray(o)?o:[o];this.shouldCheckForMemLeaks()&&this.checkKernelForMemLeak(e,t,a);const u=a.map((t=>null!=t.rank?t:this.makeTensorFromTensorInfo(t)));if(s){const t=this.getTensorsForGradient(e,i,u);n=this.saveTensorsForBackwardMode(t)}return u}}else{const{forwardFunc:e}=t,i=t=>{s&&(n=t.map((t=>this.keep(this.clone(t)))))};a=()=>{const t=this.backend.numDataIds();o=this.tidy((()=>e(this.backend,i)));const n=Array.isArray(o)?o:[o];return this.shouldCheckForMemLeaks()&&this.checkKernelForMemLeak(l,t,n),n}}const{inputs:u,attrs:h}=t,c=Vs(t)?null:t.backwardsFunc;let p;return this.scopedRun((()=>this.state.kernelDepth++),(()=>this.state.kernelDepth--),(()=>{this.ENV.getBool("DEBUG")||this.state.profiling?(p=this.profiler.profileKernel(l,u,(()=>a())),this.ENV.getBool("DEBUG")&&this.profiler.logKernelProfile(p),e=p.outputs):e=a()})),s&&this.addTapeNode(l,u,e,c,n,h),this.state.profiling&&this.state.activeProfile.kernels.push({name:l,bytesAdded:this.state.numBytes-i,totalBytesSnapshot:this.state.numBytes,tensorsAdded:this.state.numTensors-r,totalTensorsSnapshot:this.state.numTensors,inputShapes:Object.keys(u).map((t=>null!=u[t]?u[t].shape:null)),outputShapes:e.map((t=>t.shape)),kernelTimeMs:p.timeMs,extraInfo:p.extraInfo}),Array.isArray(o)?e:e[0]}saveTensorsForBackwardMode(t){const e=t.map((t=>this.keep(this.clone(t))));return e}getTensorsForGradient(t,e,n){const s=Rn(t);if(null!=s){const t=s.inputsToSave||[],i=s.outputsToSave||[];let r;s.saveAllInputs?(ln(Array.isArray(e),(()=>"saveAllInputs is true, expected inputs to be an array.")),r=Object.keys(e).map((t=>e[t]))):r=t.map((t=>e[t]));const a=n.filter(((t,e)=>i[e]));return r.concat(a)}return[]}makeTensor(t,e,n,s){if(null==t)throw new Error("Values passed to engine.makeTensor() are null");n=n||"float32",s=s||this.backend;let i=t;"string"===n&&gn(t[0])&&(i=t.map((t=>function(t,e="utf-8"){return e=e||"utf-8",Tn().platform.encode(t,e)}(t))));const r=s.write(i,e,n),a=new zs(e,n,r,this.nextTensorId());if(this.trackTensor(a,s),"string"===n){const t=this.state.tensorInfo.get(r),e=function(t){if(null==t)return 0;let e=0;return t.forEach((t=>e+=t.length)),e}(i);this.state.numBytes+=e-t.bytes,t.bytes=e}return a}makeTensorFromDataId(t,e,n,s){const i={dataId:t,shape:e,dtype:n=n||"float32"};return this.makeTensorFromTensorInfo(i,s)}makeTensorFromTensorInfo(t,e){const{dataId:n,shape:s,dtype:i}=t,r=new zs(s,i,n,this.nextTensorId());return this.trackTensor(r,e),r}makeVariable(t,e=!0,n,s){n=n||this.nextVariableId().toString(),null!=s&&s!==t.dtype&&(t=t.cast(s));const i=new Fs(t,e,n,this.nextTensorId());if(null!=this.state.registeredVariables[i.name])throw new Error(`Variable with name ${i.name} was already registered`);return this.state.registeredVariables[i.name]=i,this.incRef(i,this.backend),i}trackTensor(t,e){this.state.numTensors++,"string"===t.dtype&&this.state.numStringTensors++;let n=0;"complex64"!==t.dtype&&"string"!==t.dtype&&(n=t.size*fn(t.dtype)),this.state.numBytes+=n,this.state.tensorInfo.has(t.dataId)||(this.state.numDataBuffers++,this.state.tensorInfo.set(t.dataId,{backend:e||this.backend,dtype:t.dtype,shape:t.shape,bytes:n})),t instanceof Fs||this.track(t)}incRef(t,e){this.trackTensor(t,e),this.backend.incRef(t.dataId)}removeDataId(t,e){this.state.tensorInfo.has(t)&&this.state.tensorInfo.get(t).backend===e&&(this.state.tensorInfo.delete(t),this.state.numDataBuffers--)}disposeTensor(t){if(!this.state.tensorInfo.has(t.dataId))return;const e=this.state.tensorInfo.get(t.dataId);if(this.state.numTensors--,"string"===t.dtype&&(this.state.numStringTensors--,this.state.numBytes-=e.bytes),"complex64"!==t.dtype&&"string"!==t.dtype){const e=t.size*fn(t.dtype);this.state.numBytes-=e}e.backend.disposeData(t.dataId)&&this.removeDataId(t.dataId,e.backend)}disposeVariables(){for(const t in this.state.registeredVariables){const e=this.state.registeredVariables[t];this.disposeVariable(e)}}disposeVariable(t){this.disposeTensor(t),null!=this.state.registeredVariables[t.name]&&delete this.state.registeredVariables[t.name]}memory(){const t=this.backend.memory();return t.numTensors=this.state.numTensors,t.numDataBuffers=this.state.numDataBuffers,t.numBytes=this.state.numBytes,this.state.numStringTensors>0&&(t.unreliable=!0,null==t.reasons&&(t.reasons=[]),t.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)")),t}async profile(t){this.state.profiling=!0;const e=this.state.numBytes,n=this.state.numTensors;this.state.activeProfile.kernels=[],this.state.activeProfile.result=await t(),this.state.profiling=!1,this.state.activeProfile.peakBytes=Math.max(...this.state.activeProfile.kernels.map((t=>t.totalBytesSnapshot))),this.state.activeProfile.newBytes=this.state.numBytes-e,this.state.activeProfile.newTensors=this.state.numTensors-n;for(const t of this.state.activeProfile.kernels)t.kernelTimeMs=await t.kernelTimeMs,t.extraInfo=await t.extraInfo;return this.state.activeProfile}isTapeOn(){return this.state.gradientDepth>0&&0===this.state.kernelDepth}addTapeNode(t,e,n,s,i,r){const a={id:this.state.nextTapeNodeId++,kernelName:t,inputs:e,outputs:n,saved:i},o=Rn(t);null!=o&&(s=o.gradFunc),null!=s&&(a.gradient=t=>(t=t.map(((t,e)=>{if(null==t){const t=n[e],s=Sn(t.size,t.dtype);return this.makeTensor(s,t.shape,t.dtype)}return t})),s(t.length>1?t:t[0],i,r))),this.state.activeTape.push(a)}keep(t){return t.kept=!0,t}startTape(){0===this.state.gradientDepth&&(this.state.activeTape=[]),this.state.gradientDepth++}endTape(){this.state.gradientDepth--}startScope(t){const e={track:[],name:"unnamed scope",id:this.state.nextScopeId++};t&&(e.name=t),this.state.scopeStack.push(e),this.state.activeScope=e}endScope(t){const e=Ws(t),n=new Set(e.map((t=>t.id)));for(let t=0;t{t.kept||t.scopeId!==s.id||this.track(t)}))}gradients(t,e,n,s=!1){if(ln(e.length>0,(()=>"gradients() received an empty list of xs.")),null!=n&&"float32"!==n.dtype)throw new Error(`dy must have 'float32' dtype, but has '${n.dtype}'`);const i=this.scopedRun((()=>this.startTape()),(()=>this.endTape()),(()=>this.tidy("forward",t)));ln(i instanceof zs,(()=>"The result y returned by f() must be a tensor."));const r=function(t,e,n){const s={},i={};for(let t=0;ts[t.id]=!0)),o=!0,i[r.id]=!0;break}if(o)break}}const r={};r[n.id]=!0;const a={};for(let e=t.length-1;e>=0;e--){const n=t[e],s=n.inputs;for(let t=0;t0)throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y.");return this.tidy("backward",(()=>{const t={};t[i.id]=null==n?function(t){const e=vn(un(t),"float32");return Gs.makeTensor(e,t,"float32")}(i.shape):n,function(t,e,n,s){for(let i=e.length-1;i>=0;i--){const r=e[i],a=[];if(r.outputs.forEach((e=>{const n=t[e.id];null!=n?a.push(n):a.push(null)})),null==r.gradient)throw new Error(`Cannot compute gradient: gradient function not found for ${r.kernelName}.`);const o=r.gradient(a);for(const e in r.inputs){if(!(e in o))throw new Error(`Cannot backprop through input ${e}. Available gradients found: ${Object.keys(o)}.`);const i=n((()=>o[e]()));if("float32"!==i.dtype)throw new Error(`Error in gradient for op ${r.kernelName}. The gradient of input ${e} must have 'float32' dtype, but has '${i.dtype}'`);const a=r.inputs[e];if(!hn(i.shape,a.shape))throw new Error(`Error in gradient for op ${r.kernelName}. The gradient of input '${e}' has shape '${i.shape}', which does not match the shape of the input '${a.shape}'`);if(null==t[a.id])t[a.id]=i;else{const e=t[a.id];t[a.id]=s(e,i),e.dispose()}}}}(t,r,(t=>this.tidy(t)),Hs);const s=e.map((e=>t[e.id]));return 0===this.state.gradientDepth&&(this.state.activeTape.forEach((t=>{for(const e of t.saved)e.dispose()})),this.state.activeTape=null),{value:i,grads:s}}))}customGrad(t){return ln(yn(t),(()=>"The f passed in customGrad(f) must be a function.")),(...e)=>{let n;ln(e.every((t=>t instanceof zs)),(()=>"The args passed in customGrad(f)(x1, x2,...) must all be tensors"));const s={};e.forEach(((t,e)=>{s[e]=t}));return this.runKernelFunc({forwardFunc:(s,i)=>(n=t(...e,i),ln(n.value instanceof zs,(()=>"The function f passed in customGrad(f) must return an object where `obj.value` is a tensor")),ln(yn(n.gradFunc),(()=>"The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function.")),n.value),backwardsFunc:(t,s)=>{const i=n.gradFunc(t,s),r=Array.isArray(i)?i:[i];ln(r.length===e.length,(()=>"The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...).")),ln(r.every((t=>t instanceof zs)),(()=>"The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors."));const a={};return r.forEach(((t,e)=>{a[e]=()=>t})),a},inputs:s})}}readSync(t){return this.state.tensorInfo.get(t).backend.readSync(t)}read(t){return this.state.tensorInfo.get(t).backend.read(t)}readToGPU(t,e){return this.state.tensorInfo.get(t).backend.readToGPU(t,e)}async time(t){const e=ys(),n=await this.backend.time(t);return n.wallMs=ys()-e,n}track(t){return null!=this.state.activeScope&&(t.scopeId=this.state.activeScope.id,this.state.activeScope.track.push(t)),t}get registeredVariables(){return this.state.registeredVariables}reset(){this.pendingBackendInitId++,this.state.dispose(),this.ENV.reset(),this.state=new qs;for(const t in this.registry)this.disposeRegisteredKernels(t),this.registry[t].dispose(),delete this.registry[t];this.backendName=null,this.backendInstance=null,this.pendingBackendInit=null}}Ks.nextTensorId=0,Ks.nextVariableId=0;const Gs=function(){const t=zn();if(null==t._tfengine){const e=new In(t);t._tfengine=new Ks(e)}var e;return e=t._tfengine.ENV,Cn=e,Cs=()=>t._tfengine,t._tfengine}();function Hs(t,e){const n={a:t,b:e};return Gs.runKernel("Add",n)}function Js(t,e){let n=t;if(ws(t))return"string"===e?[]:[t.length];if(Bs(t)){const e=t.channels||"RGBA";return[t.height,t.width*e.length]}if(Ps(t))return[t.buffer.size/(null==e?4:fn(e))];if(!Array.isArray(t))return[];const s=[];for(;Array.isArray(n)||ws(n)&&"string"!==e;)s.push(n.length),n=n[0];return Array.isArray(t)&&Tn().getBool("TENSORLIKE_CHECK_SHAPE_CONSISTENCY")&&Zs(t,s,[]),s}function Zs(t,e,n){if(n=n||[],!Array.isArray(t)&&!ws(t))return void ln(0===e.length,(()=>`Element arr[${n.join("][")}] is a primitive, but should be an array/TypedArray of ${e[0]} elements`));ln(e.length>0,(()=>`Element arr[${n.join("][")}] should be a primitive, but is an array of ${t.length} elements`)),ln(t.length===e[0],(()=>`Element arr[${n.join("][")}] should have ${e[0]} elements, but has ${t.length} elements`));const s=e.slice(1);for(let e=0;e=0&&(i=s),Ys(s,i,e,n),null==t||!ws(t)&&!Array.isArray(t)&&"number"!=typeof t&&"boolean"!=typeof t&&"string"!=typeof t){const s=null==t?"null":t.constructor.name;throw new Error(`Argument '${e}' passed to '${n}' must be a Tensor or TensorLike, but got '${s}'`)}const r=Js(t,i);ws(t)||Array.isArray(t)||(t=[t]);const a="string"!==i?ms(t,i):ks(t,[],!0);return Gs.makeTensor(a,r,i)}function Qs(t,e,n,s="numeric"){if(!Array.isArray(t))throw new Error(`Argument ${e} passed to ${n} must be a \`Tensor[]\` or \`TensorLike[]\``);return t.map(((t,i)=>Xs(t,`${e}[${i}]`,n,s)))}function ti(t){const e=Object.keys(t);if(1!==e.length)throw new Error(`Please provide an object with a single key (operation name) mapping to a function. Got an object with ${e.length} keys.`);let n=e[0];const s=t[n];n.endsWith("_")&&(n=n.substring(0,n.length-1)),n+="__op";const i=(...t)=>{Gs.startScope(n);try{const e=s(...t);return Nn(e)&&console.error("Cannot return a Promise inside of tidy."),Gs.endScope(e),e}catch(t){throw Gs.endScope(null),t}};return Object.defineProperty(i,"name",{value:n,configurable:!0}),i}const ei=ti({cast_:function(t,e){const n=Xs(t,"x","cast");if(!function(t){return"bool"===t||"complex64"===t||"float32"===t||"int32"===t||"string"===t}(e))throw new Error(`Failed to cast to unknown dtype ${e}`);if("string"===e&&"string"!==n.dtype||"string"!==e&&"string"===n.dtype)throw new Error("Only strings can be casted to strings");const s={x:n},i={dtype:e};return Gs.runKernel("Cast",s,i)}});const ni=ti({mul_:function(t,e){let n=Xs(t,"a","mul"),s=Xs(e,"b","mul");[n,s]=Us(n,s);const i={a:n,b:s};return Gs.runKernel("Multiply",i)}});const si=ti({step_:function(t,e=0){const n={x:Xs(t,"x","step")},s={alpha:e};return Gs.runKernel("Step",n,s)}}),ii={kernelName:"Abs",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(t,si(ei(n,"float32"),-1))}}};const ri=ti({floorDiv_:function(t,e){let n=Xs(t,"a","floorDiv"),s=Xs(e,"b","floorDiv");[n,s]=Us(n,s);const i={a:n,b:s};return Gs.runKernel("FloorDiv",i)}});const ai=ti({div_:function(t,e){let n=Xs(t,"a","div"),s=Xs(e,"b","div");if([n,s]=Us(n,s),"int32"===n.dtype&&"int32"===s.dtype)return ri(n,s);const i={a:n,b:s};return Gs.runKernel("RealDiv",i,{})}});const oi=ti({neg_:function(t){const e={x:Xs(t,"x","neg")};return Gs.runKernel("Neg",e)}});function li(t,e){if((ws(t)&&"string"!==e||Array.isArray(t))&&"complex64"!==e)throw new Error("Error creating a new Scalar: value must be a primitive (number|boolean|string)");if("string"===e&&ws(t)&&!(t instanceof Uint8Array))throw new Error("When making a scalar from encoded string, the value must be `Uint8Array`.");return function(t,e,n,s){if(null==s)s=mn(t);else if("complex64"===s)throw new Error("Cannot construct a complex64 tensor directly. Please use tf.complex(real, imag).");if(Ps(t)||Bs(t)){if("float32"!==s&&"int32"!==s)throw new Error(`Creating tensor from GPU data only supports 'float32'|'int32' dtype, while the dtype is ${s}.`);return Gs.backend.createTensorFromGPUData(t,e||n,s)}if(!ws(t)&&!Array.isArray(t)&&"number"!=typeof t&&"boolean"!=typeof t&&"string"!=typeof t)throw new Error("values passed to tensor(values) must be a number/boolean/string or an array of numbers/booleans/strings, or a TypedArray");if(null!=e){xn(e);const t=un(e),s=un(n);ln(t===s,(()=>`Based on the provided shape, [${e}], the tensor should have ${t} values but has ${s}`));for(let t=0;t`Error creating a new Tensor. Inferred shape (${n}) does not match the provided shape (${e}). `))}}return ws(t)||Array.isArray(t)||(t=[t]),e=e||n,t="string"!==s?ms(t,s):ks(t,[],!0),Gs.makeTensor(t,e,s)}(t,[],[],e)}const ui=ti({sqrt_:function(t){const e={x:Xs(t,"x","sqrt","float32")};return Gs.runKernel("Sqrt",e)}});const hi=ti({square_:function(t){const e=Xs(t,"x","square");return Gs.runKernel("Square",{x:e},{})}});const ci=ti({sub_:function(t,e){let n=Xs(t,"a","sub"),s=Xs(e,"b","sub");[n,s]=Us(n,s);const i={a:n,b:s};return Gs.runKernel("Sub",i)}}),pi={kernelName:"Acos",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>{const e=hi(ei(n,"float32")),s=ui(ci(li(1),e));return oi(ai(t,s))}}}},di={kernelName:"Acosh",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>{const e=ui(ci(hi(ei(n,"float32")),1));return ai(t,e)}}}};function fi(t,e){const n=[];for(let s=0;s1)&&n.unshift(r)}return n}function gi(t,e){const n=Math.max(t.length,e.length),s=new Array(n);for(let i=0;i{const[n,s]=e,i=gi(n.shape,s.shape);return{a:()=>{let e=t;const s=fi(n.shape,i);return s.length>0&&(e=yi(e,s)),mi(e,n.shape)},b:()=>{let e=t;const n=fi(s.shape,i);return n.length>0&&(e=yi(e,n)),mi(e,s.shape)}}}},wi={kernelName:"AddN",saveAllInputs:!0,gradFunc:(t,e)=>{const n={};return e.forEach(((e,s)=>{n[s]=()=>t.clone()})),n}};const ki=ti({zerosLike_:function(t){const e={x:Xs(t,"x","zerosLike")};return Gs.runKernel("ZerosLike",e)}}),vi={kernelName:"ArgMax",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ki(n)}}},Si={kernelName:"ArgMin",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ki(n)}}},xi={kernelName:"Asin",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ai(t,ui(ci(li(1),hi(ei(n,"float32")))))}}};const Ni=ti({add_:function(t,e){let n=Xs(t,"a","add"),s=Xs(e,"b","add");[n,s]=Us(n,s);const i={a:n,b:s};return Gs.runKernel("Add",i)}}),Ii={kernelName:"Asinh",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>{const e=ui(Ni(li(1),hi(ei(n,"float32"))));return ai(t,e)}}}},Ai={kernelName:"Atan2",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=gi(n.shape,s.shape);return{a:()=>{const e=Ni(hi(n),hi(s));let r=ni(t,ai(s,e));const a=fi(n.shape,i);return a.length>0&&(r=yi(r,a)),mi(r,n.shape)},b:()=>{const e=Ni(hi(n),hi(s));let r=oi(ni(t,ai(n,e)));const a=fi(s.shape,i);return a.length>0&&(r=yi(r,a)),mi(r,s.shape)}}}},Ti={kernelName:"Atan",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ai(t,Ni(hi(ei(n,"float32")),1))}}},Ei={kernelName:"Atanh",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ai(t,ci(li(1),hi(ei(n,"float32"))))}}};function Ci(t){return"number"==typeof t?[t,t,t]:2===t.length?[t[0],t[1],1]:t}function zi(t){const[e,n,s]=Ci(t);return 1===e&&1===n&&1===s}function $i(t,e){return zi(t)||zi(e)}function Fi(t){return Ci(t).every((t=>t>0))}function Di(t,e,n){if(null!=n){if("string"==typeof e)throw Error(`Error in ${t}: pad must be an integer when using dimRoundingMode ${n} but got pad ${e}.`);if("number"==typeof e)ln(cn(e),(()=>`Error in ${t}: pad must be an integer when using dimRoundingMode ${n} but got pad ${e}.`));else{if("object"!=typeof e)throw Error(`Error in ${t}: Unknown padding parameter: ${e}`);e.forEach((e=>{e.forEach((e=>{ln(cn(e),(()=>`Error in ${t}: pad must be an integer when using dimRoundingMode ${n} but got pad ${e}.`))}))}))}}}const Li=ti({avgPool3dGrad_:function(t,e,n,s,i,r){const a=Xs(t,"dy","avgPool3dGrad"),o=Xs(e,"input","avgPool3dGrad");let l=a,u=o,h=!1;4===o.rank&&(h=!0,l=mi(a,[1,a.shape[0],a.shape[1],a.shape[2],a.shape[3]]),u=mi(o,[1,o.shape[0],o.shape[1],o.shape[2],o.shape[3]])),ln(5===l.rank,(()=>`Error in avgPool3dGrad: dy must be rank 5 but got rank ${l.rank}.`)),ln(5===u.rank,(()=>`Error in avgPool3dGrad: input must be rank 5 but got rank ${u.rank}.`)),Di("avgPool3dGrad",i,r);const c={dy:l,input:u},p={filterSize:n,strides:s,pad:i,dimRoundingMode:r},d=Gs.runKernel("AvgPool3DGrad",c,p);return h?mi(d,[d.shape[1],d.shape[2],d.shape[3],d.shape[4]]):d}}),_i={kernelName:"AvgPool3D",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{filterSize:i,strides:r,pad:a,dimRoundingMode:o}=n;return{x:()=>Li(t,s,i,r,a,o)}}};const Ri=ti({avgPoolGrad_:function(t,e,n,s,i){const r=Xs(t,"dy","avgPoolGrad"),a=Xs(e,"input","avgPoolGrad");ln(a.rank===r.rank,(()=>`Rank of input (${a.rank}) does not match rank of dy (${r.rank})`));let o=a,l=r,u=!1;3===a.rank&&(u=!0,o=mi(a,[1,a.shape[0],a.shape[1],a.shape[2]]),l=mi(r,[1,r.shape[0],r.shape[1],r.shape[2]])),ln(4===l.rank,(()=>`Error in avgPoolGrad: dy must be rank 4 but got rank ${l.rank}.`)),ln(4===o.rank,(()=>`Error in avgPoolGrad: input must be rank 4 but got rank ${o.rank}.`));const h={dy:l,input:o},c={filterSize:n,strides:s,pad:i},p=Gs.runKernel("AvgPoolGrad",h,c);return u?mi(p,[p.shape[1],p.shape[2],p.shape[3]]):p}}),Mi={kernelName:"AvgPool",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{filterSize:i,strides:r,pad:a}=n;return{x:()=>Ri(t,s,i,r,a)}}};const Oi=ti({matMul_:function(t,e,n=!1,s=!1){let i=Xs(t,"a","matMul"),r=Xs(e,"b","matMul");[i,r]=Us(i,r);const a={a:i,b:r},o={transposeA:n,transposeB:s};return Gs.runKernel("BatchMatMul",a,o)}}),Bi={kernelName:"BatchMatMul",inputsToSave:["a","b"],gradFunc:(t,e,n)=>{const[s,i]=e,{transposeA:r,transposeB:a}=n;return r||a?!r&&a?{a:()=>Oi(t,i,!1,!1),b:()=>Oi(t,s,!0,!1)}:r&&!a?{a:()=>Oi(i,t,!1,!0),b:()=>Oi(s,t,!1,!1)}:{a:()=>Oi(i,t,!0,!0),b:()=>Oi(t,s,!0,!0)}:{a:()=>Oi(t,i,!1,!0),b:()=>Oi(s,t,!0,!1)}}};const Pi=ti({spaceToBatchND_:function(t,e,n){const s=Xs(t,"x","spaceToBatchND");ln(s.rank>=1+e.length,(()=>`input rank ${s.rank} should be > than [blockShape] ${e.length}`)),ln(n.length===e.length,(()=>`paddings.shape[0] ${n.length} must be equal to [blockShape] ${e.length}`)),ln(s.shape.reduce(((t,s,i)=>i>0&&i<=e.length?t&&(s+n[i-1][0]+n[i-1][1])%e[i-1]==0:t),!0),(()=>`input spatial dimensions ${s.shape.slice(1)} with paddings ${n.toString()} must be divisible by blockShapes ${e.toString()}`));const i={x:s},r={blockShape:e,paddings:n};return Gs.runKernel("SpaceToBatchND",i,r)}}),Ui={kernelName:"BatchToSpaceND",gradFunc:(t,e,n)=>{const{blockShape:s,crops:i}=n;return{x:()=>Pi(t,s,i)}}},Wi={kernelName:"BroadcastTo",gradFunc:(t,e,n)=>{const s=n,i=s.inputShape,r=s.shape,a=Array.from(r);for(let t=i.length-1;t>=0;t--)if(i[t]===r[t])a[t]=1;else if(1!==i[t])throw new Error(`broadcastTo(): [${i}] cannot be broadcast to [${r}].`);const o=[];for(let t=0;t1&&o.push(t);return{x:()=>yi(t,o,!0)}}},ji={kernelName:"Cast",gradFunc:t=>({x:()=>t.clone()})},Vi={kernelName:"Ceil",gradFunc:t=>({x:()=>ki(t)})};const qi=ti({greaterEqual_:function(t,e){let n=Xs(t,"a","greaterEqual","string_or_numeric"),s=Xs(e,"b","greaterEqual","string_or_numeric");[n,s]=Us(n,s),gi(n.shape,s.shape);const i={a:n,b:s};return Gs.runKernel("GreaterEqual",i)}});const Ki=ti({lessEqual_:function(t,e){let n=Xs(t,"a","lessEqual","string_or_numeric"),s=Xs(e,"b","lessEqual","string_or_numeric");[n,s]=Us(n,s),gi(n.shape,s.shape);const i={a:n,b:s};return Gs.runKernel("LessEqual",i)}});const Gi=ti({logicalAnd_:function(t,e){const n=Xs(t,"a","logicalAnd","bool"),s=Xs(e,"b","logicalAnd","bool");gi(n.shape,s.shape);const i={a:n,b:s};return Gs.runKernel("LogicalAnd",i)}});const Hi=ti({clone_:function(t){const e={x:Xs(t,"x","clone","string_or_numeric")};return Gs.runKernel("Identity",e)}});const Ji=ti({broadcastTo_:function(t,e){let n=Xs(t,"broadcastTo","x");const s=n.shape;if(xn(e),e.lengthn.rank){const t=n.shape.slice();for(;t.length=0;t--)if(i[t]===e[t])r[t]=1;else if(1!==n.shape[t])throw new Error(`broadcastTo(): [${s}] cannot be broadcast to [${e}].`);if(0===r.map(((t,e)=>t>1?e:-1)).filter((t=>t>=0)).length)return Hi(n);const a={x:n},o={reps:r};return Gs.runKernel("Tile",a,o)}});const Zi=ti({where_:function(t,e,n){const s=Xs(e,"a","where"),i=Xs(n,"b","where"),r=Xs(t,"condition","where","bool"),a=gi(gi(r.shape,s.shape),i.shape),o={condition:Ji(r,a),t:Ji(s,a),e:Ji(i,a)};return Gs.runKernel("Select",o)}}),Yi={kernelName:"ClipByValue",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{clipValueMin:i,clipValueMax:r}=n;return{x:()=>Zi(Gi(qi(s,i),Ki(s,r)),t,ki(t))}}},Xi={kernelName:"ComplexAbs",inputsToSave:["x"],gradFunc:ii.gradFunc};const Qi=ti({split_:function(t,e,n=0){const s={x:Xs(t,"x","split")},i={numOrSizeSplits:e,axis:n};return Gs.runKernel("SplitV",s,i)}}),tr={kernelName:"Concat",saveAllInputs:!0,gradFunc:(t,e,n)=>{const s=e.map((t=>t.shape)),{axis:i}=n,r=dn(i,e[0].shape)[0],a=s.map((t=>t[r]));return Qi(t,a,r).map((t=>()=>t))}};const er=ti({conv2DBackpropFilter_:function(t,e,n,s,i,r="NHWC",a){let o=t;3===t.rank&&(o=mi(t,[1,t.shape[0],t.shape[1],t.shape[2]]));let l=e;3===l.rank&&(l=mi(e,[1,e.shape[0],e.shape[1],e.shape[2]])),ln(4===o.rank,(()=>`Error in conv2dDerFilter: input must be rank 4, but got shape ${o.shape}.`)),ln(4===l.rank,(()=>`Error in conv2dDerFilter: dy must be rank 4, but got shape ${l.shape}.`)),ln(4===n.length,(()=>`Error in conv2dDerFilter: filterShape must be length 4, but got ${n}.`));const u="NHWC"===r?o.shape[3]:o.shape[1],h="NHWC"===r?l.shape[3]:l.shape[1];ln(u===n[2],(()=>`Error in conv2dDerFilter: depth of input ${u}) must match input depth in filter (${n[2]}.`)),ln(h===n[3],(()=>`Error in conv2dDerFilter: depth of dy (${h}) must match output depth for filter (${n[3]}).`)),Di("conv2dDerFilter",i,a);const c={x:o,dy:l},p={strides:s,pad:i,dataFormat:r,dimRoundingMode:a,filterShape:n};return Gs.runKernel("Conv2DBackpropFilter",c,p)}});const nr=ti({conv2DBackpropInput_:function(t,e,n,s,i,r="NHWC",a){ln(t.length===e.rank,(()=>`Length of inShape (${t.length}) and rank of dy (${e.rank}) must match`));let o=t,l=e,u=!1;3===e.rank&&(u=!0,l=mi(e,[1,e.shape[0],e.shape[1],e.shape[2]]),o=[1,t[0],t[1],t[2]]),ln(4===o.length,(()=>`Error in conv2dDerInput: inShape must be length 4, but got length ${o.length}.`)),ln(4===l.rank,(()=>`Error in conv2dDerInput: dy must be rank 4, but got rank ${l.rank}`)),ln(4===n.rank,(()=>`Error in conv2dDerInput: filter must be rank 4, but got rank ${n.rank}`));const h="NHWC"===r?o[3]:o[1],c="NHWC"===r?l.shape[3]:l.shape[1];ln(h===n.shape[2],(()=>`Error in conv2dDerInput: depth of input (${h}) must match input depth for filter ${n.shape[2]}.`)),ln(c===n.shape[3],(()=>`Error in conv2dDerInput: depth of output (${c}) must match output depth for filter ${n.shape[3]}.`)),Di("conv2dDerInput",i,a);const p={dy:l,filter:n},d={strides:s,pad:i,dataFormat:r,dimRoundingMode:a,inputShape:o},f=Gs.runKernel("Conv2DBackpropInput",p,d);return u?mi(f,[f.shape[1],f.shape[2],f.shape[3]]):f}}),sr={kernelName:"Conv2D",inputsToSave:["x","filter"],gradFunc:(t,e,n)=>{const[s,i]=e,{dilations:r,strides:a,pad:o,dataFormat:l}=n;return ln(zi(r),(()=>`Error in gradient of conv2D: dilation rates greater than 1 are not yet supported in gradients. Got dilations '${r}'`)),{x:()=>nr(s.shape,t,i,a,o,l),filter:()=>er(s,t,i.shape,a,o,l)}}};const ir=ti({conv2d_:function(t,e,n,s,i="NHWC",r=[1,1],a){const o=Xs(t,"x","conv2d","float32"),l=Xs(e,"filter","conv2d","float32");let u=o,h=!1;3===o.rank&&(h=!0,u=mi(o,[1,o.shape[0],o.shape[1],o.shape[2]])),ln(4===u.rank,(()=>`Error in conv2d: input must be rank 4, but got rank ${u.rank}.`)),ln(4===l.rank,(()=>`Error in conv2d: filter must be rank 4, but got rank ${l.rank}.`)),Di("conv2d",s,a);const c="NHWC"===i?u.shape[3]:u.shape[1];ln(c===l.shape[2],(()=>`Error in conv2d: depth of input (${c}) must match input depth for filter ${l.shape[2]}.`)),ln($i(n,r),(()=>`Error in conv2D: Either strides or dilations must be 1. Got strides ${n} and dilations '${r}'`)),ln(Fi(r),(()=>"Error in conv2D: Dilated rates should be larger than 0.")),ln(Fi(n),(()=>"Error in conv2D: Strides should be larger than 0."));const p={x:u,filter:l},d={strides:n,pad:s,dataFormat:i,dilations:r,dimRoundingMode:a},f=Gs.runKernel("Conv2D",p,d);return h?mi(f,[f.shape[1],f.shape[2],f.shape[3]]):f}}),rr={kernelName:"Conv2DBackpropInput",inputsToSave:["dy","filter"],gradFunc:(t,e,n)=>{const[s,i]=e,{strides:r,pad:a,dataFormat:o,dimRoundingMode:l}=n;return{dy:()=>ir(t,i,r,a,o,1,l),filter:()=>er(t,s,i.shape,r,a,o,l)}}};const ar=ti({conv3DBackpropFilter_:function(t,e,n,s,i){let r=t;4===t.rank&&(r=mi(t,[1,t.shape[0],t.shape[1],t.shape[2],t.shape[3]]));let a=e;4===a.rank&&(a=mi(e,[1,e.shape[0],e.shape[1],e.shape[2],e.shape[3]])),ln(5===r.rank,(()=>`Error in conv3dDerFilter: input must be rank 5, but got shape ${r.shape}.`)),ln(5===a.rank,(()=>`Error in conv3dDerFilter: dy must be rank 5, but got shape ${a.shape}.`)),ln(5===n.length,(()=>`Error in conv3dDerFilter: filterShape must be length 5, but got ${n}.`)),ln(r.shape[4]===n[3],(()=>`Error in conv3dDerFilter: depth of input ${r.shape[4]}) must match input depth in filter (${n[3]}.`)),ln(a.shape[4]===n[4],(()=>`Error in conv3dDerFilter: depth of dy (${a.shape[4]}) must match output depth for filter (${n[4]}).`));const o={x:r,dy:a},l={strides:s,pad:i,filterShape:n};return Gs.runKernel("Conv3DBackpropFilterV2",o,l)}});const or=ti({conv3DBackpropInput_:function(t,e,n,s,i){ln(t.length===e.rank,(()=>`Length of inShape (${t.length}) and rank of dy (${e.rank}) must match`));let r=t,a=e,o=!1;4===e.rank&&(o=!0,a=mi(e,[1,e.shape[0],e.shape[1],e.shape[2],e.shape[3]]),r=[1,t[0],t[1],t[2],t[3]]);const l=r[4],u=a.shape[4];ln(5===r.length,(()=>`Error in conv3dDerInput: inShape must be length 5, but got length ${r.length}.`)),ln(5===a.rank,(()=>`Error in conv3dDerInput: dy must be rank 5, but got rank ${a.rank}`)),ln(5===n.rank,(()=>`Error in conv3dDerInput: filter must be rank 5, but got rank ${n.rank}`)),ln(l===n.shape[3],(()=>`Error in conv3dDerInput: depth of input (${l}) must match input depth for filter ${n.shape[3]}.`)),ln(u===n.shape[4],(()=>`Error in conv3dDerInput: depth of output (${u}) must match output depth for filter ${n.shape[4]}.`));const h={dy:a,filter:n},c={pad:i,strides:s,inputShape:r},p=Gs.runKernel("Conv3DBackpropInputV2",h,c);return o?mi(p,[p.shape[1],p.shape[2],p.shape[3],p.shape[4]]):p}}),lr={kernelName:"Conv3D",inputsToSave:["x","filter"],gradFunc:(t,e,n)=>{const{dilations:s,strides:i,pad:r}=n;ln(zi(s),(()=>`Error in gradient of conv3D: dilation rates greater than 1 are not yet supported in gradients. Got dilations '${s}'`));const[a,o]=e;return{x:()=>or(a.shape,t,o,i,r),filter:()=>ar(a,t,o.shape,i,r)}}};const ur=ti({sin_:function(t){const e={x:Xs(t,"x","sin","float32")};return Gs.runKernel("Sin",e)}}),hr={kernelName:"Cos",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(oi(ur(ei(n,"float32"))),t)}}};const cr=ti({sinh_:function(t){const e={x:Xs(t,"x","sinh")};return Gs.runKernel("Sinh",e)}}),pr={kernelName:"Cosh",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(cr(ei(n,"float32")),t)}}};function dr(t,e){return function(t,e,n){const s=t.length+e.length,i=[];let r=0,a=0;for(let o=0;o1)),e)}function fr(t,e){if(function(t,e){for(let n=0;nn.push(t))),n}function gr(t){return t.map(((t,e)=>[e,t])).sort(((t,e)=>t[1]-e[1])).map((t=>t[0]))}const mr=ti({cumsum_:function(t,e=0,n=!1,s=!1){const i={x:Xs(t,"x","cumsum")},r={axis:e,exclusive:n,reverse:s};return Gs.runKernel("Cumsum",i,r)}});const yr=ti({complex_:function(t,e){const n=Xs(t,"real","complex"),s=Xs(e,"imag","complex");!function(t,e,n=""){ln(hn(t,e),(()=>n+` Shapes ${t} and ${e} must match`))}(n.shape,s.shape,`real and imag shapes, ${n.shape} and ${s.shape}, must match in call to tf.complex().`);const i={real:n,imag:s};return Gs.runKernel("Complex",i)}});const br=ti({imag_:function(t){const e={input:Xs(t,"input","imag")};return Gs.runKernel("Imag",e)}});const wr=ti({real_:function(t){const e={input:Xs(t,"input","real")};return Gs.runKernel("Real",e)}});const kr=ti({transpose_:function(t,e,n){const s=Xs(t,"x","transpose");if(null==e&&(e=s.shape.map(((t,e)=>e)).reverse()),ln(s.rank===e.length,(()=>`Error in transpose: rank of input ${s.rank} must match length of perm ${e}.`)),e.forEach((t=>{ln(t>=0&&t"All entries in 'perm' must be between 0 and "+(s.rank-1)+` but got ${e}`))})),s.rank<=1)return s.clone();const i={x:s},r={perm:e};return"complex64"===s.dtype?(a=()=>{let t=wr(s),e=br(s);return t=Gs.runKernel("Transpose",{x:t},r),e=Gs.runKernel("Transpose",{x:e},r),n&&(e=oi(e)),yr(t,e)},Gs.tidy(a,o)):Gs.runKernel("Transpose",i,r);var a,o}}),vr={kernelName:"Cumsum",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{axis:i,exclusive:r,reverse:a}=n;return{x:()=>{const e=fr([i],s.rank);let n=mr(t,i,r,!a);return null!=e&&(n=kr(n,e)),n}}}};const Sr=ti({depthwiseConv2dNativeBackpropFilter_:function(t,e,n,s,i,r=[1,1],a){let o=t;3===t.rank&&(o=mi(t,[1,t.shape[0],t.shape[1],t.shape[2]]));let l=e;3===l.rank&&(l=mi(e,[1,e.shape[0],e.shape[1],e.shape[2]]));const u={x:o,dy:l},h={strides:s,pad:i,dimRoundingMode:a,dilations:r,filterShape:n};return Gs.runKernel("DepthwiseConv2dNativeBackpropFilter",u,h)}});const xr=ti({depthwiseConv2dNativeBackpropInput_:function(t,e,n,s,i,r=[1,1],a){let o=e,l=!1;3===e.rank&&(l=!0,o=mi(e,[1,e.shape[0],e.shape[1],e.shape[2]]));const u={dy:o,filter:n},h={strides:s,pad:i,dimRoundingMode:a,dilations:r,inputShape:t},c=Gs.runKernel("DepthwiseConv2dNativeBackpropInput",u,h);return l?mi(c,[c.shape[1],c.shape[2],c.shape[3]]):c}}),Nr={kernelName:"DepthwiseConv2dNative",inputsToSave:["x","filter"],gradFunc:(t,e,n)=>{const{dilations:s,strides:i,pad:r,dimRoundingMode:a}=n,o=null==s?[1,1]:s;ln(zi(o),(()=>`Error in gradient of depthwiseConv2dNative: dilation rates greater than 1 are not yet supported. Got dilations '${o}'`));const[l,u]=e;return ln(4===l.rank,(()=>`Error in gradient of depthwiseConv2dNative: input must be rank 4, but got rank ${l.rank}.`)),ln(4===u.rank,(()=>`Error in gradient of depthwiseConv2dNative: filter must be rank 4, but got rank ${u.rank}.`)),ln(l.shape[3]===u.shape[2],(()=>`Error in gradient of depthwiseConv2d: number of input channels (${l.shape[3]}) must match the inChannels dimension in filter ${u.shape[2]}.`)),ln($i(i,o),(()=>`Error in gradient of depthwiseConv2d: Either strides or dilations must be 1. Got strides ${i} and dilations '${o}'.`)),Di("depthwiseConv2d",r,a),{x:()=>xr(l.shape,t,u,i,r,o,a),filter:()=>Sr(l,t,u.shape,i,r,o,a)}}},Ir={kernelName:"Dilation2D",inputsToSave:["x","filter"],gradFunc:(t,e,n)=>{const[s,i]=e,r={x:s,filter:i,dy:t},a={x:s,filter:i,dy:t};return{x:()=>Gs.runKernel("Dilation2DBackpropInput",r,n),filter:()=>Gs.runKernel("Dilation2DBackpropFilter",a,n)}}},Ar={kernelName:"Elu",outputsToSave:[!0],gradFunc:(t,e)=>{const[n]=e,s={dy:t,y:n};return{x:()=>Gs.runKernel("EluGrad",s)}}};const Tr=ti({exp_:function(t){const e={x:Xs(t,"x","exp")};return Gs.runKernel("Exp",e)}}),Er={kernelName:"Erf",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e,s=ni(Tr(oi(hi(n))),2/Math.sqrt(Math.PI));return{x:()=>ni(t,s)}}},Cr={kernelName:"Exp",outputsToSave:[!0],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(t,n)}}},zr={kernelName:"ExpandDims",inputsToSave:["input"],gradFunc:(t,e)=>{const[n]=e;return{input:()=>mi(t,n.shape)}}},$r={kernelName:"Expm1",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(t,Tr(n))}}},Fr={kernelName:"Floor",gradFunc:t=>({x:()=>ki(t)})},Dr={kernelName:"FloorDiv",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=gi(n.shape,s.shape);return{a:()=>{const e=ai(t,ei(s,"float32")),r=fi(n.shape,i);return r.length>0?mi(yi(e,r),n.shape):e},b:()=>{let e=ni(t,ei(n,"float32"));const r=fi(s.shape,i);r.length>0&&(e=mi(yi(e,r),s.shape));const a=hi(s);return oi(ai(e,ei(a,"float32")))}}}};const Lr=ti({rsqrt_:function(t){const e={x:Xs(t,"x","rsqrt","float32")};return Gs.runKernel("Rsqrt",e)}});const _r=ti({tile_:function(t,e){const n=Xs(t,"x","tile","string_or_numeric");ln(n.rank===e.length,(()=>`Error in transpose: rank of input ${n.rank} must match length of reps ${e}.`));const s={x:n},i={reps:e};return Gs.runKernel("Tile",s,i)}}),Rr={kernelName:"FusedBatchNorm",inputsToSave:["x","mean","variance","scale"],gradFunc:(t,e,n)=>{const{varianceEpsilon:s}=n,[i,r,a,o]=e,l=null==o?li(1):o,u=fi(r.shape,i.shape),h=[];if(1===r.rank){for(let t=0;t1===r.rank?mi(ni(ni(t,_r(mi(d,[1,1,1,r.shape[0]]),h)),l),i.shape):mi(ni(ni(t,d),l),i.shape),mean:()=>{let t=ni(ni(d,li(-1)),p);return 1===r.rank&&(t=yi(t,u)),mi(t,r.shape)},variance:()=>{let t=ni(ni(f,c),p);return 1===r.rank&&(t=yi(t,u)),mi(t,r.shape)},scale:()=>{const e=ni(c,d);let n=ni(t,e);return 1===r.rank&&(n=yi(n,u)),mi(n,r.shape)},offset:()=>{let e=t;return 1===r.rank&&(e=yi(e,u)),mi(e,r.shape)}}}};const Mr=ti({stack_:function(t,e=0){const n=Qs(t,"tensors","stack","string_or_numeric");ln(n.length>=1,(()=>"Pass at least one tensor to tf.stack")),n.length>0&&ln(e<=n[0].rank,(()=>"Axis must be <= rank of the tensor"));const s=n,i={axis:e};return Gs.runKernel("Pack",s,i)}});const Or=ti({unsortedSegmentSum_:function(t,e,n){const s=Xs(t,"x","unsortedSegmentSum"),i=Xs(e,"segmentIds","unsortedSegmentSum","int32");ln(cn(n),(()=>"numSegments must be of dtype int"));const r={x:s,segmentIds:i},a={numSegments:n};return Gs.runKernel("UnsortedSegmentSum",r,a)}}),Br={kernelName:"GatherV2",inputsToSave:["x","indices"],gradFunc:(t,e,n)=>{const[s,i]=e,{axis:r,batchDims:a}=n,o=dn(r,s.shape)[0],l=(t,e,n)=>()=>{const s=t.shape,i=e.size,a=s.slice(0,o),l=a.length,u=s.slice(r,s.length).slice(1),h=u.length,c=Pr(0,l),p=Pr(l+1,l+1+h),d=Ur([a,[i],u]),f=mi(n,d),g=mi(e,[i]),m=Ur([[l],c,p]),y=kr(f,m);let b=Or(y,g,t.shape[o]);const w=gr(m);return b=kr(b,w),b};if(1===a){const e=s.shape[0],n=s.split(e,0);return{x:()=>{const e=Mr(n.map(((e,n)=>l(e,i.slice(n,1),t.slice(n,1))())));return e.reshape(s.shape)},indices:()=>i}}return{x:l(s,i,t),indices:()=>i}}};function Pr(t,e){const n=[];for(let s=t;s{const[n,s]=e;return{a:()=>ki(n),b:()=>ki(s)}}},jr={kernelName:"Identity",gradFunc:t=>({x:()=>ei(t,"float32")})},Vr={kernelName:"IsFinite",gradFunc:t=>({x:()=>ki(t)})},qr={kernelName:"IsInf",gradFunc:t=>({x:()=>ki(t)})},Kr={kernelName:"IsNan",gradFunc:t=>({x:()=>ki(t)})};const Gr=ti({greater_:function(t,e){let n=Xs(t,"a","greater","string_or_numeric"),s=Xs(e,"b","greater","string_or_numeric");[n,s]=Us(n,s),gi(n.shape,s.shape);const i={a:n,b:s};return Gs.runKernel("Greater",i)}}),Hr={kernelName:"LeakyRelu",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{alpha:i}=n,r=Gr(s,0);return{x:()=>Zi(r,t,ni(t,i))}}},Jr={kernelName:"Log1p",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ai(t,Ni(n,1))}}},Zr={kernelName:"Log",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ai(t,ei(n,"float32"))}}},Yr={kernelName:"LogSoftmax",inputsToSave:[],outputsToSave:[!0],gradFunc:(t,e,n)=>{const[s]=e,{axis:i}=n;return{logits:()=>{const e=Tr(s);return ci(t,ni(yi(t,i,!0),e))}}}};const Xr=ti({localResponseNormalizationBackprop_:function(t,e,n,s=5,i=1,r=1,a=.5){const o={x:t,y:e,dy:n},l={depthRadius:s,bias:i,alpha:r,beta:a};return Gs.runKernel("LRNGrad",o,l)}}),Qr={kernelName:"LRN",inputsToSave:["x"],outputsToSave:[!0],gradFunc:(t,e,n)=>{const[s,i]=e,{depthRadius:r,bias:a,alpha:o,beta:l}=n;return{x:()=>Xr(s,i,t,r,a,o,l)}}};const ta=ti({equal_:function(t,e){let n=Xs(t,"a","equal","string_or_numeric"),s=Xs(e,"b","equal","string_or_numeric");[n,s]=Us(n,s),gi(n.shape,s.shape);const i={a:n,b:s};return Gs.runKernel("Equal",i)}});function ea(t,e,n,s){return e.rankni(t,ei(ta(n,e),t.dtype))}}const na={kernelName:"Max",inputsToSave:["x"],outputsToSave:[!0],gradFunc:(t,e,n)=>{const s=n,{reductionIndices:i}=s,r=e[0],a=ea(t,e[1],r,dn(i,r.shape));return{x:()=>a.x()}}};const sa=ti({less_:function(t,e){let n=Xs(t,"a","less","string_or_numeric"),s=Xs(e,"b","less","string_or_numeric");[n,s]=Us(n,s),gi(n.shape,s.shape);const i={a:n,b:s};return Gs.runKernel("Less",i)}}),ia={kernelName:"Maximum",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e;return{a:()=>ni(t,ei(qi(n,s),"float32")),b:()=>ni(t,ei(sa(n,s),"float32"))}}};const ra=ti({maxPool3dGrad_:function(t,e,n,s,i,r,a){const o=Xs(t,"dy","maxPool3dGrad"),l=Xs(e,"input","maxPool3dGrad"),u=Xs(n,"output","maxPool3dGrad");let h=o,c=l,p=u,d=!1;4===l.rank&&(d=!0,h=mi(o,[1,o.shape[0],o.shape[1],o.shape[2],o.shape[3]]),c=mi(l,[1,l.shape[0],l.shape[1],l.shape[2],l.shape[3]]),p=mi(u,[1,u.shape[0],u.shape[1],u.shape[2],u.shape[3]])),ln(5===h.rank,(()=>`Error in maxPool3dGrad: dy must be rank 5 but got rank ${h.rank}.`)),ln(5===c.rank,(()=>`Error in maxPool3dGrad: input must be rank 5 but got rank ${c.rank}.`)),ln(5===p.rank,(()=>`Error in maxPool3dGrad: output must be rank 5 but got rank ${p.rank}.`)),Di("maxPool3dGrad",r,a);const f={dy:h,input:c,output:p},g={filterSize:s,strides:i,pad:r,dimRoundingMode:a},m=Gs.runKernel("MaxPool3DGrad",f,g);return d?mi(m,[m.shape[1],m.shape[2],m.shape[3],m.shape[4]]):m}}),aa={kernelName:"MaxPool3D",inputsToSave:["x"],outputsToSave:[!0],gradFunc:(t,e,n)=>{const[s,i]=e,{filterSize:r,strides:a,pad:o,dimRoundingMode:l}=n;return{x:()=>ra(t,s,i,r,a,o,l)}}};const oa=ti({maxPoolGrad_:function(t,e,n,s,i,r,a){const o=Xs(t,"dy","maxPoolGrad"),l=Xs(e,"input","maxPoolGrad"),u=Xs(n,"output","maxPoolGrad");ln(l.rank===o.rank,(()=>`Rank of input (${l.rank}) does not match rank of dy (${o.rank})`)),ln(4===o.rank,(()=>`Error in maxPoolGrad: dy must be rank 4 but got rank ${o.rank}.`)),ln(4===l.rank,(()=>`Error in maxPoolGrad: input must be rank 4 but got rank ${l.rank}.`)),Di("maxPoolGrad",r,a);const h={dy:o,input:l,output:u},c={filterSize:s,strides:i,pad:r,dimRoundingMode:a};return Gs.runKernel("MaxPoolGrad",h,c)}}),la={kernelName:"MaxPool",inputsToSave:["x"],outputsToSave:[!0],gradFunc:(t,e,n)=>{const[s,i]=e,{filterSize:r,strides:a,pad:o}=n;return{x:()=>oa(t,s,i,r,a,o)}}};function ua(t,e="float32"){if(xn(t),"complex64"===e){const e=ua(t,"float32"),n=ua(t,"float32");return yr(e,n)}const n=Sn(un(t),e);return Gs.makeTensor(n,t,e)}function ha(t,e="float32"){if(xn(t),"complex64"===e){const e=ha(t,"float32"),n=ua(t,"float32");return yr(e,n)}const n=vn(un(t),e);return Gs.makeTensor(n,t,e)}const ca={kernelName:"Mean",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{axis:i}=n,r=dn(i,s.shape),a=function(t,e){const n=[],s=t.length;for(let i=0;it[e]))]}(s.shape,r),o=un(a[1]);return{x:()=>{const e=s.shape.slice();r.forEach((t=>{e[t]=1}));const n=mi(t,e);return ai(ni(n,ha(s.shape,"float32")),o)}}}},pa={kernelName:"Min",inputsToSave:["x"],outputsToSave:[!0],gradFunc:(t,e,n)=>{const s=n,{axis:i}=s,[r,a]=e,o=ea(t,a,r,dn(i,r.shape));return{x:()=>o.x()}}},da={kernelName:"Minimum",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e;return{a:()=>ni(t,ei(Ki(n,s),"float32")),b:()=>ni(t,ei(Gr(n,s),"float32"))}}};const fa=ti({slice_:function(t,e,n){const s=Xs(t,"x","slice","string_or_numeric");if(0===s.rank)throw new Error("Slicing scalar is not possible");const i={x:s},r={begin:e,size:n};return Gs.runKernel("Slice",i,r)}}),ga={kernelName:"MirrorPad",inputsToSave:["x"],gradFunc:(t,e,n)=>{const s=e[0],{paddings:i}=n,r=i.map((t=>t[0]));return{x:()=>fa(t,r,s.shape)}}};const ma=ti({floor_:function(t){const e={x:Xs(t,"x","floor","float32")};return Gs.runKernel("Floor",e)}}),ya={kernelName:"Mod",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=gi(n.shape,s.shape);return{a:()=>{const e=fi(n.shape,i);return e.length>0?mi(yi(t,e),n.shape):t},b:()=>{const e=ni(t,oi(ma(ai(n,s)))),r=fi(s.shape,i);return r.length>0?mi(yi(e,r),s.shape):e}}}},ba={kernelName:"Multiply",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=gi(n.shape,s.shape);return{a:()=>{const e=ni(t,ei(s,"float32")),r=fi(n.shape,i);return r.length>0?mi(yi(e,r),n.shape):e},b:()=>{const e=ni(t,ei(n,"float32")),r=fi(s.shape,i);return r.length>0?mi(yi(e,r),s.shape):e}}}},wa={kernelName:"Neg",gradFunc:t=>({x:()=>oi(t)})},ka={kernelName:"OneHot",inputsToSave:["indices"],gradFunc:(t,e)=>{const n=e[0];return{indices:()=>ua(n.shape,"float32")}}},va={kernelName:"OnesLike",gradFunc:t=>({x:()=>ki(t)})};const Sa=ti({unstack_:function(t,e=0){const n=Xs(t,"x","unstack","string_or_numeric");ln(e>=-n.shape.length&&e`Axis = ${e} is not in [-${n.shape.length}, ${n.shape.length})`));const s={value:n},i={axis:e};return Gs.runKernel("Unpack",s,i)}}),xa={kernelName:"Pack",saveAllInputs:!0,gradFunc:(t,e,n)=>{const{axis:s}=n;return Sa(t,s).map((t=>()=>t))}},Na={kernelName:"PadV2",inputsToSave:["x"],gradFunc:(t,e,n)=>{const s=e[0],{paddings:i}=n,r=i.map((t=>t[0]));return{x:()=>fa(t,r,s.shape)}}};const Ia=ti({log_:function(t){const e={x:Xs(t,"x","log","float32")};return Gs.runKernel("Log",e)}});const Aa=ti({pow_:function(t,e){let n=Xs(t,"base","pow"),s=Xs(e,"exp","pow");[n,s]=Us(n,s);const i={a:n,b:s};return Gs.runKernel("Pow",i)}}),Ta={kernelName:"Pow",inputsToSave:["a","b"],outputsToSave:[!0],gradFunc:(t,e)=>{const[n,s,i]=e,r=n,a=s,o=gi(r.shape,a.shape);return{a:()=>{const e=ei(a,"float32");let n=ni(t,ni(e,Aa(r,ci(e,li(1)))));const s=fi(r.shape,o);return s.length>0&&(n=yi(n,s)),mi(n,r.shape)},b:()=>{const e=Gr(r,0),n=Zi(e,Ia(r),ki(r));let s=ni(t,ni(i,n));const l=fi(a.shape,o);return l.length>0&&(s=yi(s,l)),mi(s,a.shape)}}}},Ea={kernelName:"Prelu",inputsToSave:["x","alpha"],gradFunc:(t,e)=>{const[n,s]=e,i=Gr(n,0);return{x:()=>Zi(i,t,ni(t,s)),alpha:()=>{let e=Zi(i,ki(t),ni(t,n));const r=fi(s.shape,t.shape);return r.length>0&&(e=yi(e,r)),mi(e,s.shape)}}}};const Ca=Tn();Ca.registerFlag("DEBUG",(()=>!1),(t=>{t&&console.warn("Debugging mode is ON. The output of every math call will be downloaded to CPU and checked for NaNs. This significantly impacts performance.")})),Ca.registerFlag("IS_BROWSER",(()=>"undefined"!=typeof window&&null!=window.document||"undefined"!=typeof WorkerGlobalScope)),Ca.registerFlag("IS_NODE",(()=>"undefined"!=typeof process&&"undefined"!=typeof process.versions&&"undefined"!=typeof process.versions.node)),Ca.registerFlag("IS_CHROME",(()=>"undefined"!=typeof navigator&&null!=navigator&&null!=navigator.userAgent&&/Chrome/.test(navigator.userAgent)&&/Google Inc/.test(navigator.vendor))),Ca.registerFlag("IS_SAFARI",(()=>"undefined"!=typeof navigator&&null!=navigator&&null!=navigator.userAgent&&/Safari/.test(navigator.userAgent)&&/Apple/.test(navigator.vendor))),Ca.registerFlag("PROD",(()=>!1)),Ca.registerFlag("TENSORLIKE_CHECK_SHAPE_CONSISTENCY",(()=>Ca.getBool("DEBUG"))),Ca.registerFlag("DEPRECATION_WARNINGS_ENABLED",(()=>!0)),Ca.registerFlag("IS_TEST",(()=>!1)),Ca.registerFlag("CHECK_COMPUTATION_FOR_ERRORS",(()=>Ca.getBool("DEBUG"))),Ca.registerFlag("WRAP_TO_IMAGEBITMAP",(()=>!1)),Ca.registerFlag("CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU",(()=>!1)),Ca.registerFlag("USE_SETTIMEOUTCUSTOM",(()=>!1));class za{static join(t){return new za(t).slice()}constructor(t){if(this.shards=[],this.previousShardIndex=0,null==t)return;if(t instanceof Array||(t=[t]),0===(t=t.map((t=>ws(t)?t.buffer:t))).length)return;this.bufferUniformSize=t[0].byteLength;let e=0;for(let n=0;n=this.byteLength)return-1;if(null!=this.bufferUniformSize)return this.previousShardIndex=Math.floor(t/this.bufferUniformSize),this.previousShardIndex;function e(e){return t=e.end?1:0}if(0===e(this.shards[this.previousShardIndex]))return this.previousShardIndex;const n=function(t,e){let n=0,s=t.length;for(;n<=s;){const i=Math.floor((s-n)/2)+n,r=e(t[i]);if(0===r)return i;r<0?s=i:n=i+1}return-1}(this.shards,e);return-1===n?-1:(this.previousShardIndex=n,this.previousShardIndex)}}const $a="undefined"!=typeof Buffer&&("undefined"==typeof Blob||"undefined"==typeof atob||"undefined"==typeof btoa);function Fa(t){return $a?Buffer.byteLength(t,"utf8"):new Blob([t]).size}function Da(t,e){const n={modelTopology:t.modelTopology,format:t.format,generatedBy:t.generatedBy,convertedBy:t.convertedBy,weightsManifest:e};return null!=t.signature&&(n.signature=t.signature),null!=t.userDefinedMetadata&&(n.userDefinedMetadata=t.userDefinedMetadata),null!=t.modelInitializer&&(n.modelInitializer=t.modelInitializer),null!=t.initializerSignature&&(n.initializerSignature=t.initializerSignature),null!=t.trainingConfig&&(n.trainingConfig=t.trainingConfig),n}async function La(t,e){let n,s;return null!=t.weightsManifest&&([n,s]=await e(t.weightsManifest)),function(t,e,n){const s={modelTopology:t.modelTopology,format:t.format,generatedBy:t.generatedBy,convertedBy:t.convertedBy};if(null!=t.trainingConfig&&(s.trainingConfig=t.trainingConfig),null!=t.weightsManifest){if(!e)throw new Error("modelJSON has weightsManifest but weightSpecs is null");if(!n)throw new Error("modelJSON has weightsManifest but weightData is null");s.weightSpecs=e,s.weightData=n}return null!=t.signature&&(s.signature=t.signature),null!=t.userDefinedMetadata&&(s.userDefinedMetadata=t.userDefinedMetadata),null!=t.modelInitializer&&(s.modelInitializer=t.modelInitializer),null!=t.initializerSignature&&(s.initializerSignature=t.initializerSignature),s}(t,n,s)}function _a(t){if(t.modelTopology instanceof ArrayBuffer)throw new Error("Expected JSON model topology, received ArrayBuffer.");return{dateSaved:new Date,modelTopologyType:"JSON",modelTopologyBytes:null==t.modelTopology?0:Fa(JSON.stringify(t.modelTopology)),weightSpecsBytes:null==t.weightSpecs?0:Fa(JSON.stringify(t.weightSpecs)),weightDataBytes:null==t.weightData?0:new za(t.weightData).byteLength}}function Ra(t){const e=[];for(const n of t)e.push(...n.weights);return e}class Ma{constructor(){this.saveRouters=[],this.loadRouters=[]}static getInstance(){return null==Ma.instance&&(Ma.instance=new Ma),Ma.instance}static registerSaveRouter(t){Ma.getInstance().saveRouters.push(t)}static registerLoadRouter(t){Ma.getInstance().loadRouters.push(t)}static getSaveHandlers(t){return Ma.getHandlers(t,"save")}static getLoadHandlers(t,e){return Ma.getHandlers(t,"load",e)}static getHandlers(t,e,n){const s=[];return("load"===e?Ma.getInstance().loadRouters:Ma.getInstance().saveRouters).forEach((e=>{const i=e(t,n);null!==i&&s.push(i)})),s}}class Oa{constructor(t){if(this.indexedDB=function(){if(!Tn().getBool("IS_BROWSER"))throw new Error("Failed to obtain IndexedDB factory because the current environmentis not a web browser.");const t="undefined"==typeof window?self:window,e=t.indexedDB||t.mozIndexedDB||t.webkitIndexedDB||t.msIndexedDB||t.shimIndexedDB;if(null==e)throw new Error("The current browser does not appear to support IndexedDB.");return e}(),null==t||!t)throw new Error("For IndexedDB, modelPath must not be null, undefined or empty.");this.modelPath=t}async save(t){if(t.modelTopology instanceof ArrayBuffer)throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");return this.databaseAction(this.modelPath,t)}async load(){return this.databaseAction(this.modelPath)}databaseAction(t,e){return new Promise(((t,n)=>{const s=this.indexedDB.open("tensorflowjs",1);s.onupgradeneeded=()=>function(t){const e=t.result;e.createObjectStore("models_store",{keyPath:"modelPath"}),e.createObjectStore("model_info_store",{keyPath:"modelPath"})}(s),s.onsuccess=()=>{const i=s.result;if(null==e){const e=i.transaction("models_store","readonly"),s=e.objectStore("models_store").get(this.modelPath);s.onsuccess=()=>{if(null==s.result)return i.close(),n(new Error(`Cannot find model with path '${this.modelPath}' in IndexedDB.`));t(s.result.modelArtifacts)},s.onerror=t=>(i.close(),n(s.error)),e.oncomplete=()=>i.close()}else{e.weightData=za.join(e.weightData);const s=_a(e),r=i.transaction("model_info_store","readwrite");let a,o,l=r.objectStore("model_info_store");try{a=l.put({modelPath:this.modelPath,modelArtifactsInfo:s})}catch(t){return n(t)}a.onsuccess=()=>{o=i.transaction("models_store","readwrite");const a=o.objectStore("models_store");let u;try{u=a.put({modelPath:this.modelPath,modelArtifacts:e,modelArtifactsInfo:s})}catch(t){return n(t)}u.onsuccess=()=>t({modelArtifactsInfo:s}),u.onerror=t=>{l=r.objectStore("model_info_store");const e=l.delete(this.modelPath);e.onsuccess=()=>(i.close(),n(u.error)),e.onerror=t=>(i.close(),n(u.error))}},a.onerror=t=>(i.close(),n(a.error)),r.oncomplete=()=>{null==o?i.close():o.oncomplete=()=>i.close()}}},s.onerror=t=>n(s.error)}))}}Oa.URL_SCHEME="indexeddb://";const Ba=t=>{return Tn().getBool("IS_BROWSER")&&!Array.isArray(t)&&t.startsWith(Oa.URL_SCHEME)?(e=t.slice(Oa.URL_SCHEME.length),new Oa(e)):null;var e};Ma.registerSaveRouter(Ba),Ma.registerLoadRouter(Ba);const Pa="tensorflowjs_models",Ua="info",Wa="model_topology",ja="weight_specs",Va="weight_data",qa="model_metadata";class Ka{constructor(t){if(!Tn().getBool("IS_BROWSER")||"undefined"==typeof window||"undefined"==typeof window.localStorage)throw new Error("The current environment does not support local storage.");if(this.LS=window.localStorage,null==t||!t)throw new Error("For local storage, modelPath must not be null, undefined or empty.");var e;this.modelPath=t,this.keys=(e=this.modelPath,{info:[Pa,e,Ua].join("/"),topology:[Pa,e,Wa].join("/"),weightSpecs:[Pa,e,ja].join("/"),weightData:[Pa,e,Va].join("/"),modelMetadata:[Pa,e,qa].join("/")})}async save(t){if(t.modelTopology instanceof ArrayBuffer)throw new Error("BrowserLocalStorage.save() does not support saving model topology in binary formats yet.");{const e=JSON.stringify(t.modelTopology),n=JSON.stringify(t.weightSpecs),s=_a(t),i=za.join(t.weightData);try{this.LS.setItem(this.keys.info,JSON.stringify(s)),this.LS.setItem(this.keys.topology,e),this.LS.setItem(this.keys.weightSpecs,n),this.LS.setItem(this.keys.weightData,function(t){if($a)return Buffer.from(t).toString("base64");const e=new Uint8Array(t);let n="";for(let t=0,s=e.length;t{return Tn().getBool("IS_BROWSER")&&!Array.isArray(t)&&t.startsWith(Ka.URL_SCHEME)?(e=t.slice(Ka.URL_SCHEME.length),new Ka(e)):null;var e};Ma.registerSaveRouter(Ga),Ma.registerLoadRouter(Ga);function Ha(t){return new Promise((t=>setTimeout(t))).then(t)}class Ja{constructor(t){if(!Tn().getBool("IS_BROWSER"))throw new Error("browserDownloads() cannot proceed because the current environment is not a browser.");t.startsWith(Ja.URL_SCHEME)&&(t=t.slice(Ja.URL_SCHEME.length)),null!=t&&0!==t.length||(t="model"),this.modelJsonFileName=t+".json",this.weightDataFileName=t+".weights.bin"}async save(t){if("undefined"==typeof document)throw new Error("Browser downloads are not supported in this environment since `document` is not present");const e=za.join(t.weightData),n=window.URL.createObjectURL(new Blob([e],{type:"application/octet-stream"}));if(t.modelTopology instanceof ArrayBuffer)throw new Error("BrowserDownloads.save() does not support saving model topology in binary formats yet.");{const e=Da(t,[{paths:["./"+this.weightDataFileName],weights:t.weightSpecs}]),s=window.URL.createObjectURL(new Blob([JSON.stringify(e)],{type:"application/json"})),i=null==this.modelJsonAnchor?document.createElement("a"):this.modelJsonAnchor;if(i.download=this.modelJsonFileName,i.href=s,await Ha((()=>i.dispatchEvent(new MouseEvent("click")))),null!=t.weightData){const t=null==this.weightDataAnchor?document.createElement("a"):this.weightDataAnchor;t.download=this.weightDataFileName,t.href=n,await Ha((()=>t.dispatchEvent(new MouseEvent("click"))))}return{modelArtifactsInfo:_a(t)}}}}Ja.URL_SCHEME="downloads://";function Za(t,e,n,s){!function(t){ln(null!=t&&Array.isArray(t)&&t.length>0,(()=>"promises must be a none empty array"))}(t),function(t,e){ln(t>=0&&t<=1,(()=>`Progress fraction must be in range [0, 1], but got startFraction ${t}`)),ln(e>=0&&e<=1,(()=>`Progress fraction must be in range [0, 1], but got endFraction ${e}`)),ln(e>=t,(()=>`startFraction must be no more than endFraction, but got startFraction ${t} and endFraction ${e}`))}(n=null==n?0:n,s=null==s?1:s);let i=0;return Promise.all(t.map((r=>(r.then((r=>{const a=n+ ++i/t.length*(s-n);return e(a),r})),r))))}Ma.registerSaveRouter((t=>Tn().getBool("IS_BROWSER")&&!Array.isArray(t)&&t.startsWith(Ja.URL_SCHEME)?function(t="model"){return new Ja(t)}(t.slice(Ja.URL_SCHEME.length)):null));class Ya{constructor(t,e){if(this.DEFAULT_METHOD="POST",null==e&&(e={}),this.weightPathPrefix=e.weightPathPrefix,this.weightUrlConverter=e.weightUrlConverter,null!=e.fetchFunc?(ln("function"==typeof e.fetchFunc,(()=>"Must pass a function that matches the signature of `fetch` (see https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)")),this.fetch=e.fetchFunc):this.fetch=Tn().platform.fetch,ln(null!=t&&t.length>0,(()=>"URL path for http must not be null, undefined or empty.")),Array.isArray(t)&&ln(2===t.length,(()=>`URL paths for http must have a length of 2, (actual length is ${t.length}).`)),this.path=t,null!=e.requestInit&&null!=e.requestInit.body)throw new Error("requestInit is expected to have no pre-existing body, but has one.");this.requestInit=e.requestInit||{},this.loadOptions=e}async save(t){if(t.modelTopology instanceof ArrayBuffer)throw new Error("BrowserHTTPRequest.save() does not support saving model topology in binary formats yet.");const e=Object.assign({method:this.DEFAULT_METHOD},this.requestInit);e.body=new FormData;const n=Da(t,[{paths:["./model.weights.bin"],weights:t.weightSpecs}]);if(e.body.append("model.json",new Blob([JSON.stringify(n)],{type:"application/json"}),"model.json"),null!=t.weightData){const n=za.join(t.weightData);e.body.append("model.weights.bin",new Blob([n],{type:"application/octet-stream"}),"model.weights.bin")}const s=await this.fetch(this.path,e);if(s.ok)return{modelArtifactsInfo:_a(t),responses:[s]};throw new Error(`BrowserHTTPRequest.save() failed due to HTTP response status ${s.status}.`)}async loadModelJSON(){const t=await this.fetch(this.path,this.requestInit);if(!t.ok)throw new Error(`Request to ${this.path} failed with status code ${t.status}. Please verify this URL points to the model JSON of the model to load.`);let e;try{e=await t.json()}catch(t){let e=`Failed to parse model JSON of response from ${this.path}.`;throw this.path.endsWith(".pb")?e+=" Your path contains a .pb file extension. Support for .pb models have been removed in TensorFlow.js 1.0 in favor of .json models. You can re-convert your Python TensorFlow model using the TensorFlow.js 1.0 conversion scripts or you can convert your.pb models with the 'pb2json'NPM script in the tensorflow/tfjs-converter repository.":e+=" Please make sure the server is serving valid JSON for this request.",new Error(e)}const n=e.modelTopology,s=e.weightsManifest;if(null==n&&null==s)throw new Error(`The JSON from HTTP path ${this.path} contains neither model topology or manifest for weights.`);return e}async load(){if(this.loadOptions.streamWeights)return this.loadStream();return La(await this.loadModelJSON(),(t=>this.loadWeights(t)))}async loadStream(){const t=await this.loadModelJSON(),e=await this.getWeightUrls(t.weightsManifest),n=Ra(t.weightsManifest);return Object.assign(Object.assign({},t),{weightSpecs:n,getWeightStream:()=>function(t,e){var n;const s=null==e.fetchFunc?Tn().platform.fetch:e.fetchFunc;let i,r=0;return null===(n=e.onProgress)||void 0===n||n.call(e,0),new ReadableStream({pull:async n=>{for(var a;re?t.substring(n):"";return[s+"/",i]}(e),i=this.weightPathPrefix||n,r=[],a=[];for(const e of t)for(const t of e.paths)null!=this.weightUrlConverter?a.push(this.weightUrlConverter(t)):r.push(i+t+s);return this.weightUrlConverter&&r.push(...await Promise.all(a)),r}async loadWeights(t){const e=await this.getWeightUrls(t),n=Ra(t),s=await async function(t,e){null==e&&(e={});const n=null==e.fetchFunc?Tn().platform.fetch:e.fetchFunc,s=t.map((t=>n(t,e.requestInit,{isBinary:!0}))),i=(null==e.onProgress?await Promise.all(s):await Za(s,e.onProgress,0,.5)).map((t=>t.arrayBuffer()));return null==e.onProgress?await Promise.all(i):await Za(i,e.onProgress,.5,1)}(e,this.loadOptions);return[n,s]}}function Xa(t){return null!=t.match(Ya.URL_SCHEME_REGEX)}Ya.URL_SCHEME_REGEX=/^https?:\/\//;const Qa=(t,e)=>{if("undefined"==typeof fetch&&(null==e||null==e.fetchFunc))return null;{let n=!0;if(n=Array.isArray(t)?t.every((t=>Xa(t))):Xa(t),n)return function(t,e){return new Ya(t,e)}(t,e)}return null};function to(t,e,n){if(e.rank<1)throw new Error(`tf.scatterND() expects the indices to be rank 1 or higher, but the rank was ${e.rank}.`);if(t.rank<1)throw new Error(`tf.scatterND() expects the updates to be rank 1 or higher, but the rank was ${t.rank}.`);if("int32"!==e.dtype)throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${e.dtype}`);if(n.length<1)throw new Error(`Output rank must be greater or equal to 1, but got shape: ${n}`);if(0===n.length){if(0===e.size)throw new Error(`Indices specified for empty output. indices shape: ${e.shape}`);if(0===t.size)throw new Error(`Updates specified for empty output. updates shape: ${t.shape}`)}!function(t,e,n){const s=e.rank>1?e.shape[e.rank-1]:1,i=e.rank>1?e.rank-1:1,r=`Must have updates.shape = indices.shape[:batchDim] + shape[sliceDim:], got updates.shape: ${n.shape}, indices.shape: ${e.shape}, shape: ${t}, sliceDim: ${s}, and batchDim: ${i}.`;if(n.rank{if("complex64"!==t.dtype)throw new Error(`Cannot concatenate complex64 tensors with a tensor\n with dtype ${t.dtype}. `)})),1===n.length)return Hi(n[0]);const s=n,i={axis:e};return Gs.runKernel("Concat",s,i)}});const no=ti({sigmoid_:function(t){const e={x:Xs(t,"x","sigmoid","float32")};return Gs.runKernel("Sigmoid",e)}});const so=ti({batchToSpaceND_:function(t,e,n){const s=Xs(t,"x","batchToSpaceND"),i=e.reduce(((t,e)=>t*e));ln(s.rank>=1+e.length,(()=>`input rank is ${s.rank} but should be > than blockShape.length ${e.length}`)),ln(n.length===e.length,(()=>`crops.length is ${n.length} but should be equal to blockShape.length ${e.length}`)),ln(s.shape[0]%i==0,(()=>`input tensor batch is ${s.shape[0]} but is not divisible by the product of the elements of blockShape ${e.join(" * ")} === ${i}`));const r={x:s},a={blockShape:e,crops:n};return Gs.runKernel("BatchToSpaceND",r,a)}});const io=ti({cos_:function(t){const e={x:Xs(t,"x","cos","float32")};return Gs.runKernel("Cos",e)}});const ro=ti({cosh_:function(t){const e={x:Xs(t,"x","cosh","float32")};return Gs.runKernel("Cosh",e)}});const ao=ti({cumprod_:function(t,e=0,n=!1,s=!1){const i={x:Xs(t,"x","cumprod")},r={axis:e,exclusive:n,reverse:s};return Gs.runKernel("Cumprod",i,r)}});const oo=ti({expandDims_:function(t,e=0){const n=Xs(t,"x","expandDims","string_or_numeric");ln(e<=n.rank,(()=>"Axis must be <= rank of the tensor"));const s={input:n},i={dim:e};return Gs.runKernel("ExpandDims",s,i)}});const lo=ti({gather_:function(t,e,n=0,s=0){const i={x:Xs(t,"x","gather"),indices:Xs(e,"indices","gather","int32")},r={axis:n,batchDims:s};return Gs.runKernel("GatherV2",i,r)}});const uo=ti({logicalNot_:function(t){const e={x:Xs(t,"x","logicalNot","bool")};return Gs.runKernel("LogicalNot",e)}});const ho=ti({maximum_:function(t,e){let n=Xs(t,"a","maximum"),s=Xs(e,"b","maximum");[n,s]=Us(n,s),"bool"===n.dtype&&(n=ei(n,"int32"),s=ei(s,"int32")),gi(n.shape,s.shape);const i={a:n,b:s};return Gs.runKernel("Maximum",i)}});const co=ti({pad_:function(t,e,n=0){const s=Xs(t,"x","pad");if(0===s.rank)throw new Error("pad(scalar) is not defined. Pass non-scalar to pad");const i={paddings:e,constantValue:n},r={x:s};return Gs.runKernel("PadV2",r,i)}});var po={exports:{}};!function(t,e,n){function s(t){var e,n=this,s=(e=4022871197,function(t){t=String(t);for(var n=0;n>>0,e=(s*=e)>>>0,e+=4294967296*(s-=e)}return 2.3283064365386963e-10*(e>>>0)});n.next=function(){var t=2091639*n.s0+2.3283064365386963e-10*n.c;return n.s0=n.s1,n.s1=n.s2,n.s2=t-(n.c=0|t)},n.c=1,n.s0=s(" "),n.s1=s(" "),n.s2=s(" "),n.s0-=s(t),n.s0<0&&(n.s0+=1),n.s1-=s(t),n.s1<0&&(n.s1+=1),n.s2-=s(t),n.s2<0&&(n.s2+=1),s=null}function i(t,e){return e.c=t.c,e.s0=t.s0,e.s1=t.s1,e.s2=t.s2,e}function r(t,e){var n=new s(t),r=e&&e.state,a=n.next;return a.int32=function(){return 4294967296*n.next()|0},a.double=function(){return a()+11102230246251565e-32*(2097152*a()|0)},a.quick=a,r&&("object"==typeof r&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.alea=r}(0,po,!1);var fo=po.exports,go={exports:{}};!function(t,e,n){function s(t){var e=this,n="";e.x=0,e.y=0,e.z=0,e.w=0,e.next=function(){var t=e.x^e.x<<11;return e.x=e.y,e.y=e.z,e.z=e.w,e.w^=e.w>>>19^t^t>>>8},t===(0|t)?e.x=t:n+=t;for(var s=0;s>>0)/4294967296};return a.double=function(){do{var t=((n.next()>>>11)+(n.next()>>>0)/4294967296)/(1<<21)}while(0===t);return t},a.int32=n.next,a.quick=a,r&&("object"==typeof r&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.xor128=r}(0,go,!1);var mo=go.exports,yo={exports:{}};!function(t,e,n){function s(t){var e=this,n="";e.next=function(){var t=e.x^e.x>>>2;return e.x=e.y,e.y=e.z,e.z=e.w,e.w=e.v,(e.d=e.d+362437|0)+(e.v=e.v^e.v<<4^t^t<<1)|0},e.x=0,e.y=0,e.z=0,e.w=0,e.v=0,t===(0|t)?e.x=t:n+=t;for(var s=0;s>>4),e.next()}function i(t,e){return e.x=t.x,e.y=t.y,e.z=t.z,e.w=t.w,e.v=t.v,e.d=t.d,e}function r(t,e){var n=new s(t),r=e&&e.state,a=function(){return(n.next()>>>0)/4294967296};return a.double=function(){do{var t=((n.next()>>>11)+(n.next()>>>0)/4294967296)/(1<<21)}while(0===t);return t},a.int32=n.next,a.quick=a,r&&("object"==typeof r&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.xorwow=r}(0,yo,!1);var bo=yo.exports,wo={exports:{}};!function(t,e,n){function s(t){var e=this;e.next=function(){var t,n,s=e.x,i=e.i;return t=s[i],n=(t^=t>>>7)^t<<24,n^=(t=s[i+1&7])^t>>>10,n^=(t=s[i+3&7])^t>>>3,n^=(t=s[i+4&7])^t<<7,t=s[i+7&7],n^=(t^=t<<13)^t<<9,s[i]=n,e.i=i+1&7,n},function(t,e){var n,s=[];if(e===(0|e))s[0]=e;else for(e=""+e,n=0;n0;--n)t.next()}(e,t)}function i(t,e){return e.x=t.x.slice(),e.i=t.i,e}function r(t,e){null==t&&(t=+new Date);var n=new s(t),r=e&&e.state,a=function(){return(n.next()>>>0)/4294967296};return a.double=function(){do{var t=((n.next()>>>11)+(n.next()>>>0)/4294967296)/(1<<21)}while(0===t);return t},a.int32=n.next,a.quick=a,r&&(r.x&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.xorshift7=r}(0,wo,!1);var ko=wo.exports,vo={exports:{}};!function(t,e,n){function s(t){var e=this;e.next=function(){var t,n,s=e.w,i=e.X,r=e.i;return e.w=s=s+1640531527|0,n=i[r+34&127],t=i[r=r+1&127],n^=n<<13,t^=t<<17,n^=n>>>15,t^=t>>>12,n=i[r]=n^t,e.i=r,n+(s^s>>>16)|0},function(t,e){var n,s,i,r,a,o=[],l=128;for(e===(0|e)?(s=e,e=null):(e+="\0",s=0,l=Math.max(l,e.length)),i=0,r=-32;r>>15,s^=s<<4,s^=s>>>13,r>=0&&(a=a+1640531527|0,i=0==(n=o[127&r]^=s+a)?i+1:0);for(i>=128&&(o[127&(e&&e.length||0)]=-1),i=127,r=512;r>0;--r)s=o[i+34&127],n=o[i=i+1&127],s^=s<<13,n^=n<<17,s^=s>>>15,n^=n>>>12,o[i]=s^n;t.w=a,t.X=o,t.i=i}(e,t)}function i(t,e){return e.i=t.i,e.w=t.w,e.X=t.X.slice(),e}function r(t,e){null==t&&(t=+new Date);var n=new s(t),r=e&&e.state,a=function(){return(n.next()>>>0)/4294967296};return a.double=function(){do{var t=((n.next()>>>11)+(n.next()>>>0)/4294967296)/(1<<21)}while(0===t);return t},a.int32=n.next,a.quick=a,r&&(r.X&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.xor4096=r}(0,vo,!1);var So=vo.exports,xo={exports:{}};!function(t,e,n){function s(t){var e=this,n="";e.next=function(){var t=e.b,n=e.c,s=e.d,i=e.a;return t=t<<25^t>>>7^n,n=n-s|0,s=s<<24^s>>>8^i,i=i-t|0,e.b=t=t<<20^t>>>12^n,e.c=n=n-s|0,e.d=s<<16^n>>>16^i,e.a=i-t|0},e.a=0,e.b=0,e.c=-1640531527,e.d=1367130551,t===Math.floor(t)?(e.a=t/4294967296|0,e.b=0|t):n+=t;for(var s=0;s>>0)/4294967296};return a.double=function(){do{var t=((n.next()>>>11)+(n.next()>>>0)/4294967296)/(1<<21)}while(0===t);return t},a.int32=n.next,a.quick=a,r&&("object"==typeof r&&i(r,n),a.state=function(){return i(n,{})}),a}e&&e.exports?e.exports=r:n&&n.amd?n((function(){return r})):this.tychei=r}(0,xo,!1);var No,Io=xo.exports,Ao={exports:{}},To=Un({__proto__:null,default:{}});No=Ao,function(t,e,n){var s,i=256,r=n.pow(i,6),a=n.pow(2,52),o=2*a,l=255;function u(l,u,g){var m=[],y=d(p((u=1==u?{entropy:!0}:u||{}).entropy?[l,f(e)]:null==l?function(){try{var n;return s&&(n=s.randomBytes)?n=n(i):(n=new Uint8Array(i),(t.crypto||t.msCrypto).getRandomValues(n)),f(n)}catch(n){var r=t.navigator,a=r&&r.plugins;return[+new Date,t,a,t.screen,f(e)]}}():l,3),m),b=new h(m),w=function(){for(var t=b.g(6),e=r,n=0;t=o;)t/=2,e/=2,n>>>=1;return(t+n)/e};return w.int32=function(){return 0|b.g(4)},w.quick=function(){return b.g(4)/4294967296},w.double=w,d(f(b.S),e),(u.pass||g||function(t,e,s,i){return i&&(i.S&&c(i,b),t.state=function(){return c(b,{})}),s?(n.random=t,e):t})(w,y,"global"in u?u.global:this==n,u.state)}function h(t){var e,n=t.length,s=this,r=0,a=s.i=s.j=0,o=s.S=[];for(n||(t=[n++]);rt*e),1);o.push(l);let u=function(t,e,n){const s=t.shape.slice();s[n]=1;const i=mi(e,s),r=ao(t,n,!0,!1),a=ao(t,n,!0,!0),o=ni(r,a);return ni(i,o)}(a.reshape(o),e,i);if(u=u.reshape(a.shape),null!=r){const t=gr(r);u=kr(u,t)}return u}const Bo={kernelName:"SpaceToBatchND",gradFunc:(t,e,n)=>{const{blockShape:s,paddings:i}=n;return{x:()=>so(t,s,i)}}},Po={kernelName:"SplitV",gradFunc:(t,e,n)=>{const{axis:s}=n;return{x:()=>eo(t,s)}}};const Uo=[ii,pi,di,bi,wi,vi,Si,xi,Ii,Ai,Ti,Ei,_i,Mi,Bi,Ui,Wi,ji,Vi,Yi,Xi,tr,rr,sr,lr,hr,pr,vr,Nr,Ir,{kernelName:"RealDiv",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=gi(n.shape,s.shape);return{a:()=>{const e=ai(t,ei(s,"float32")),r=fi(n.shape,i);return r.length>0?mi(yi(e,r),n.shape):e},b:()=>{let e=ni(t,ei(n,"float32"));const r=fi(s.shape,i);r.length>0&&(e=mi(yi(e,r),s.shape));const a=hi(s);return oi(ai(e,ei(a,"float32")))}}}},Ar,Er,Cr,zr,$r,Dr,Fr,Rr,Br,Wr,jr,Vr,qr,Kr,Hr,Jr,Zr,Yr,Qr,na,na,ia,aa,la,ca,pa,da,ga,ya,ba,wa,ka,va,xa,Na,Na,Ta,Ea,{kernelName:"Prod",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{axis:i}=n;let r=[];return r=null==i?s.shape.map(((t,e)=>e)):"number"==typeof i?[i]:i,{x:()=>Oo(s,t,r)}}},{kernelName:"Reciprocal",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ai(t,oi(hi(n)))}}},{kernelName:"Relu6",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e,s=ni(Ki(n,6),si(n));return{x:()=>ni(t,ei(s,"float32"))}}},{kernelName:"Relu",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(t,ei(si(n),"float32"))}}},{kernelName:"Reshape",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>mi(t,n.shape)}}},{kernelName:"ResizeBilinear",inputsToSave:["images"],gradFunc:(t,e,n)=>{const[s]=e,i={dy:t,images:s};return{images:()=>Gs.runKernel("ResizeBilinearGrad",i,n)}}},{kernelName:"ResizeNearestNeighbor",inputsToSave:["images"],gradFunc:(t,e,n)=>{const[s]=e,i={dy:t,images:s};return{images:()=>Gs.runKernel("ResizeNearestNeighborGrad",i,n)}}},{kernelName:"Reverse",gradFunc:(t,e,n)=>{const{dims:s}=n,i=dn(s,t.shape);return{x:()=>_o(t,i)}}},{kernelName:"Round",gradFunc:t=>({x:()=>ki(t)})},{kernelName:"Rsqrt",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>oi(ai(t,ni(Aa(n,1.5),2)))}}},{kernelName:"Select",inputsToSave:["condition"],gradFunc:(t,e)=>{const[n]=e;return{condition:()=>ei(ki(n),"float32"),t:()=>ni(t,ei(n,t.dtype)),e:()=>ni(t,ei(uo(n),t.dtype))}}},{kernelName:"Selu",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>{const e=Gr(n,li(0)),s=li(1.7580993408473768),i=li(1.0507009873554805),r=ni(t,i),a=ni(ni(t,s),Tr(ei(n,"float32")));return Zi(e,r,a)}}}},{kernelName:"Sigmoid",outputsToSave:[!0],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(t,ni(n,ci(li(1),n)))}}},{kernelName:"Sign",gradFunc:t=>({x:()=>ki(t)})},{kernelName:"Sin",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(io(ei(n,"float32")),t)}}},{kernelName:"Sinh",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(ro(ei(n,"float32")),t)}}},{kernelName:"Slice",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{begin:i,size:r}=n,a=s.shape,[o,l]=function(t,e,n){let s;const i=t.shape.length;let r;return s="number"==typeof e?[e,...new Array(i-1).fill(0)]:e.length{ln(-1!==t,(()=>"slice() does not support negative begin indexing."))})),r=null==n?new Array(i).fill(-1):"number"==typeof n?[n,...new Array(i-1).fill(-1)]:n.lengthe>=0?e:(ln(-1===e,(()=>`Negative size values should be exactly -1 but got ${e} for the slice() size at index ${n}.`)),t.shape[n]-s[n]))),[s,r]}(s,i,r),u=[];for(let e=0;eco(t,u)}}},{kernelName:"Softmax",outputsToSave:[!0],gradFunc:(t,e,n)=>{const[s]=e,{dim:i}=n,r=ni(t,s);return{logits:()=>ci(r,ni(yi(r,[i],true),s))}}},{kernelName:"Softplus",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(t,no(n))}}},Bo,Bo,Po,Po,{kernelName:"Sqrt",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ai(t,ni(ui(ei(n,"float32")),2))}}},{kernelName:"SquaredDifference",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=li(2);return{a:()=>ni(t,ni(i,ci(n,s))),b:()=>ni(t,ni(i,ci(s,n)))}}},{kernelName:"Square",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(t,ni(ei(n,"float32"),2))}}},{kernelName:"Step",gradFunc:t=>({x:()=>ki(t)})},{kernelName:"Sub",inputsToSave:["a","b"],gradFunc:(t,e)=>{const[n,s]=e,i=gi(n.shape,s.shape);return{a:()=>{let e=t;const s=fi(n.shape,i);return s.length>0&&(e=yi(e,s)),mi(e,n.shape)},b:()=>{let e=t;const n=fi(s.shape,i);return n.length>0&&(e=yi(e,n)),mi(oi(e),s.shape)}}}},{kernelName:"Sum",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,i=s.shape.slice(),{axis:r}=n;dn(r,s.shape).forEach((t=>{i[t]=1}));const a=mi(t,i),o=ni(a,ha(s.shape,"float32"));return{x:()=>o}}},{kernelName:"Tan",inputsToSave:["x"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ai(t,hi(io(n)))}}},{kernelName:"Tanh",outputsToSave:[!0],gradFunc:(t,e)=>{const[n]=e;return{x:()=>ni(ci(li(1),hi(n)),t)}}},{kernelName:"Tile",inputsToSave:["x"],gradFunc:(t,e,n)=>{const[s]=e,{reps:i}=n;return{x:()=>{let e=ki(s);if(1===s.rank)for(let n=0;n{const s=n,{perm:i}=s,r=gr(i);return{x:()=>kr(t,r)}}},{kernelName:"Unpack",gradFunc:(t,e,n)=>{const s=n,{axis:i}=s;return{value:()=>Mr(t,i)}}},{kernelName:"UnsortedSegmentSum",inputsToSave:["segmentIds"],gradFunc:(t,e)=>{const[n]=e;return{x:()=>function(t,e){const n=ho(e,ki(e)),s=lo(t,n);let i=qi(e,li(0,"int32"));const r=s.rank-i.rank;for(let t=0;t({x:()=>ki(t)})}];for(const t of Uo)On(t);function Wo(e,n){return s((()=>t.sqrt(t.sum(t.mul(e,e),n,!0))))}class jo extends r.Serializable{getConfig(){return{}}}class Vo extends jo{constructor(t){super(),this.defaultMaxValue=2,this.defaultAxis=0,this.maxValue=null!=t.maxValue?t.maxValue:this.defaultMaxValue,this.axis=null!=t.axis?t.axis:this.defaultAxis}apply(e){return s((()=>{const n=Wo(e,this.axis),s=t.clipByValue(n,0,this.maxValue);return t.mul(e,t.div(s,t.add(Yt(),n)))}))}getConfig(){return{maxValue:this.maxValue,axis:this.axis}}}Vo.className="MaxNorm",r.registerClass(Vo);class qo extends jo{constructor(t){super(),this.defaultAxis=0,this.axis=null!=t.axis?t.axis:this.defaultAxis}apply(e){return s((()=>t.div(e,t.add(Yt(),Wo(e,this.axis)))))}getConfig(){return{axis:this.axis}}}qo.className="UnitNorm",r.registerClass(qo);class Ko extends jo{apply(e){return t.relu(e)}}Ko.className="NonNeg",r.registerClass(Ko);class Go extends jo{constructor(t){super(),this.defaultMinValue=0,this.defaultMaxValue=1,this.defaultRate=1,this.defaultAxis=0,this.minValue=null!=t.minValue?t.minValue:this.defaultMinValue,this.maxValue=null!=t.maxValue?t.maxValue:this.defaultMaxValue,this.rate=null!=t.rate?t.rate:this.defaultRate,this.axis=null!=t.axis?t.axis:this.defaultAxis}apply(e){return s((()=>{const n=Wo(e,this.axis),s=t.add(t.mul(this.rate,t.clipByValue(n,this.minValue,this.maxValue)),t.mul(1-this.rate,n));return t.mul(e,t.div(s,t.add(Yt(),n)))}))}getConfig(){return{minValue:this.minValue,maxValue:this.maxValue,rate:this.rate,axis:this.axis}}}Go.className="MinMaxNorm",r.registerClass(Go);const Ho={maxNorm:"MaxNorm",minMaxNorm:"MinMaxNorm",nonNeg:"NonNeg",unitNorm:"UnitNorm"};function Jo(t){return mt(t)}function Zo(t,e={}){return bt(t,r.SerializationMap.getMap().classNameMap,e,"constraint")}function Yo(t){if(null==t)return null;if("string"==typeof t){return Zo({className:t in Ho?Ho[t]:t,config:{}})}return t instanceof jo?t:Zo(t)}var Xo={__proto__:null,maxNorm:function(t){return new Vo(t)},minMaxNorm:function(t){return new Go(t)},nonNeg:function(){return new Ko},unitNorm:function(t){return new qo(t)}};var Qo,tl={__proto__:null,constant:function(t){return new we(t)},glorotNormal:function(t){return new Ae(t)},glorotUniform:function(t){return new Ie(t)},heNormal:function(t){return new Te(t)},heUniform:function(t){return new Ee(t)},identity:function(t){return new xe(t)},leCunNormal:function(t){return new Ce(t)},leCunUniform:function(t){return new ze(t)},ones:function(){return new be},orthogonal:function(t){return new $e(t)},randomNormal:function(t){return new ve(t)},randomUniform:function(t){return new ke(t)},truncatedNormal:function(t){return new Se(t)},varianceScaling:function(t){return new Ne(t)},zeros:function(){return new ye}};async function el(t){if(null==t)return;const e=[],n=[],s=[];for(const i in t){const r=t[i];if("number"!=typeof r){const t=r;e.push(t.data()),n.push(i),s.push(t)}}if(e.length>0){const i=await Promise.all(e);for(let e=0;ew(this.totals[t],l(i,n))));this.totals[t]=r,null!=e&&e.dispose()}}}async onEpochEnd(t,e){if(null!=e)for(const t of this.params.metrics)null!=this.totals[t]&&("number"==typeof this.totals[t]?e[t]=this.totals[t]/this.seen:s((()=>{const n=l(k(1,this.seen),this.totals[t]);e[t]=n,this.totals[t].dispose(),v(e[t])})))}}class al extends sl{async onTrainBegin(t){this.epoch=[],this.history={}}async onEpochEnd(t,e){null==e&&(e={}),this.epoch.push(t);for(const t in e)null==this.history[t]&&(this.history[t]=[]),this.history[t].push(e[t])}async syncData(){const t=[],e=[],n=[];for(const s in this.history){const i=this.history[s];for(let r=0;r{const o=null!=s?s():e.now();return o-rnew ol(t,e)))}class ul{constructor(){}static registerCallbackConstructor(t,n){e.assert(t>=0&&Number.isInteger(t),(()=>`Verbosity level is expected to be an integer >= 0, but got ${t}`)),ul.checkForDuplicate(n),null==ul.constructors[t]&&(ul.constructors[t]=[]),ul.constructors[t].push(n)}static checkForDuplicate(t){for(const e in ul.constructors){ul.constructors[+e].forEach((e=>{if(e===t)throw new it("Duplicate callback constructor.")}))}}static clear(){ul.constructors={}}static createCallbacks(t){const e=[];for(const n in ul.constructors){const s=+n;t>=s&&e.push(...ul.constructors[s])}return e.map((t=>new t))}}function hl(t,e,n,s,i,r,a,o,l){const u=new al,h=[new rl,...ul.createCallbacks(e)];null!=t&&h.push(...t),h.push(u);const c=new il(h);return c.setParams({epochs:n,initialEpoch:s,samples:i,steps:r,batchSize:a,verbose:e,doValidation:o,metrics:l}),{callbackList:c,history:u}}function cl(t,e={},n=!1){return bt(t,r.SerializationMap.getMap().classNameMap,e,"layer",n)}function pl(e,n){return s((()=>{"float32"!==e.dtype&&(e=t.cast(e,"float32"));const s=t.sum(ue(e),n,!0),i=t.fill(s.shape,Yt()),r=t.sqrt(t.maximum(s,i));return t.div(e,r)}))}function dl(e,n){return s((()=>t.mean(ue(t.sub(n,e)),-1)))}function fl(e,n){return s((()=>t.mean(t.abs(t.sub(n,e)),-1)))}function gl(e,n){return s((()=>{const s=t.sub(e,n),i=t.clipByValue(t.abs(e),Yt(),Number.MAX_VALUE),r=t.abs(t.div(s,i));return t.mul(100,t.mean(r,-1))}))}function ml(e,n,i=!1){return s((()=>{if(i)n=t.softmax(n);else{const e=t.sum(n,n.shape.length-1,!0);n=t.div(n,e)}return n=t.clipByValue(n,Yt(),1-Yt()),t.neg(t.sum(t.mul(t.cast(e,"float32"),t.log(n)),n.shape.length-1))}))}function yl(e,n,i=!1){return s((()=>{const s=t.cast(t.floor(function(e){const n=[Kt(e.shape)];return t.reshape(e,n)}(e)),"int32"),r=(n=t.clipByValue(n,Yt(),1-Yt())).shape;return ml(t.reshape(t.oneHot(s,r[r.length-1]),r),n,i)}))}function bl(n,i){return s((()=>{let r;return r=t.clipByValue(i,Yt(),1-Yt()),r=t.log(t.div(r,t.sub(1,r))),t.mean(function(n,i){if(!e.arraysEqual(n.shape,i.shape))throw new it(`logits and labels must have the same shape, but got shapes ${JSON.stringify(n.shape)} and ${JSON.stringify(i.shape)}`);return s((()=>{const e=t.relu(i),s=t.neg(t.abs(i));return t.add(t.sub(e,t.mul(i,n)),t.log1p(t.exp(s)))}))}(n,r),-1)}))}function wl(e,n){return s((()=>{const s=pl(e,-1),i=pl(n,-1),r=t.mul(s,i);return t.neg(t.sum(r,-1))}))}ul.constructors={};const kl={meanSquaredError:dl,meanAbsoluteError:fl,meanAbsolutePercentageError:gl,meanSquaredLogarithmicError:function(e,n){return s((()=>{const s=t.clipByValue(n,Yt(),Number.MAX_VALUE),i=t.log(t.add(1,s)),r=t.clipByValue(e,Yt(),Number.MAX_VALUE),a=t.log(t.add(1,r));return t.mean(ue(t.sub(i,a)),-1)}))},squaredHinge:function(e,n){return s((()=>{const s=t.maximum(0,t.sub(1,t.mul(e,n)));return t.mean(ue(s),-1)}))},hinge:function(e,n){return s((()=>{const s=t.maximum(0,t.sub(1,t.mul(e,n)));return t.mean(s,-1)}))},categoricalHinge:function(e,n){return s((()=>{const s=t.sum(t.mul(e,n),-1),i=t.max(t.mul(t.sub(1,e),n),-1);return t.maximum(0,t.add(1,t.sub(i,s)))}))},logcosh:function(e,n){return s((()=>{const s=Math.log(2),i=t.sub(n,e),r=t.sub(t.add(i,t.softplus(t.mul(-2,i))),s);return t.mean(r,-1)}))},categoricalCrossentropy:ml,sparseCategoricalCrossentropy:yl,binaryCrossentropy:bl,kullbackLeiblerDivergence:function(e,n){return s((()=>{const s=t.clipByValue(e,Yt(),1),i=t.clipByValue(n,Yt(),1);return t.sum(t.mul(e,t.log(t.div(s,i))),-1)}))},poisson:function(e,n){return s((()=>{const s=t.log(t.add(Yt(),n));return t.mean(t.sub(n,t.mul(e,s)),-1)}))},cosineProximity:wl};function vl(t){if("string"==typeof t){if(t in kl)return kl[t];let e=`Unknown loss ${t}`;throw t.toLowerCase().includes("softmaxcrossentropy")&&(e=`Unknown loss ${t}. Use "categoricalCrossentropy" as the string name for tf.losses.softmaxCrossEntropy`),new it(e)}return t}function Sl(e,n){return s((()=>{const s=t.mul(.5,t.onesLike(n)),i=Xt(t.greater(n,s),e.dtype);return t.mean(t.equal(e,i),-1)}))}function xl(e,n){return s((()=>Xt(t.equal(t.argMax(e,-1),t.argMax(n,-1)),"float32")))}function Nl(e,n){return s((()=>t.cast(t.sum(t.logicalAnd(t.equal(e,1),t.equal(n,1))),"float32")))}function Il(e,n){return s((()=>{const i=Nl(e,n),r=function(e,n){return s((()=>t.cast(t.sum(t.logicalAnd(t.equal(e,0),t.equal(n,1))),"float32")))}(e,n),a=t.add(i,r);return t.cast(t.where(t.greater(a,0),t.div(i,a),0),"float32")}))}function Al(e,n){return s((()=>{const i=Nl(e,n),r=function(e,n){return s((()=>t.cast(t.sum(t.logicalAnd(t.equal(e,1),t.equal(n,0))),"float32")))}(e,n),a=t.add(i,r);return t.cast(t.where(t.greater(a,0),t.div(i,a),0),"float32")}))}function Tl(t,e){return bl(t,e)}function El(e,n){return e.rank===n.rank&&(e=t.squeeze(e,[e.rank-1])),(n=t.argMax(n,-1)).dtype!==e.dtype&&(n=t.cast(n,e.dtype)),t.cast(t.equal(e,n),"float32")}const Cl=ml,zl=yl,$l={binaryAccuracy:Sl,categoricalAccuracy:xl,precision:Il,categoricalCrossentropy:Cl,sparseCategoricalCrossentropy:zl,mse:dl,MSE:dl,mae:fl,MAE:fl,mape:gl,MAPE:gl,cosine:wl};function Fl(t){if("string"==typeof t&&t in $l)return $l[t];if("string"!=typeof t&&null!=t)return t;throw new it(`Unknown metric ${t}`)}function Dl(t){if(ut(null!==t,`Unknown LossOrMetricFn ${t}`),"string"==typeof t)return t;{let e;for(const n of Object.keys(kl))if(kl[n]===t){e=n;break}if(void 0!==e)return e;for(const n of Object.keys($l))if($l[n]===t){e=n;break}return void 0!==e?e:t.name}}function Ll(t,e,n=!1){if(null==t||"object"!=typeof t||Object.getPrototypeOf(t)!==Object.prototype||!_l(t))throw new Error("User-defined metadata is expected to be a JSON object, but is not.");if(n){const n=JSON.stringify(t);n.length>1048576&&console.warn(`User-defined metadata of model "${e}" is too large in size (length=${n.length} when serialized). It is not recommended to store such large objects in user-defined metadata. Please make sure its serialized length is <= 1048576.`)}}function _l(t){if(null===t)return!0;if("object"==typeof t){if(Object.getPrototypeOf(t)===Object.prototype){const e=Object.keys(t);for(const n of e){if("string"!=typeof n)return!1;if(!_l(t[n]))return!1}return!0}if(Array.isArray(t)){for(const e of t)if(!_l(e))return!1;return!0}return!1}{const e=typeof t;return"string"===e||"number"===e||"boolean"===e}}function Rl(t,e,n,s=console.log){const i=function(t){let e=!0;const n=[],s=[];for(const e in t.nodesByDepth)n.push(t.nodesByDepth[e]);for(const t of n){if(t.length>1||1===t.length&&t[0].inboundLayers.length>1){e=!1;break}s.push(...t)}if(e)for(const n of t.layers){let t=!1;for(const i of n.inboundNodes)if(-1!==s.indexOf(i)){if(t){e=!1;break}t=!0}if(!e)break}return e}(t),r=["Layer (type)","Input Shape","Output shape","Param #"];let a;if(i?(e=e||90,n=n||[.32,.61,.89,1]):(e=e||115,n=n||[.24,.48,.7,.8,1]),n[n.length-1]<=1&&(n=n.map((t=>Math.floor(e*t)))),!i){r.push("Receives inputs"),a=[];for(const e in t.nodesByDepth)a.push(...t.nodesByDepth[e])}s("_".repeat(e)),Ml(r,n,s),s("=".repeat(e));const o=t.layers;for(let t=0;t0&&(s=s.slice(0,s.length-1)+" "),s+=t[n],s=s.slice(0,e[n]),s+=" ".repeat(e[n]-s.length);n(s)}function Ol(t,e,n){let s,i;try{i=t.inboundNodes.map((t=>JSON.stringify(t.inputShapes))).join(",")}catch(t){i="multiple"}try{s=JSON.stringify(t.outputShape)}catch(t){s="multiple"}Ml([`${t.name} (${t.getClassName()})`,i,s,t.countParams().toString()],e,n)}function Bl(t,e,n,s){let i,r;try{r=t.inboundNodes.map((t=>JSON.stringify(t.inputShapes))).join(",")}catch(t){r="multiple"}try{i=JSON.stringify(t.outputShape)}catch(t){i="multiple"}const a=[];for(const e of t.inboundNodes)if(!(null!=n&&n.length>0&&-1===n.indexOf(e)))for(let t=0;tt.name))}`);kt(this.outputs).length!==this.outputs.length&&console.warn(`The list of outputs passed to the model is redundant. All outputs should only appear once. Found: ${this.outputs.map((t=>t.name))}`),this.inputLayers=[],this.inputLayersNodeIndices=[],this.inputLayersTensorIndices=[],this.outputLayers=[],this.outputLayersNodeIndices=[],this.outputLayersTensorIndices=[],this.layers=[],this.internalContainerRefs=[];for(const t of this.outputs){const e=t.sourceLayer,n=t.nodeIndex,s=t.tensorIndex;this.outputLayers.push(e),this.outputLayersNodeIndices.push(n),this.outputLayersTensorIndices.push(s)}for(const t of this.inputs){const e=t.sourceLayer,n=t.nodeIndex,s=t.tensorIndex;ut(0===n,"input layer has >1 nodes"),ut(0===s,"input layer has >1 tensors"),this.inputLayers.push(e),this.inputLayersNodeIndices.push(n),this.inputLayersTensorIndices.push(s)}this.inputNames=[],this.outputNames=[],this.feedInputShapes=[],this.feedInputNames=[],this.feedOutputNames=[];for(let e=0;et.shape)),this.internalOutputShapes=this.outputs.map((t=>t.shape));const e={},n={},s={},i={},r={},a=[],o=(t,e,n,s,i,l)=>{null!=s&&null!=i&&null!=l||(s=t.sourceLayer,i=t.nodeIndex,l=t.tensorIndex);const u=s.inboundNodes[i];if(-1!==n.indexOf(u))throw new st(`The tensor ${t.name} at layer "${s.name}" is part of a cycle.`);if(-1!==e.indexOf(u))return;this.containerNodes.add(Vl.nodeKey(s,i)),s.id in r||(r[s.id]=Object.keys(r).length),-1===n.indexOf(u)&&n.push(u);const h=u.inboundLayers.length;for(let t=0;t=0;)n.splice(n.indexOf(u),1);a.push(u)},l=[],u=[];for(const t of this.outputs)o(t,l,u);const h=a.slice().reverse();for(const t of h){n[t.id]=t,t.id in e||(e[t.id]=0);let r=e[t.id];const a=null==s[t.outboundLayer.id]?0:s[t.outboundLayer.id];r=Math.max(r,a),s[t.outboundLayer.id]=r,i[t.outboundLayer.id]=t.outboundLayer,e[t.id]=r;for(let s=0;sparseInt(t,10))).sort(wt);this.layers=[];for(const t of d){const e=p[t];e.sort(((t,e)=>{const n=r[t.id],s=r[e.id];return ns?1:0}));for(const t of e)t instanceof Vl&&this.internalContainerRefs.push(t),this.layers.push(t)}this.layersByDepth=p,d=Object.keys(c).map((t=>parseInt(t,10))).sort(wt);const f=this.inputs.slice(),g=[];for(const t of d)for(const e of c[t]){const t=e.outboundLayer;if(null!=t){for(const n of e.inputTensors)if(-1===f.indexOf(n))throw new st(`Graph disconnected: cannot obtain value for tensor ${n} at layer "${t.name}". The following previous layers were accessed without issue: ${g}`);for(const t of e.outputTensors)f.push(t);g.push(t.name)}}this.nodesByDepth=c;const m=this.layers.map((t=>t.name));for(const t of m){const e=m.filter((e=>e===t)).length;if(1!==e)throw new st(`The name "${t}" is used ${e} times in the model. All layer names should be unique. Layer names: `+JSON.stringify(m))}this.outboundNodes=[],this.inboundNodes=[],new Ge({outboundLayer:this,inboundLayers:[],nodeIndices:[],tensorIndices:[],inputTensors:this.inputs,outputTensors:this.outputs,inputMasks:this.inputs.map((t=>null)),outputMasks:this.outputs.map((t=>null)),inputShapes:this.inputs.map((t=>t.shape)),outputShapes:this.outputs.map((t=>t.shape))}),this.built=!0,this._refCount=1}assertNotDisposed(){if(0===this._refCount)throw new Error(`Container '${this.name}' is already disposed.`)}dispose(){this.assertNotDisposed();const t={refCountAfterDispose:null,numDisposedVariables:0};if(0==--this._refCount){for(const e of this.layers)t.numDisposedVariables+=e.dispose().numDisposedVariables;for(const e of this.internalContainerRefs)t.numDisposedVariables+=e.dispose().numDisposedVariables}return t.refCountAfterDispose=this._refCount,t}get trainable(){return this.trainable_}set trainable(t){this.layers.forEach((e=>{e._trainableWeights.forEach((e=>e.trainable=t))})),this.trainable_=t}get trainableWeights(){if(this._trainableWeights.length>0)throw new it("Container instance unexpectedly contains _trainableWeights.The trainable weights of a Container are a union of the trainable weights of its consituent Layers. Its own _trainableWeights must remain an empty Array.");if(!this.trainable)return[];let t=[];for(const e of this.layers)t=t.concat(e.trainableWeights);return t}get nonTrainableWeights(){const t=[];for(const e of this.layers)t.push(...e.nonTrainableWeights);if(!this.trainable){const e=[];for(const t of this.layers)e.push(...t.trainableWeights);return e.concat(t)}return t}get weights(){return this.trainableWeights.concat(this.nonTrainableWeights)}loadWeights(t,e=!0){const n={};let s=0;const i=(t=>{const e=Object.keys(t);if(0===e.length)return!1;const n=e[0].split("/");return!isNaN(parseInt(n[n.length-1],10))})(t);i&&this.parseWeights(t);for(const t of this.layers)for(const[e,r]of t.weights.entries()){const t=i?`${r.name.split("/").slice(0,-1).join("/")+"/"}${e}`:r.originalName;if(null!=n[t])throw new it(`Duplicate weight name: ${t}`);n[t]=r,s++}const r=[];for(const s in t){let i=s;if(null==n[s]){const t=s.split("/");i=t.slice(0,-2).concat([t[t.length-1]]).join("/")}if(null!=n[i])r.push([n[i],t[s]]);else if(e)throw new it(`Provided weight data has no target variable: ${s}`);delete n[i]}if(e){const t=[];for(const e in n)t.push(e);if(t.length>0)throw new it(`${t.length} of ${s} weights are not set: ${t}`)}je(r)}parseWeights(t){for(const e in Object.keys(t)){const n=e.split("/"),s=["vars","layer_checkpoint_dependencies"],i=n.map((t=>t.startsWith("_")?t.slice(1):t)).filter((t=>!s.includes(t))).join("/");i!==e&&(t[i]=t[e],delete t[e])}}updatedConfig(){const t=this.getConfig(),e={};return e.className=this.getClassName(),e.config=t,e.kerasVersion="tfjs-layers 4.15.0",e.backend="TensorFlow.js",e}toJSON(t,e=!0){const n=Wl(this.updatedConfig());return e?JSON.stringify(n):n}call(t,e){return s((()=>{t=pt(t);const n=new Qe;for(let e=0;e{let n;return t=pt(t),n=null==e?lt(null,t.length):pt(e),this.runInternalGraph(t,n)[1]}))}computeOutputShape(t){const e=Me(t);if(e.length!==this.inputLayers.length)throw new it(`Invalid inputShape argument ${t}: model has ${this.inputLayers.length} tensor inputs.`);const n={};for(let t=0;tparseInt(t,10))).sort(wt);if(s.length>1)for(const t of s){const e=this.nodesByDepth[t];for(const t of e){const e=t.outboundLayer;if(-1!==this.inputLayers.map((t=>t.id)).indexOf(e.id))continue;const s=[];for(let e=0;eparseInt(t,10))).sort(wt);for(const t of s){const e=this.nodesByDepth[t];for(const t of e){const e=t.outboundLayer,s=t.inputTensors,i=t.outputTensors,r=new Array;for(const t of s)t.id in n&&r.push(n[t.id]);if(r.length===s.length){let s,a,o,l,u={};if(null!=t.callArgs&&(u=t.callArgs),1===r.length){const[t,n]=r[0];null==u.mask&&(u.mask=n),o=pt(e.call(t,u)),l=pt(e.computeMask(t,n)),s=[t],a=[n]}else s=r.map((t=>t[0])),a=r.map((t=>t[1])),null==u.mask&&(u.mask=a),o=pt(e.call(s,u)),l=pt(e.computeMask(s,a));if(e.activityRegularizer)throw new rt("LayersModel invocation with concrete Tensor value(s) in the presence of activity regularizer(s) is not supported yet.");for(let t=0;t{const t=[];for(const e of this.layers)for(let n=0;n0){const t=[];for(let n=0;n0&&t.apply(ct(n),s)}function l(t){const n=t.name,r=cl(t,null!=e.customObjects?e.customObjects:{});r.setFastWeightInitDuringBuild(s),i[n]=r;t.inboundNodes.forEach((t=>{if(!(t instanceof Array))throw new it(`Corrupted configuration, expected array for nodeData: ${t}`);a(r,t)}))}const u=e.name,h=e.layers;for(const t of h)l(t);for(;!vt(r);)for(const t of h){const e=i[t.name];if(e.name in r){const t=r[e.name];delete r[e.name];for(const n of t)o(e,n)}}const c=[],p=[],d=e.inputLayers;for(const t of d){const e=t[0],n=t[1],s=t[2];ut(e in i);const r=i[e].inboundNodes[n].outputTensors;c.push(r[s])}const f=e.outputLayers;for(const t of f){const e=t[0],n=t[1],s=t[2];ut(e in i);const r=i[e].inboundNodes[n].outputTensors;p.push(r[s])}return new t({inputs:c,outputs:p,name:u})}get stateful(){if(this._stateful)throw new it("Container instance unexpectedly has _stateful = true. The statefulness of a Container is determined by the Layers it contains. Its _stateful property must remain the default false.");for(const t of this.layers)if(t.stateful)return!0;return!1}resetStates(){s((()=>{this.layers.forEach((t=>{t.stateful&&t.resetStates()}))}))}}function ql(t,e){return function(t,e,n){const s=e.length;if(null==t||Array.isArray(t)&&0===t.length)return e.map((t=>null));if(1===s)return Array.isArray(t)&&1===t.length?t:"object"==typeof t&&e[0]in t?[t[e[0]]]:[t];if(Array.isArray(t)){if(t.length!==s)throw new Error(`Provided ${n} is an array of ${t.length} element(s), but the model has ${s} outputs. Make sure a set of weights is provided for each model output.`);return t}if("object"==typeof t&&Object.keys(t).length>0&&"object"==typeof t[Object.keys(t)[0]]){const n=[];return e.forEach((e=>{e in t?n.push(t[e]):n.push(null)})),n}throw new Error(`The model has multiple (${s}) outputs, so ${n} must be either an array with ${s} elements or an object with ${e} keys. Provided ${n} not understood: ${JSON.stringify(t)}`)}(t,e,"classWeight")}async function Kl(t,e,n,r){if(null!=e||null!=r)throw new Error("Support sampleWeight is not implemented yet");if(null!=n){const e=s((()=>{if(1===t.shape.length)return x(t);if(2===t.shape.length){if(t.shape[1]>1){return N(t,1)}if(1===t.shape[1])return I(t,[t.shape[0]]);throw new Error(`Encountered unexpected last-dimension size (${t.shape[1]}) during handling of class weights. The size is expected to be >= 1.`)}throw new Error(`Unexpected rank of target (y) tensor (${t.rank}) during handling of class weights. The rank is expected to be 1 or 2.`)})),r=Array.from(await e.data());f(e);const a=[];return r.forEach((t=>{if(null==n[t])throw new Error(`classWeight must contain all classes in the training data. The class ${t} exists in the data but not in classWeight`);a.push(n[t])})),i(a,"float32")}return null}function Gl(t,e){return l(t,e)}function Hl(e,n){let s,i;const r=n;s=r.xs,i=r.ys,t.util.assert(null!=s&&null!=i,(()=>`A Dataset iterator for fitDataset() is expected to generate objects of the form \`{xs: xVal, ys: yVal}\`, where the two values may be \`tf.Tensor\`, an array of Tensors, or a map of string to Tensor. The provided Dataset instead generates ${n}`));const a=Jl("input",e.inputNames,s),o=Jl("output",e.outputNames,i),l=a[0].shape[0];t.util.assert(a.length===e.inputs.length,(()=>`LayersModel has ${e.inputs.length} inputs, but the dataset provides ${a.length} inputs. (Expected input keys: ${JSON.stringify(e.inputNames)})`)),t.util.assert(o.length===e.outputs.length,(()=>`LayersModel has ${e.outputs.length} outputs, but the dataset provides ${o.length} outputs. (Expected output keys: ${JSON.stringify(e.outputNames)})`));for(let n=0;n`Batch size mismatch: input ${e.inputNames[n]} has ${a[n].shape[0]}; expected ${l} based on input ${e.inputNames[0]}.`));for(let n=0;n`Batch size mismatch: output ${e.outputNames[n]} has ${o[n].shape[0]}; expected ${l} based on input ${e.inputNames[0]}.`));return{xs:a,ys:o}}function Jl(e,n,s){if(s instanceof t.Tensor)return[s];if(Array.isArray(s))return t.util.assert(s.length===n.length,(()=>`Received an array of ${s.length} Tensors, but expected ${n.length} to match the ${e} keys ${n}.`)),s;{const t=[];for(const i of n){if(null==s[i])throw new it(`The feature data generated by the dataset lacks the required ${e} key '${i}'.`);t.push(s[i])}return t}}async function Zl(e,n,s){const i=null!=s.batchesPerEpoch;if(t.util.assert(null!=e.optimizer,(()=>"You must compile a model before training/testing. Use LayersModel.compile(modelCompileConfig).")),t.util.assert(null!=s,(()=>"For fitDataset(), the 2nd argument (config) is required, but it is not provided in this call.")),t.util.assert(null!=s.epochs&&s.epochs>0&&Number.isInteger(s.epochs),(()=>`For fitDataset(), config.epochs is expected to be a positive integer, but got ${s.epochs}`)),t.util.assert(!i||s.batchesPerEpoch>0&&Number.isInteger(s.batchesPerEpoch),(()=>`For fitDataset(), config.batchesPerEpoch is expected to be a positive integer if specified, but got ${s.batchesPerEpoch}`)),t.util.assert(null==s.validationSplit,(()=>"`validationSplit` is not supported by `fitDataset()`. Use validationData instead.")),e.isTraining)throw new Error("Cannot start training because another fit() call is ongoing.");e.isTraining=!0;try{const r=null!=s.validationData;let a,o;if(r)if(Yl(s.validationData))t.util.assert(null==s.validationBatches||s.validationBatches>0&&Number.isInteger(s.validationBatches),(()=>`For fitDataset() with dataset-based validation, config.validationBatches is expected not to be provided, or to be a positive integer, but got ${s.validationBatches}`));else{const t=function(t){if(3===t.length)throw new rt("Validation with sample weights is not implemented yet.");return{xs:t[0],ys:t[1]}}(s.validationData);a=t.xs,o=t.ys}const l=e.makeTrainFunction(),u=e.getDedupedMetricsNames();let h;h=r?u.slice().concat(u.map((t=>"val_"+t))):u.slice();const c=ll(s.callbacks,s.yieldEvery),p=null==s.verbose?1:s.verbose,{callbackList:d,history:f}=hl(c,p,s.epochs,null,null,function(t,e){let n=null;null!=e.batchesPerEpoch?n=e.batchesPerEpoch:Number.isFinite(t.size)&&(n=t.size);return n}(n,s),null,r,h);d.setModel(e),e.history=f,await d.onTrainBegin(),e.stopTraining_=!1;let g=null==s.initialEpoch?0:s.initialEpoch,m=await n.iterator();for(;g=s.batchesPerEpoch:n.done){if(r){let t;t=Yl(s.validationData)?pt(await e.evaluateDataset(s.validationData,{batches:s.validationBatches})):pt(e.evaluate(a,o,{batchSize:null==s.validationBatchSize?32:s.validationBatchSize,verbose:0}));for(let n=0;n0&&Number.isInteger(e),(()=>`batchSize is required to be a positive integer, but got ${e}`))}function Ql(t,e,n){return null==t?[null]:Array.isArray(t)?t.map((t=>te(t,e,n-e))):te(t,e,n-e)}function tu(e,n){return t.tidy((()=>null==e?null:Array.isArray(e)?e.map((t=>tu(t,n))):le(e,"int32"===n.dtype?n:t.cast(n,"int32"))))}function eu(t,e){const n=[];let s=0,i=null;for(;s=t&&(i=t),n.push([s,i]),s=i;return n}function nu(t){const e=[];t instanceof A&&(t=[t]);for(let n=0;nn.push(t.id)));else if(null!=e)for(const t in e){const s=e[t];n.push(s.id)}const s=[];if(t instanceof A)-1===n.indexOf(t.id)&&s.push(t);else if(Array.isArray(t))t.forEach((t=>{-1===n.indexOf(t.id)&&s.push(t)}));else if(null!=t)for(const e in t){const i=t[e];-1===n.indexOf(i.id)&&s.push(i)}s.forEach((t=>{t.isDisposed||t.dispose()}))}function iu(t){return Array.isArray(t)}function ru(t){return!function(t){return t instanceof A}(t)&&!iu(t)}function au(t,e,n,s=!0,i=""){if(null==e||0===e.length){if(null!=t){let e=!1;if(iu(t)&&t.length>0)e=!0;else if(ru(t)){for(const n in t)if(t.hasOwnProperty(n)){e=!0;break}}else e=!0;if(e)throw new it(`Error when checking model ${i} expected no data, but got ${t}`)}return[]}if(null==t)return e.map((t=>null));let r;if(ru(t)){r=[];for(const n of e){if(null==t[n])throw new it(`No data provided for "${n}". Need data for each key in: ${e}`);r.push(t[n])}}else if(iu(t)){if(t.length!==e.length)throw new it(`Error when checking model ${i}: the Array of Tensors that you are passing to your model is not the size the model expected. Expected to see ${e.length} Tensor(s), but instead got the following list of Tensor(s): ${t}`);r=t}else{if(e.length>1)throw new it(`The model ${i} expects ${e.length} Tensor(s), but only received one Tensor. Found: Tensor with shape ${t.shape}`);r=[t]}if(r=nu(r),null!=n)for(let t=0;t=0&&r!==o)throw new it(`${i} expected a batch of elements where each example has shape [${n[t].slice(1,n[t].length)}] (i.e.,tensor shape [*,${n[t].slice(1,n[t].length)}]) but the ${i} received an input with ${a.shape[0]} examples, each with shape [${a.shape.slice(1,a.shape.length)}] (tensor shape [${a.shape}])`)}}return r}function ou(t,e,n,s=!0,i=""){let r;if(Array.isArray(t)){if(t.length!==e.length)throw new it(`Error when checking model ${i}: the Array of Tensors that you are passing to your model is not the size the the model expected. Expected to see ${e.length} Tensor(s), but instead got ${t.length} Tensors(s).`);r=t}else{if(e.length>1)throw new it(`The model expects ${e.length} ${i} Tensors, but only received one Tensor. Found: array with shape ${JSON.stringify(t.shape)}.`);r=[t]}if(null!=n)for(let t=0;tS.adagrad(.01),Adadelta:()=>S.adadelta(1,.95,Yt()),Adam:()=>S.adam(.001,.9,.999,Yt()),Adamax:()=>S.adamax(.002,.9,.999,Yt(),0),RMSProp:()=>S.rmsprop(.001,.9,0,Yt()),SGD:()=>S.sgd(.01)};if(e.adagrad=e.Adagrad,e.adadelta=e.Adadelta,e.adam=e.Adam,e.adamax=e.Adamax,e.rmsprop=e.RMSProp,e.sgd=e.SGD,t in e)return e[t]();throw new it(`Unknown Optimizer ${t}`)}(t.optimizer),this.isOptimizerOwned=!0;else{if(!(t.optimizer instanceof T))throw new it("User-defined optimizer must be an instance of tf.Optimizer.");this.optimizer_=t.optimizer,this.isOptimizerOwned=!1}let e=[];if(Array.isArray(t.loss)||"string"==typeof t.loss||"function"==typeof t.loss)if(Array.isArray(t.loss)){if(t.loss.length!==this.outputs.length)throw new it(`When passing an Array as loss, it should have one entry per model output. The model has ${this.outputs.length} output(s), but you passed loss=${t.loss}.`);const n=t.loss;e=n.map((t=>vl(t)))}else{const n=vl(t.loss);this.outputs.forEach((t=>{e.push(n)}))}else{t.loss=t.loss;for(const e in t.loss)if(-1===this.outputNames.indexOf(e))throw new it(`Unknown entry in loss dictionary: "${e}". Only expected the following keys: ${this.outputNames}`);for(const n of this.outputNames)null==t.loss[n]&&console.warn(`Output "${n}" is missing from loss dictionary. We assume this was done on purpose, and we will not be expecting data to be passed to ${n} during training`),e.push(vl(t.loss[n]))}this.lossFunctions=e,this.feedOutputNames=[],this.feedOutputShapes=[],this.feedLossFns=[];for(let t=0;t{for(let t=0;t1&&(this.metricsTensors.push([e,t]),this.metricsNames.push(this.outputNames[t]+"_loss"))}}));const s=function(t,e){if(null==t||Array.isArray(t)&&0===t.length)return e.map((t=>[]));let n;if("string"==typeof t||"function"==typeof t)n=[t];else{if(!Array.isArray(t)&&"object"!=typeof t)throw new TypeError(`Type of metrics argument not understood. Expected an string,function, Array, or Object, found: ${t}`);n=t}if(Array.isArray(n))return e.map((t=>n));{const t=[];for(const s of e){let e=n.hasOwnProperty(s)?n[s]:[];Array.isArray(e)||(e=[e]),t.push(e)}return t}}(t.metrics,this.outputNames),i=(t,e,n)=>{this.outputNames.length>1&&(e=this.outputNames[t]+"_"+e),this.metricsNames.push(e),this.metricsTensors.push([n,t])};Ut("metric",(()=>{for(let t=0;t{let n,s,r;for(const a of e){if("string"==typeof a&&-1!==["accuracy","acc","crossentropy","ce"].indexOf(a)){const e=this.internalOutputShapes[t];let i;1===e[e.length-1]||this.lossFunctions[t]===bl?-1!==["accuracy","acc"].indexOf(a)?s=Sl:-1!==["crossentropy","ce"].indexOf(a)&&(s=Tl):this.lossFunctions[t]===yl?-1!==["accuracy","acc"].indexOf(a)?s=El:-1!==["crossentropy","ce"].indexOf(a)&&(s=zl):-1!==["accuracy","acc"].indexOf(a)?s=xl:-1!==["crossentropy","ce"].indexOf(a)&&(s=Cl),-1!==["accuracy","acc"].indexOf(a)?i="acc":-1!==["crossentropy","ce"].indexOf(a)&&(i="ce"),r=s,n=""+i}else{const t=Fl(a);r=t,n=""+Dl(a)}let e;Ut(n,(()=>{e=r})),i(t,n,e)}})(s[t])}})),this.collectedTrainableWeights=this.trainableWeights}checkTrainableWeightsConsistency(){null!=this.collectedTrainableWeights&&this.trainableWeights.length!==this.collectedTrainableWeights.length&&console.warn("Discrepancy between trainableweights and collected trainable weights. Did you set `model.trainable` without calling `model.compile()` afterwards?")}evaluate(t,e,n={}){const s=null==n.batchSize?32:n.batchSize;Xl(s);const i=this.standardizeUserDataXY(t,e,!0,s);try{const r=i[0].concat(i[1]);this.makeTestFunction();const a=this.testFunction;return ct(this.testLoop(a,r,s,n.verbose,n.steps))}finally{su(i[0],t),su(i[1],e)}}async evaluateDataset(e,n){return this.makeTestFunction(),async function(e,n,s){const i=null!=(s=s||{}).batches,r=e.testFunction;let a=[];if(s.verbose>0)throw new rt("Verbose mode is not implemented yet.");t.util.assert(!i||s.batches>0&&Number.isInteger(s.batches),(()=>`Test loop expects \`batches\` to be a positive integer, but received ${JSON.stringify(s.batches)}`));const o="function"==typeof n.next?n:await n.iterator();let l=0,h=0;for(;!i||h{if(n.value){const{xs:s,ys:i}=Hl(e,n.value),o=s.concat(i),c=t.tidy((()=>r(o)));if(t.dispose(o),0===h)for(let t=0;tt.add(a[e],t.mul(p,n)))),h>0&&t.dispose(s)}t.dispose(c),l+=p,++h}return a})),n.done){i&&console.warn(`Your dataset iterator ran out of data during evaluateDataset(). Interrupting evalution. Make sure that your dataset can generate at least \`batches\` batches (in this case, ${s.batches} batches). You may need to use the repeat() function when building your dataset.`);break}}for(let e=0;et.name));for(let s=0;s0){const n=[];throw e.forEach(((e,s)=>{null==e&&n.push(t[s])})),new it(`Cannot find SymbolicTensors for output name(s): ${JSON.stringify(n)}`)}return e}predictLoop(e,n=32,s=!1){return t.tidy((()=>{const i=this.checkNumSamples(e);if(s)throw new rt("Verbose predictLoop() is not implemented yet.");const r=eu(i,n),a=this.outputs.map((t=>[]));for(let n=0;n{const t=r[n][0],s=r[n][1],i=Ql(e,t,s),a=[];if(Array.isArray(i))for(let t=0;ta[e].push(t)))}return ct(a.map((e=>t.concat(e,0))))}))}predict(t,e={}){const n=nu(t);ou(n,this.inputNames,this.feedInputShapes,!1);try{const s=null==e.batchSize?32:e.batchSize;return Xl(s),this.predictLoop(n,s)}finally{su(n,t)}}predictOnBatch(t){ou(t,this.inputNames,this.feedInputShapes,!0);const e=(Array.isArray(t)?t[0]:t).shape[0];return this.predictLoop(t,e)}standardizeUserDataXY(t,n,s=!0,i){if(null==this.optimizer_)throw new st("You must compile a model before training/testing. Use LayersModel.compile(modelCompileArgs).");const r=[];for(let t=0;tt.shape[0])));i.sort();const r=kt(n.map((t=>t.shape[0])));if(r.sort(),i.length>1)throw new it(`All input Tensors (x) should have the same number of samples. Got array shapes: ${JSON.stringify(t.map((t=>t.shape)))}`);if(r.length>1)throw new it(`All target Tensors (y) should have the same number of samples. Got array shapes: ${JSON.stringify(n.map((t=>t.shape)))}`);if(i.length>0&&r.length>0&&!e.arraysEqual(i,r))throw new it(`Input Tensors should have the same number of samples as target Tensors. Found ${i[0]} input sample(s) and ${r[0]} target sample(s).`)}(t=au(t,this.feedInputNames,this.feedInputShapes,!1,"input"),n=au(n,this.feedOutputNames,r,!1,"target")),function(t,e,n){const s=[dl,bl,ml];for(let i=0;i0&&t[0].shape[0]%i!=0)throw new it(`In a stateful network, you should only pass inputs with a number of samples that is divisible by the batch size ${i}. Found: ${t[0].shape[0]} sample(s).`);return[t,n]}async standardizeUserData(t,e,n,s,i=!0,r){const[a,o]=this.standardizeUserDataXY(t,e,i,r);if(null!=n)throw new Error("sample weight is not supported yet.");let l=null;if(null!=s){const t=ql(s,this.outputNames);l=[];for(let e=0;e{const o=this.checkNumSamples(n,s,a,"steps"),l=[];if(r>0)throw new rt("Verbose mode is not implemented yet.");if(null!=a)throw new rt("steps mode in testLoop() is not implemented yet");{const r=eu(o,s),a=i(Jt(0,o));for(let s=0;s1){i+=`_${ht(t.slice(0,n),s)}`}e.push(i)}return e}makeTrainFunction(){return e=>{const n=[],s=e.slice(0,this.inputs.length),i=e.slice(this.inputs.length,this.inputs.length+this.outputs.length),r=e.slice(this.inputs.length+this.outputs.length,this.inputs.length+2*this.outputs.length),a=[],o=this.collectedTrainableWeights.map((t=>t.read()));return[this.optimizer_.minimize((()=>{const e=[];for(let t=0;t1&&e{u=t.add(u,e)})),u}),!0,o)].concat(a)}}makeTestFunction(){this.testFunction=e=>t.tidy((()=>{const n=[];let s;const i=e.slice(0,this.inputs.length),r=e.slice(this.inputs.length,this.inputs.length+this.outputs.length),a=[];for(let t=0;t0){if(y=!0,2!==s.validationData.length)throw 3===s.validationData.length?new rt("validationData including sample weights is not supported yet."):new it(`When passing validation data, it must contain 2 (valX, valY) or 3 (valX, valY, valSampleWeight) items; ${s.validationData} is invalid.`);l=s.validationData[0],u=s.validationData[1];const t=!0,e=await this.standardizeUserData(l,u,null,null,t,d);h=e[0],c=e[1],m=h.concat(c)}else if(null!=s.validationSplit&&s.validationSplit>0&&s.validationSplit<1){y=!0;const t=Math.floor(i[0].shape[0]*(1-s.validationSplit)),e=i[0].shape[0];h=Ql(i,t,e),a=i,i=Ql(i,0,t),c=Ql(r,t,e),o=r,r=Ql(r,0,t),m=h.concat(c)}else null!=s.validationSteps&&(y=!0);const b=i.concat(r).concat(p);this.checkTrainableWeightsConsistency();const w=this.makeTrainFunction(),k=this.getDedupedMetricsNames();let v,S;y?(this.makeTestFunction(),v=this.testFunction,S=k.slice().concat(k.map((t=>"val_"+t)))):(v=null,m=[],S=k.slice());const x=ll(s.callbacks,s.yieldEvery);return await this.fitLoop(w,b,k,d,s.epochs,s.verbose,x,v,m,s.shuffle,S,s.initialEpoch,null,null)}finally{this.isTraining=!1,su(i,e),su(r,n),su(a,e),su(o,n),su(h,l),su(c,u),null!=p&&t.dispose(p)}}async fitLoop(n,s,r,a,o,l,u,h,c,p,d,f,g,m){null==a&&(a=32),null==o&&(o=1),null==p&&(p=!0),null==f&&(f=0);let y=!1;if(null!=h&&null!=c&&(y=!0),null!=m&&(y=!0,null==g))throw new it("Can only use `validationSteps` when doing step-wise training, i.e., `stepsPerEpoch` must be set.");const b=this.checkNumSamples(s,a,g,"steps_per_epoch");let w;null!=b&&(w=Jt(0,b)),null==l&&(l=1);const{callbackList:k,history:v}=hl(u,l,o,f,b,g,a,y,d);k.setModel(this),this.history=v,await k.onTrainBegin(),this.stopTraining_=!1;for(let l=f;l{const p=u[e][0],d=u[e][1],f=te(l,p,d-p);i.batch=e,i.size=d-p;const g=tu(s,f),m=n(g);for(let e=0;edt(t)))}else{const e=Object.keys(this.loss);t={};const n=this.loss;for(const s of e){if("string"!=typeof n[s])throw new Error("Serialization of non-string loss is not supported.");t[s]=dt(n[s])}}return t}getMetricIdentifiers(){if("string"==typeof this.metrics||"function"==typeof this.metrics)return[dt(Dl(this.metrics))];if(Array.isArray(this.metrics))return this.metrics.map((t=>dt(Dl(t))));{const t={};for(const e in this.metrics)t[e]=dt(Dl(this.metrics[e]));return t}}getTrainingConfig(){return{loss:this.getLossIdentifiers(),metrics:this.getMetricIdentifiers(),optimizer_config:{class_name:this.optimizer.getClassName(),config:this.optimizer.getConfig()}}}loadTrainingConfig(t){if(null!=t.weighted_metrics)throw new Error("Loading weight_metrics is not supported yet.");if(null!=t.loss_weights)throw new Error("Loading loss_weights is not supported yet.");if(null!=t.sample_weight_mode)throw new Error("Loading sample_weight_mode is not supported yet.");const e=cl(Ul(t.optimizer_config));let n,s;if("string"==typeof t.loss)n=ft(t.loss);else if(Array.isArray(t.loss))n=t.loss.map((t=>ft(t)));else if(null!=t.loss){n={};for(const e in t.loss)n[e]=ft(t.loss[e])}if(Array.isArray(t.metrics))s=t.metrics.map((t=>ft(t)));else if(null!=t.metrics){s={};for(const e in t.metrics)s[e]=ft(t.metrics[e])}this.compile({loss:n,metrics:s,optimizer:e})}async save(t,e){if("string"==typeof t){const e=E.getSaveHandlers(t);if(0===e.length)throw new it(`Cannot find any save handlers for URL '${t}'`);if(e.length>1)throw new it(`Found more than one (${e.length}) save handlers for URL '${t}'`);t=e[0]}if(null==t.save)throw new it("LayersModel.save() cannot proceed because the IOHandler provided does not have the `save` attribute defined.");const n=await E.encodeWeights(this.getNamedWeights(e)),s={modelTopology:this.toJSON(null,!1),format:"layers-model",generatedBy:"TensorFlow.js tfjs-layers v4.15.0",convertedBy:null};if(null!=e&&e.includeOptimizer&&null!=this.optimizer){s.trainingConfig=this.getTrainingConfig();const t="optimizer",{data:e,specs:i}=await E.encodeWeights(await this.optimizer.getWeights(),t);n.specs.push(...i),n.data=E.concatenateArrayBuffers([n.data,e])}if(null!=this.userDefinedMetadata){const t=!0;Ll(this.userDefinedMetadata,this.name,t),s.userDefinedMetadata=this.userDefinedMetadata}return s.weightData=n.data,s.weightSpecs=n.specs,t.save(s)}setUserDefinedMetadata(t){Ll(t,this.name),this.userDefinedMetadata=t}getUserDefinedMetadata(){return this.userDefinedMetadata}}lu.className="Model",r.registerClass(lu);class uu extends lu{}async function hu(t,e){if(null==e&&(e={}),"string"==typeof t){const n=E.getLoadHandlers(t,e);if(0===n.length)n.push(E.browserHTTPRequest(t,e));else if(n.length>1)throw new it(`Found more than one (${n.length}) load handlers for URL '${t}'`);t=n[0]}return async function(t,e,n){null==n&&(n={});if(null==t.load)throw new it("Cannot proceed with model loading because the IOHandler provided does not have the `load` method implemented.");const s=await t.load();let i=s.modelTopology;null!=i.model_config&&(i=i.model_config);const r=null==n.strict||n.strict,a=null!=s.weightData&&null!=s.weightSpecs&&r,o=cl(Ul(i),e,a),l=s.trainingConfig;null!=l&&o.loadTrainingConfig(l);null!=s.userDefinedMetadata&&o.setUserDefinedMetadata(s.userDefinedMetadata);if(null!=s.weightData){if(null==s.weightSpecs)throw new it("LayersModel artifacts contains weight data, but not weight specs. Therefore loading of weights cannot proceed.");const{modelWeights:t,optimizerWeights:e}=function(t,e){const n=E.decodeWeights(t,e),s={},i=[];return e.forEach((t=>{"optimizer"===t.group?i.push({name:t.name,tensor:n[t.name]}):s[t.name]=n[t.name]})),{modelWeights:s,optimizerWeights:i}}(s.weightData,s.weightSpecs);o.loadWeights(t,r),null!=o.optimizer&&e.length>0&&await o.optimizer.setWeights(e),f(t),f(e.map((t=>t.tensor)))}return o}(t,void 0,e)}uu.className="Functional",r.registerClass(uu);class cu extends lu{constructor(t){if(super({inputs:[],outputs:[]}),t=t||{},this.trainable=!0,this.built=!1,this.name=null!=t.name?t.name:zt("sequential_"),null!=t.layers)for(const e of t.layers)this.add(e)}checkShape(t){if(t.inboundNodes[0].outputTensors[0].shape.some((t=>t<0)))throw new it(`Negative dimension size caused by adding layer ${t.name} with input shape [${t.inboundNodes[0].inputTensors[0].shape}]`)}add(t){const e=t instanceof cu||t instanceof lu;let n;if(e){if(n=t,1!==n.outputs.length)throw new it("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");if(1!==n.inputs.length)throw new it("All layers in a Sequential model should have a single input tensor. For multi-input layers, use the functional API.")}if(0===this.outputs.length){if(0===t.inboundNodes.length){if(null==t.batchInputShape)throw new it("The first layer in a Sequential model must get an `inputShape` or `batchInputShape` argument.");const e=Xe({batchShape:t.batchInputShape,dtype:t.dtype,name:t.name+"_input"});t.apply(e)}if(e)this.outputs=n.outputs,this.inputs=n.inputs;else{if(1!==t.inboundNodes.length)throw new it(`A layer added to a Sequential model must not already be connected somewhere else. LayersModel received layer ${t.name} which has ${t.inboundNodes.length} pre-existing inbound connections.`);if(1!==t.inboundNodes[0].outputTensors.length)throw new it("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");this.checkShape(t),this.outputs=[t.inboundNodes[0].outputTensors[0]],this.inputs=Ze(this.outputs[0])}this.inboundNodes=[],new Ge({outboundLayer:this,inboundLayers:[],nodeIndices:[],tensorIndices:[],inputTensors:this.inputs,outputTensors:this.outputs,inputMasks:lt(null,this.inputs.length),outputMasks:[null],inputShapes:this.inputs.map((t=>t.shape)),outputShapes:this.outputs[0].shape})}else{const e=t.apply(this.outputs[0]);if(Array.isArray(e))throw new TypeError("All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.");this.checkShape(t),this.outputs=[e],this.inboundNodes[0].outputTensors=this.outputs,this.inboundNodes[0].outputShapes=[this.outputs[0].shape]}this.layers.push(t),this.built=!1}pop(){if(0===this.layers.length)throw new TypeError("There are no layers in the model.");if(this.layers.pop(),0===this.layers.length)this.outputs=[],this.inboundNodes=[],this.outboundNodes=[];else{const t=this.layers.length-1;this.layers[t].outboundNodes=[],this.outputs=[this.layers[t].output],this.inboundNodes[0].outputTensors=this.outputs,this.inboundNodes[0].outputShapes=[this.outputs[0].shape]}}call(t,e){return null==this.model&&this.build(),this.model.call(t,e)}build(t){if(Be(t),0===this.inputs.length||0===this.outputs.length)throw new TypeError("Sequential model cannot be built: model is empty. Add some layers first.");this.model=new lu({inputs:this.inputs,outputs:this.outputs[0],name:this.name+"_model"}),this.model.trainable=this.trainable,this.supportsMasking=this.model.supportsMasking,this.inputLayers=this.model.inputLayers,this.inputLayersNodeIndices=this.model.inputLayersNodeIndices,this.inputLayersTensorIndices=this.model.inputLayersTensorIndices,this.outputLayers=this.model.outputLayers,this.outputLayersNodeIndices=this.model.outputLayersNodeIndices,this.outputLayersTensorIndices=this.model.outputLayersTensorIndices,this.nodesByDepth=this.model.nodesByDepth,this.containerNodes=this.model.containerNodes,this.outputNames=this.model.outputNames,this.inputNames=this.model.inputNames,this.built=!0}countParams(){return this.built||this.build(),super.countParams()}summary(t,e,n=console.log){this.built||this.build(),super.summary(t,e,n)}setWeights(t){null==this.model&&this.build(),this.model.setWeights(t)}evaluate(t,e,n={}){if(!this.built)throw new st("The model needs to be compiled before being used.");return this.model.evaluate(t,e,n)}async evaluateDataset(t,e){if(!this.built)throw new st("The model needs to be compiled before being used.");return this.model.evaluateDataset(t,e)}predict(t,e={}){return null==this.model&&this.build(),this.model.predict(t,e)}predictOnBatch(t){return null==this.model&&this.build(),this.model.predictOnBatch(t)}compile(t){this.build(),this.model.compile(t),this.optimizer_=this.model.optimizer,this.isOptimizerOwned=this.model.isOptimizerOwned,this.loss=this.model.loss,this.metrics=this.model.metrics,this.metricsTensors=this.model.metricsTensors,this.metricsNames=this.model.metricsNames}get optimizer(){return null==this.model?void 0:this.model.optimizer}set optimizer(t){this.model.optimizer=t}async fit(t,e,n={}){if(!this.built)throw new st("The model needs to be compiled before being used.");return this.model.fit(t,e,n)}async fitDataset(t,e){if(!this.built)throw new st("The model needs to be compiled before being used.");return this.model.fitDataset(t,e)}async trainOnBatch(t,e){return this.model.trainOnBatch(t,e)}static fromConfig(t,n,s={},i=!1){let r,a={};if(n instanceof Array){if(null==n[0].className||"Merge"===n[0].className)throw new it("Legacy serialization format not supported yet.");r=n}else e.assert(null!=n.layers,(()=>"When the config data for a Sequential model is not an Array, it must be an Object that contains the 'layers' field.")),r=n.layers,delete n.layers,a=n;const o=new t(a);if(!(o instanceof cu))throw new rt(`Sequential.fromConfig called on non-Sequential input: ${o}`);for(const t of r){const e=cl(t,void 0,i);i&&e.setFastWeightInitDuringBuild(!0),o.add(e)}return o}set stopTraining(t){if(null==this.model)throw new it("Cannot set the stopTraining property of a sequential model before it is compiled.");this.model.stopTraining=t}get stopTraining(){if(null==this.model)throw new it("Cannot get the stopTraining property of a sequential model before it is compiled.");return this.model.stopTraining}getConfig(){const t=[];for(const e of this.layers){const n={};n.className=e.getClassName(),n.config=e.getConfig(),t.push(n)}return{name:this.name,layers:t}}}function pu(t){return new lu(t)}function du(t){return new cu(t)}function fu(t){return Xe(t)}function gu(t,e){ul.registerCallbackConstructor(t,e)}cu.className="Sequential",r.registerClass(cu);let mu=class extends r.Serializable{getConfig(){return{}}};class yu extends mu{apply(e,n=1){return function(e,n=1){if(1!==n)throw new rt(`Support for alpha values other than 1 (${n}) is not implemented yet.`);return t.elu(e)}(e,n)}}yu.className="elu",r.registerClass(yu);class bu extends mu{apply(e){return t.selu(e)}}bu.className="selu",r.registerClass(bu);class wu extends mu{apply(e){return t.relu(e)}}wu.className="relu",r.registerClass(wu);class ku extends mu{apply(e){return s((()=>t.minimum(6,t.relu(e))))}}ku.className="relu6",r.registerClass(ku);class vu extends mu{apply(t){return t}}vu.className="linear",r.registerClass(vu);class Su extends mu{apply(e){return t.sigmoid(e)}}Su.className="sigmoid",r.registerClass(Su);class xu extends mu{apply(e){return function(e){return s((()=>{const n=t.add(.5,t.mul(.2,e));return t.clipByValue(n,0,1)}))}(e)}}xu.className="hardSigmoid",r.registerClass(xu);class Nu extends mu{apply(e){return t.softplus(e)}}Nu.className="softplus",r.registerClass(Nu);class Iu extends mu{apply(e){return function(e){return s((()=>t.div(e,t.add(t.abs(e),1))))}(e)}}Iu.className="softsign",r.registerClass(Iu);class Au extends mu{apply(e){return t.tanh(e)}}Au.className="tanh",r.registerClass(Au);let Tu=class extends mu{apply(e,n=-1){return t.softmax(e,n)}};Tu.className="softmax",r.registerClass(Tu);class Eu extends mu{apply(e,n=-1){return t.logSoftmax(e,n)}}Eu.className="logSoftmax",r.registerClass(Eu);class Cu extends mu{apply(e,n=1){return s((()=>t.mul(t.sigmoid(t.mul(e,n)),e)))}}Cu.className="swish",r.registerClass(Cu);class zu extends mu{apply(e){return s((()=>t.mul(e,t.tanh(t.softplus(e)))))}}function $u(t){return t.getClassName()}function Fu(t,e={}){return bt(t,r.SerializationMap.getMap().classNameMap,e,"activation")}function Du(t){if(null==t){const t={className:"linear",config:{}};return Fu(t)}if("string"==typeof t){const e={};return e.className=t,e.config={},Fu(e)}return t instanceof mu?t:Fu(t)}function Lu(t){if(null!=t&&"object"!=typeof t)throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an object, but received: ${t}`)}zu.className="mish",r.registerClass(zu);class _u extends r.Serializable{}class Ru extends _u{constructor(t){super(),Lu(t),this.l1=null==t||null==t.l1?.01:t.l1,this.l2=null==t||null==t.l2?.01:t.l2,this.hasL1=0!==this.l1,this.hasL2=0!==this.l2}apply(e){return s((()=>{let n=a([1]);return this.hasL1&&(n=w(n,C(t.mul(this.l1,z(e))))),this.hasL2&&(n=w(n,C(t.mul(this.l2,ue(e))))),t.reshape(n,[])}))}getConfig(){return{l1:this.l1,l2:this.l2}}static fromConfig(t,e){return new t({l1:e.l1,l2:e.l2})}}Ru.className="L1L2",r.registerClass(Ru);const Mu={l1l2:"L1L2"};function Ou(t){return mt(t)}function Bu(t,e={}){return bt(t,r.SerializationMap.getMap().classNameMap,e,"regularizer")}function Pu(t){if(null==t)return null;if("string"==typeof t){return Bu({className:t in Mu?Mu[t]:t,config:{}})}return t instanceof _u?t:Bu(t)}class Uu extends Je{constructor(t){super(null==t?{}:t),this.supportsMasking=!0,null!=t&&(this.maxValue=t.maxValue)}call(t,e){t=Oe(t);let n=$(t);return null!=this.maxValue&&(n=F(n,0,this.maxValue)),n}computeOutputShape(t){return t}getConfig(){const t={maxValue:this.maxValue},e=super.getConfig();return Object.assign(t,e),t}}Uu.className="ReLU",r.registerClass(Uu);class Wu extends Je{constructor(t){super(null==t?{}:t),this.DEFAULT_ALPHA=.3,null==t&&(t={}),this.alpha=null==t.alpha?this.DEFAULT_ALPHA:t.alpha}call(t,e){const n=Oe(t);return D(n,this.alpha)}computeOutputShape(t){return t}getConfig(){const t={alpha:this.alpha},e=super.getConfig();return Object.assign(t,e),t}}Wu.className="LeakyReLU",r.registerClass(Wu);class ju extends Je{constructor(t){if(super(null==t?{}:t),this.DEFAULT_ALPHA_INITIALIZER="zeros",null==t&&(t={}),this.supportsMasking=!0,this.alphaInitializer=_e(t.alphaInitializer||this.DEFAULT_ALPHA_INITIALIZER),this.alphaRegularizer=Pu(t.alphaRegularizer),this.alphaConstraint=Yo(t.alphaConstraint),null==t.sharedAxes)this.sharedAxes=null;else if(Array.isArray(t.sharedAxes))this.sharedAxes=t.sharedAxes;else{if("number"!=typeof t.sharedAxes)throw new it(`Expected sharedAxes to be a number or an array of numbers, but got ${t.sharedAxes}`);this.sharedAxes=[t.sharedAxes]}}build(t){const e=(t=Be(t)).slice(1);if(null!=this.sharedAxes)for(const t of this.sharedAxes)e[t-1]=1;this.alpha=this.addWeight("alpha",e,"float32",this.alphaInitializer,this.alphaRegularizer,!0,this.alphaConstraint);const n={};if(null!=this.sharedAxes)for(let e=1;e{let n=Oe(t);const s=e.mask;if(null!=s){const t=l(M(o(n.shape),m(s,n.dtype)),u(-1e9));n=w(n,t)}return this.axis instanceof Array?this.axis.length>1?O(M(n,B(n,this.axis,!0))):this.softmax(n,this.axis[0]):this.softmax(n,this.axis)}))}computeOutputShape(t){return t}getConfig(){const t={axis:this.axis},e=super.getConfig();return Object.assign(t,e),t}}function Gu(t,e,n){if("number"==typeof t)return lt(t,e);if(t.length!==e)throw new it(`The ${n} argument must be an integer or tuple of ${e} integers. Received: ${t.length} elements.`);for(let i=0;i(Mt(n),"channelsFirst"===n?t.transpose(e,[0,2,3,1]):e)))}function Yu(e,n){return s((()=>(Mt(n),"channelsFirst"===n?t.transpose(e,[0,2,3,4,1]):e)))}function Xu(e,n,i,r=[1,1],a="valid",o,l,u=null){return s((()=>{if(null==o&&(o="channelsLast"),Mt(o),3!==e.rank&&4!==e.rank)throw new it(`conv2dWithBiasActivation expects input to be of rank 3 or 4, but received ${e.rank}.`);if(3!==n.rank&&4!==n.rank)throw new it(`conv2dWithBiasActivation expects kernel to be of rank 3 or 4, but received ${e.rank}.`);let s=Zu(e,o);if("causal"===a)throw new rt("The support for CAUSAL padding mode in conv1dWithBias is not implemented yet.");return s=t.fused.conv2d({x:s,filter:n,strides:r,pad:"same"===a?"same":"valid",dilations:l,dataFormat:"NHWC",bias:i,activation:u}),"channelsFirst"===o&&(s=t.transpose(s,[0,3,1,2])),s}))}Ku.className="Softmax",r.registerClass(Ku);class Qu extends Je{constructor(t,e){if(super(e),this.bias=null,this.DEFAULT_KERNEL_INITIALIZER="glorotNormal",this.DEFAULT_BIAS_INITIALIZER="zeros",Qu.verifyArgs(e),this.rank=t,Nt(this.rank,"rank"),1!==this.rank&&2!==this.rank&&3!==this.rank)throw new rt(`Convolution layer for rank other than 1, 2, or 3 (${this.rank}) is not implemented yet.`);if(this.kernelSize=Gu(e.kernelSize,t,"kernelSize"),this.strides=Gu(null==e.strides?1:e.strides,t,"strides"),this.padding=null==e.padding?"valid":e.padding,Ot(this.padding),this.dataFormat=null==e.dataFormat?"channelsLast":e.dataFormat,Mt(this.dataFormat),this.activation=Du(e.activation),this.useBias=null==e.useBias||e.useBias,this.biasInitializer=_e(e.biasInitializer||this.DEFAULT_BIAS_INITIALIZER),this.biasConstraint=Yo(e.biasConstraint),this.biasRegularizer=Pu(e.biasRegularizer),this.activityRegularizer=Pu(e.activityRegularizer),this.dilationRate=Gu(null==e.dilationRate?1:e.dilationRate,t,"dilationRate"),1===this.rank&&Array.isArray(this.dilationRate)&&1!==this.dilationRate.length)throw new it(`dilationRate must be a number or an array of a single number for 1D convolution, but received ${JSON.stringify(this.dilationRate)}`);if(2===this.rank){if("number"==typeof this.dilationRate)this.dilationRate=[this.dilationRate,this.dilationRate];else if(2!==this.dilationRate.length)throw new it(`dilationRate must be a number or array of two numbers for 2D convolution, but received ${JSON.stringify(this.dilationRate)}`)}else if(3===this.rank)if("number"==typeof this.dilationRate)this.dilationRate=[this.dilationRate,this.dilationRate,this.dilationRate];else if(3!==this.dilationRate.length)throw new it(`dilationRate must be a number or array of three numbers for 3D convolution, but received ${JSON.stringify(this.dilationRate)}`)}static verifyArgs(t){if(ut("kernelSize"in t,"required key 'kernelSize' not in config"),"number"!=typeof t.kernelSize&&!xt(t.kernelSize,"number",1,3))throw new it(`BaseConv expects config.kernelSize to be number or number[] with length 1, 2, or 3, but received ${JSON.stringify(t.kernelSize)}.`)}getConfig(){const t={kernelSize:this.kernelSize,strides:this.strides,padding:this.padding,dataFormat:this.dataFormat,dilationRate:this.dilationRate,activation:$u(this.activation),useBias:this.useBias,biasInitializer:Le(this.biasInitializer),biasRegularizer:Ou(this.biasRegularizer),activityRegularizer:Ou(this.activityRegularizer),biasConstraint:Jo(this.biasConstraint)},e=super.getConfig();return Object.assign(t,e),t}}class th extends Qu{constructor(t,e){super(t,e),this.kernel=null,th.verifyArgs(e),this.filters=e.filters,Nt(this.filters,"filters"),this.kernelInitializer=_e(e.kernelInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.kernelConstraint=Yo(e.kernelConstraint),this.kernelRegularizer=Pu(e.kernelRegularizer)}build(t){t=Be(t);const e="channelsFirst"===this.dataFormat?1:t.length-1;if(null==t[e])throw new it(`The channel dimension of the input should be defined. Found ${t[e]}`);const n=t[e],s=this.kernelSize.concat([n,this.filters]);this.kernel=this.addWeight("kernel",s,null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.useBias&&(this.bias=this.addWeight("bias",[this.filters],null,this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint)),this.inputSpec=[{ndim:this.rank+2,axes:{[e]:n}}],this.built=!0}call(e,n){return s((()=>{let n;e=Oe(e);const i=null==this.bias?null:this.bias.read(),r=At(this.activation.getClassName());if(null!=r&&2===this.rank)n=Xu(e,this.kernel.read(),i,this.strides,this.padding,this.dataFormat,this.dilationRate,r);else{if(1===this.rank)n=function(e,n,i,r=1,a="valid",o,l=1){return s((()=>{if(null==o&&(o="channelsLast"),Mt(o),3!==e.shape.length)throw new it(`The input of a conv1dWithBias operation should be 3, but is ${e.shape.length} instead.`);if(3!==n.shape.length)throw new it(`The kernel for a conv1dWithBias operation should be 3, but is ${n.shape.length} instead`);if(null!=i&&1!==i.shape.length)throw new it(`The bias for a conv1dWithBias operation should be 1, but is ${n.shape.length} instead`);if("channelsFirst"===o&&(e=t.transpose(e,[0,2,1])),"causal"===a)throw new rt("The support for CAUSAL padding mode in conv1dWithBias is not implemented yet.");let s=t.conv1d(e,n,r,"same"===a?"same":"valid","NWC",l);return null!=i&&(s=ce(s,i)),s}))}(e,this.kernel.read(),i,this.strides[0],this.padding,this.dataFormat,this.dilationRate[0]);else if(2===this.rank)n=Xu(e,this.kernel.read(),i,this.strides,this.padding,this.dataFormat,this.dilationRate);else{if(3!==this.rank)throw new rt("convolutions greater than 3D are not implemented yet.");n=function(e,n,i,r=[1,1,1],a="valid",o,l){return s((()=>{if(null==o&&(o="channelsLast"),Mt(o),4!==e.rank&&5!==e.rank)throw new it(`conv3dWithBias expects input to be of rank 4 or 5, but received ${e.rank}.`);if(4!==n.rank&&5!==n.rank)throw new it(`conv3dWithBias expects kernel to be of rank 4 or 5, but received ${e.rank}.`);let s=Yu(e,o);if("causal"===a)throw new rt("The support for CAUSAL padding mode in conv3dWithBias is not implemented yet.");return s=t.conv3d(s,n,r,"same"===a?"same":"valid","NDHWC",l),null!=i&&(s=ce(s,i)),"channelsFirst"===o&&(s=t.transpose(s,[0,4,1,2,3])),s}))}(e,this.kernel.read(),i,this.strides,this.padding,this.dataFormat,this.dilationRate)}null!=this.activation&&(n=this.activation.apply(n))}return n}))}computeOutputShape(t){t=Be(t);const e=[],n="channelsLast"===this.dataFormat?t.slice(1,t.length-1):t.slice(2);for(let t=0;t 0 but got ${JSON.stringify(t.filters)}`)}}class eh extends th{constructor(t){super(2,t),eh.verifyArgs(t)}getConfig(){const t=super.getConfig();return delete t.rank,t}static verifyArgs(t){if("number"!=typeof t.kernelSize&&!xt(t.kernelSize,"number",1,2))throw new it(`Conv2D expects config.kernelSize to be number or number[] with length 1 or 2, but received ${JSON.stringify(t.kernelSize)}.`)}}eh.className="Conv2D",r.registerClass(eh);class nh extends th{constructor(t){super(3,t),nh.verifyArgs(t)}getConfig(){const t=super.getConfig();return delete t.rank,t}static verifyArgs(t){if("number"!=typeof t.kernelSize&&(!Array.isArray(t.kernelSize)||1!==t.kernelSize.length&&3!==t.kernelSize.length))throw new it(`Conv3D expects config.kernelSize to be number or [number, number, number], but received ${JSON.stringify(t.kernelSize)}.`)}}nh.className="Conv3D",r.registerClass(nh);class sh extends eh{constructor(t){if(super(t),this.inputSpec=[new Ve({ndim:4})],"same"!==this.padding&&"valid"!==this.padding)throw new it(`Conv2DTranspose currently supports only padding modes 'same' and 'valid', but received padding mode ${this.padding}`)}build(t){if(4!==(t=Be(t)).length)throw new it("Input should have rank 4; Received input shape: "+JSON.stringify(t));const e="channelsFirst"===this.dataFormat?1:t.length-1;if(null==t[e])throw new it("The channel dimension of the inputs should be defined. Found `None`.");const n=t[e],s=this.kernelSize.concat([this.filters,n]);this.kernel=this.addWeight("kernel",s,"float32",this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.useBias&&(this.bias=this.addWeight("bias",[this.filters],"float32",this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint)),this.inputSpec=[new Ve({ndim:4,axes:{[e]:n}})],this.built=!0}call(e,n){return t.tidy((()=>{let n=Oe(e);if(4!==n.shape.length)throw new it(`Conv2DTranspose.call() expects input tensor to be rank-4, but received a tensor of rank-${n.shape.length}`);const s=n.shape,i=s[0];let r,a;"channelsFirst"===this.dataFormat?(r=2,a=3):(r=1,a=2);const o=s[r],l=s[a],u=this.kernelSize[0],h=this.kernelSize[1],c=this.strides[0],p=this.strides[1],d=[i,Ju(o,c,u,this.padding),Ju(l,p,h,this.padding),this.filters];"channelsLast"!==this.dataFormat&&(n=t.transpose(n,[0,2,3,1]));let f=t.conv2dTranspose(n,this.kernel.read(),d,this.strides,this.padding);return"channelsLast"!==this.dataFormat&&(f=t.transpose(f,[0,3,1,2])),null!=this.bias&&(f=ce(f,this.bias.read(),this.dataFormat)),null!=this.activation&&(f=this.activation.apply(f)),f}))}computeOutputShape(t){const e=(t=Be(t)).slice();let n,s,i;"channelsFirst"===this.dataFormat?(n=1,s=2,i=3):(n=3,s=1,i=2);const r=this.kernelSize[0],a=this.kernelSize[1],o=this.strides[0],l=this.strides[1];return e[n]=this.filters,e[s]=Ju(e[s],o,r,this.padding),e[i]=Ju(e[i],l,a,this.padding),e}getConfig(){const t=super.getConfig();return delete t.dilationRate,t}}sh.className="Conv2DTranspose",r.registerClass(sh);class ih extends nh{constructor(t){if(super(t),this.inputSpec=[new Ve({ndim:5})],"same"!==this.padding&&"valid"!==this.padding)throw new it(`Conv3DTranspose currently supports only padding modes 'same' and 'valid', but received padding mode ${this.padding}`)}build(t){if(5!==(t=Be(t)).length)throw new it("Input should have rank 5; Received input shape: "+JSON.stringify(t));const e="channelsFirst"===this.dataFormat?1:t.length-1;if(null==t[e])throw new it("The channel dimension of the inputs should be defined. Found `None`.");const n=t[e],s=this.kernelSize.concat([this.filters,n]);this.kernel=this.addWeight("kernel",s,"float32",this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.useBias&&(this.bias=this.addWeight("bias",[this.filters],"float32",this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint)),this.inputSpec=[new Ve({ndim:5,axes:{[e]:n}})],this.built=!0}call(e,n){return t.tidy((()=>{let n=Oe(e);if(5!==n.shape.length)throw new it(`Conv3DTranspose.call() expects input tensor to be rank-4, but received a tensor of rank-${n.shape.length}`);const s=n.shape,i=s[0];let r,a,o;"channelsFirst"===this.dataFormat?(o=2,r=3,a=4):(o=1,r=2,a=3);const l=s[o],u=s[r],h=s[a],c=this.kernelSize[0],p=this.kernelSize[1],d=this.kernelSize[2],f=this.strides[0],g=this.strides[1],m=this.strides[2],y=[i,Ju(l,f,c,this.padding),Ju(u,g,p,this.padding),Ju(h,m,d,this.padding),this.filters];"channelsLast"!==this.dataFormat&&(n=t.transpose(n,[0,2,3,4,1]));let b=t.conv3dTranspose(n,this.kernel.read(),y,this.strides,this.padding);return"channelsLast"!==this.dataFormat&&(b=t.transpose(b,[0,4,1,2,3])),null!==this.bias&&(b=ce(b,this.bias.read(),this.dataFormat)),null!==this.activation&&(b=this.activation.apply(b)),b}))}computeOutputShape(t){const e=(t=Be(t)).slice();let n,s,i,r;"channelsFirst"===this.dataFormat?(n=1,s=2,i=3,r=4):(n=4,s=1,i=2,r=3);const a=this.kernelSize[0],o=this.kernelSize[1],l=this.kernelSize[2],u=this.strides[0],h=this.strides[1],c=this.strides[2];return e[n]=this.filters,e[s]=Ju(e[s],u,a,this.padding),e[i]=Ju(e[i],h,o,this.padding),e[r]=Ju(e[r],c,l,this.padding),e}getConfig(){const t=super.getConfig();return delete t.dilationRate,t}}ih.className="Conv3DTranspose",r.registerClass(ih);class rh extends th{constructor(t,e){if(super(t,e),this.DEFAULT_DEPTHWISE_INITIALIZER="glorotUniform",this.DEFAULT_POINTWISE_INITIALIZER="glorotUniform",this.depthwiseKernel=null,this.pointwiseKernel=null,null==e.filters)throw new it("The `filters` configuration field is required by SeparableConv, but is unspecified.");if(null!=e.kernelInitializer||null!=e.kernelRegularizer||null!=e.kernelConstraint)throw new it("Fields kernelInitializer, kernelRegularizer and kernelConstraint are invalid for SeparableConv2D. Use depthwiseInitializer, depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, pointwiseRegularizer and pointwiseConstraint instead.");if(null!=e.padding&&"same"!==e.padding&&"valid"!==e.padding)throw new it(`SeparableConv${this.rank}D supports only padding modes: 'same' and 'valid', but received ${JSON.stringify(e.padding)}`);this.depthMultiplier=null==e.depthMultiplier?1:e.depthMultiplier,this.depthwiseInitializer=_e(e.depthwiseInitializer||this.DEFAULT_DEPTHWISE_INITIALIZER),this.depthwiseRegularizer=Pu(e.depthwiseRegularizer),this.depthwiseConstraint=Yo(e.depthwiseConstraint),this.pointwiseInitializer=_e(e.depthwiseInitializer||this.DEFAULT_POINTWISE_INITIALIZER),this.pointwiseRegularizer=Pu(e.pointwiseRegularizer),this.pointwiseConstraint=Yo(e.pointwiseConstraint)}build(t){if((t=Be(t)).length{let n;if(e=Oe(e),1===this.rank)throw new rt("1D separable convolution is not implemented yet.");return 2===this.rank&&("channelsFirst"===this.dataFormat&&(e=t.transpose(e,[0,2,3,1])),n=t.separableConv2d(e,this.depthwiseKernel.read(),this.pointwiseKernel.read(),this.strides,this.padding,this.dilationRate,"NHWC")),this.useBias&&(n=ce(n,this.bias.read(),this.dataFormat)),null!=this.activation&&(n=this.activation.apply(n)),"channelsFirst"===this.dataFormat&&(n=t.transpose(n,[0,3,1,2])),n}))}getConfig(){const t=super.getConfig();return delete t.rank,delete t.kernelInitializer,delete t.kernelRegularizer,delete t.kernelConstraint,t.depthwiseInitializer=Le(this.depthwiseInitializer),t.pointwiseInitializer=Le(this.pointwiseInitializer),t.depthwiseRegularizer=Ou(this.depthwiseRegularizer),t.pointwiseRegularizer=Ou(this.pointwiseRegularizer),t.depthwiseConstraint=Jo(this.depthwiseConstraint),t.pointwiseConstraint=Jo(this.pointwiseConstraint),t}}rh.className="SeparableConv";class ah extends rh{constructor(t){super(2,t)}}ah.className="SeparableConv2D",r.registerClass(ah);class oh extends th{constructor(t){super(1,t),oh.verifyArgs(t),this.inputSpec=[{ndim:3}]}getConfig(){const t=super.getConfig();return delete t.rank,delete t.dataFormat,t}static verifyArgs(t){if("number"!=typeof t.kernelSize&&!xt(t.kernelSize,"number",1,1))throw new it(`Conv1D expects config.kernelSize to be number or number[] with length 1, but received ${JSON.stringify(t.kernelSize)}.`)}}oh.className="Conv1D",r.registerClass(oh);class lh extends Je{constructor(t){super(t),"number"==typeof t.cropping?this.cropping=[[t.cropping,t.cropping],[t.cropping,t.cropping]]:"number"==typeof t.cropping[0]?this.cropping=[[t.cropping[0],t.cropping[0]],[t.cropping[1],t.cropping[1]]]:this.cropping=t.cropping,this.dataFormat=void 0===t.dataFormat?"channelsLast":t.dataFormat,this.inputSpec=[{ndim:4}]}computeOutputShape(t){return"channelsFirst"===this.dataFormat?[t[0],t[1],t[2]-this.cropping[0][0]-this.cropping[0][1],t[3]-this.cropping[1][0]-this.cropping[1][1]]:[t[0],t[1]-this.cropping[0][0]-this.cropping[0][1],t[2]-this.cropping[1][0]-this.cropping[1][1],t[3]]}call(t,e){return s((()=>{if(t=Oe(t),"channelsLast"===this.dataFormat){const e=ne(t,this.cropping[0][0],t.shape[1]-this.cropping[0][0]-this.cropping[0][1],2);return ne(e,this.cropping[1][0],t.shape[2]-this.cropping[1][1]-this.cropping[1][0],3)}{const e=ne(t,this.cropping[0][0],t.shape[2]-this.cropping[0][0]-this.cropping[0][1],3);return ne(e,this.cropping[1][0],t.shape[3]-this.cropping[1][1]-this.cropping[1][0],4)}}))}getConfig(){const t={cropping:this.cropping,dataFormat:this.dataFormat},e=super.getConfig();return Object.assign(t,e),t}}lh.className="Cropping2D",r.registerClass(lh);class uh extends Je{constructor(t){var e;super(t),this.DEFAULT_SIZE=[2,2],this.inputSpec=[{ndim:4}],this.size=null==t.size?this.DEFAULT_SIZE:t.size,this.dataFormat=null==t.dataFormat?"channelsLast":t.dataFormat,Mt(this.dataFormat),this.interpolation=null==t.interpolation?"nearest":t.interpolation,e=this.interpolation,St(Ft,"InterpolationFormat",e)}computeOutputShape(t){if("channelsFirst"===this.dataFormat){const e=null==t[2]?null:this.size[0]*t[2],n=null==t[3]?null:this.size[1]*t[3];return[t[0],t[1],e,n]}{const e=null==t[1]?null:this.size[0]*t[1],n=null==t[2]?null:this.size[1]*t[2];return[t[0],e,n,t[3]]}}call(e,n){return t.tidy((()=>{let n=Oe(e);const s=n.shape;if("channelsFirst"===this.dataFormat){n=t.transpose(n,[0,2,3,1]);const e=this.size[0]*s[2],i=this.size[1]*s[3],r="nearest"===this.interpolation?t.image.resizeNearestNeighbor(n,[e,i]):t.image.resizeBilinear(n,[e,i]);return t.transpose(r,[0,3,1,2])}{const e=this.size[0]*s[1],i=this.size[1]*s[2];return"nearest"===this.interpolation?t.image.resizeNearestNeighbor(n,[e,i]):t.image.resizeBilinear(n,[e,i])}}))}getConfig(){const t={size:this.size,dataFormat:this.dataFormat,interpolation:this.interpolation},e=super.getConfig();return Object.assign(t,e),t}}uh.className="UpSampling2D",r.registerClass(uh);class hh extends Qu{constructor(t){super(2,t),this.depthwiseKernel=null,this.depthMultiplier=null==t.depthMultiplier?1:t.depthMultiplier,this.depthwiseInitializer=_e(t.depthwiseInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.depthwiseConstraint=Yo(t.depthwiseConstraint),this.depthwiseRegularizer=Pu(t.depthwiseRegularizer)}build(t){if((t=Be(t)).length<4)throw new it(`Inputs to DepthwiseConv2D should have rank 4. Received input shape: ${JSON.stringify(t)}.`);const e="channelsFirst"===this.dataFormat?1:3;if(null==t[e]||t[e]<0)throw new it(`The channel dimension of the inputs to DepthwiseConv2D should be defined, but is not (${t[e]}).`);const n=t[e],s=[this.kernelSize[0],this.kernelSize[1],n,this.depthMultiplier];this.depthwiseKernel=this.addWeight("depthwise_kernel",s,null,this.depthwiseInitializer,this.depthwiseRegularizer,!0,this.depthwiseConstraint),this.useBias?this.bias=this.addWeight("bias",[n*this.depthMultiplier],null,this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint):this.bias=null,this.built=!0}call(e,n){return s((()=>{let n=function(e,n,i=[1,1],r="valid",a,o){return s((()=>{null==a&&(a="channelsLast"),Mt(a);let s=Zu(e,a);if(4!==e.rank)throw new it(`Input for depthwiseConv2d is required to be 4-D, but is instead ${e.rank}-D`);if(4!==n.rank)throw new it(`depthwiseKernel is required to be 4-D, but is instead ${n.rank}-D`);return s=t.depthwiseConv2d(s,n,i,"same"===r?"same":"valid","NHWC",o),"channelsFirst"===a&&(s=t.transpose(s,[0,3,1,2])),s}))}(e=Oe(e),this.depthwiseKernel.read(),this.strides,this.padding,this.dataFormat,null);return this.useBias&&(n=ce(n,this.bias.read(),this.dataFormat)),null!=this.activation&&(n=this.activation.apply(n)),n}))}computeOutputShape(t){t=Be(t);const e="channelsFirst"===this.dataFormat?t[2]:t[1],n="channelsFirst"===this.dataFormat?t[3]:t[2],s="channelsFirst"===this.dataFormat?t[1]*this.depthMultiplier:t[3]*this.depthMultiplier,i=Hu(e,this.kernelSize[0],this.padding,this.strides[0]),r=Hu(n,this.kernelSize[1],this.padding,this.strides[1]);return"channelsFirst"===this.dataFormat?[t[0],s,i,r]:[t[0],i,r,s]}getConfig(){const t=super.getConfig();return t.depthMultiplier=this.depthMultiplier,t.depthwiseInitializer=Le(this.depthwiseInitializer),t.depthwiseRegularizer=Ou(this.depthwiseRegularizer),t.depthwiseConstraint=Jo(this.depthwiseRegularizer),t}}function ch(t,e,n,s){if(Array.isArray(t)){if(null!=e||null!=n)throw new it("When inputs is an array, neither initialState or constants should be provided");null!=s&&(n=t.slice(t.length-s,t.length),t=t.slice(0,t.length-s)),t.length>1&&(e=t.slice(1,t.length)),t=t[0]}function i(t){return null==t||Array.isArray(t)?t:[t]}return{inputs:t,initialState:e=i(e),constants:n=i(n)}}function ph(e,n,s,i=!1,r,a,o=!1,l=!1){return t.tidy((()=>{const u=n.shape.length;if(u<3)throw new it(`Input should be at least 3D, but is ${u}D.`);const h=[1,0].concat(Jt(2,u));if(n=t.transpose(n,h),null!=a)throw new rt("The rnn() functoin of the deeplearn.js backend does not support constants yet.");o&&console.warn("Backend rnn(): the unroll = true option is not applicable to the imperative deeplearn.js backend."),null!=r&&((r=t.cast(t.cast(r,"bool"),"float32")).rank===u-1&&(r=t.expandDims(r,-1)),r=t.transpose(r,h)),i&&(n=t.reverse(n,0),null!=r&&(r=t.reverse(r,0)));const c=[];let p,d=s;const f=n.shape[0],g=t.unstack(n);let m,y;null!=r&&(m=t.unstack(r));for(let n=0;ne(s,d)));if(null==r)p=i[0],d=i[1];else{const e=t.tidy((()=>{const e=m[n],s=t.sub(t.onesLike(e),e);return{output:t.add(t.mul(i[0],e),t.mul(d[0],s)),newStates:d.map(((n,r)=>t.add(t.mul(i[1][r],e),t.mul(n,s))))}}));p=e.output,d=e.newStates}l&&c.push(p)}if(l){const e=1;y=t.stack(c,e)}return[p,y,d]}))}hh.className="DepthwiseConv2D",r.registerClass(hh);class dh extends Je{constructor(t){let e;if(super(t),null==t.cell)throw new it("cell property is missing for the constructor of RNN.");if(e=Array.isArray(t.cell)?new vh({cells:t.cell}):t.cell,null==e.stateSize)throw new it("The RNN cell should have an attribute `stateSize` (tuple of integers, one integer per RNN state).");this.cell=e,this.returnSequences=null!=t.returnSequences&&t.returnSequences,this.returnState=null!=t.returnState&&t.returnState,this.goBackwards=null!=t.goBackwards&&t.goBackwards,this._stateful=null!=t.stateful&&t.stateful,this.unroll=null!=t.unroll&&t.unroll,this.supportsMasking=!0,this.inputSpec=[new Ve({ndim:3})],this.stateSpec=null,this.states_=null,this.numConstants=null,this.keptStates=[]}getStates(){if(null==this.states_){return Jt(0,Array.isArray(this.cell.stateSize)?this.cell.stateSize.length:1).map((t=>null))}return this.states_}setStates(t){this.states_=t}computeOutputShape(t){Re(t)&&(t=t[0]);let e=this.cell.stateSize;Array.isArray(e)||(e=[e]);const n=e[0];let s;if(s=this.returnSequences?[t[0],t[1],n]:[t[0],n],this.returnState){const n=[];for(const s of e)n.push([t[0],s]);return[s].concat(n)}return s}computeMask(e,n){return t.tidy((()=>{Array.isArray(n)&&(n=n[0]);const t=this.returnSequences?n:null;if(this.returnState){const e=this.states.map((t=>null));return[t].concat(e)}return t}))}get states(){if(null==this.states_){const t=Array.isArray(this.cell.stateSize)?this.cell.stateSize.length:1,e=[];for(let n=0;nt.shape[t.shape.length-1])),r))throw new it(`An initialState was passed that is not compatible with cell.stateSize. Received stateSpec=${this.stateSpec}; However cell.stateSize is ${this.cell.stateSize}`)}else this.stateSpec=r.map((t=>new Ve({shape:[null,t]})));this.stateful&&this.resetStates()}resetStates(n,i=!1){s((()=>{if(!this.stateful)throw new nt("Cannot call resetStates() on an RNN Layer that is not stateful.");const s=this.inputSpec[0].shape[0];if(null==s)throw new it("If an RNN is stateful, it needs to know its batch size. Specify the batch size of your input tensors: \n- If using a Sequential model, specify the batch size by passing a `batchInputShape` option to your first layer.\n- If using the functional API, specify the batch size by passing a `batchShape` option to your Input layer.");if(null==this.states_)Array.isArray(this.cell.stateSize)?this.states_=this.cell.stateSize.map((e=>t.zeros([s,e]))):this.states_=[t.zeros([s,this.cell.stateSize])];else if(null==n)t.dispose(this.states_),null!=this.keptStates&&(t.dispose(this.keptStates),this.keptStates=[]),Array.isArray(this.cell.stateSize)?this.states_=this.cell.stateSize.map((e=>t.zeros([s,e]))):this.states_[0]=t.zeros([s,this.cell.stateSize]);else{if(Array.isArray(n)||(n=[n]),n.length!==this.states_.length)throw new it(`Layer ${this.name} expects ${this.states_.length} state(s), but it received ${n.length} state value(s). Input received: ${n}`);!0===i?this.keptStates.push(this.states_.slice()):t.dispose(this.states_);for(let t=0;tt.keep(e.clone())))}))}apply(t,e){let n=null==e?null:e.initialState,s=null==e?null:e.constants;null==e&&(e={});const i=ch(t,n,s,this.numConstants);t=i.inputs,n=i.initialState,s=i.constants;let r=[],a=[];if(null!=n){e.initialState=n,r=r.concat(n),this.stateSpec=[];for(const t of n)this.stateSpec.push(new Ve({shape:t.shape}));a=a.concat(this.stateSpec)}null!=s&&(e.constants=s,r=r.concat(s),this.numConstants=s.length);if(r[0]instanceof qe){const n=[t].concat(r),s=this.inputSpec.concat(a),i=this.inputSpec;this.inputSpec=s;const o=super.apply(n,e);return this.inputSpec=i,o}return super.apply(t,e)}call(t,e){return s((()=>{const n=null==e?null:e.mask,s=null==e?null:e.training;let i=null==e?null:e.initialState;t=Oe(t),null==i&&(i=this.stateful?this.states_:this.getInitialState(t));const r=Array.isArray(this.cell.stateSize)?this.cell.stateSize.length:1;if(i.length!==r)throw new it(`RNN Layer has ${r} state(s) but was passed ${i.length} initial state(s).`);this.unroll&&console.warn("Ignoring unroll = true for RNN layer, due to imperative backend.");const a={training:s},o=ph(((t,e)=>{const n=this.cell.call([t].concat(e),a);return[n[0],n.slice(1)]}),t,i,this.goBackwards,n,null,this.unroll,this.returnSequences),l=o[0],u=o[1],h=o[2];this.stateful&&this.resetStates(h,s);const c=this.returnSequences?u:l;return this.returnState?[c].concat(h):c}))}getInitialState(e){return s((()=>{let n=t.zeros(e.shape);return n=t.sum(n,[1,2]),n=Qt(n),Array.isArray(this.cell.stateSize)?this.cell.stateSize.map((t=>t>1?re(n,[1,t]):n)):this.cell.stateSize>1?[re(n,[1,this.cell.stateSize])]:[n]}))}get trainableWeights(){return this.trainable?this.cell.trainableWeights:[]}get nonTrainableWeights(){return this.trainable?this.cell.nonTrainableWeights:this.cell.weights}setFastWeightInitDuringBuild(t){super.setFastWeightInitDuringBuild(t),null!=this.cell&&this.cell.setFastWeightInitDuringBuild(t)}getConfig(){const t=super.getConfig(),e={returnSequences:this.returnSequences,returnState:this.returnState,goBackwards:this.goBackwards,stateful:this.stateful,unroll:this.unroll};null!=this.numConstants&&(e.numConstants=this.numConstants);const n=this.cell.getConfig();return this.getClassName()===dh.className&&(e.cell={className:this.cell.getClassName(),config:n}),Object.assign(Object.assign(Object.assign({},n),t),e)}static fromConfig(t,e,n={}){const s=cl(e.cell,n);return new t(Object.assign(e,{cell:s}))}}dh.className="RNN",r.registerClass(dh);class fh extends Je{}class gh extends fh{constructor(t){super(t),this.DEFAULT_ACTIVATION="tanh",this.DEFAULT_KERNEL_INITIALIZER="glorotNormal",this.DEFAULT_RECURRENT_INITIALIZER="orthogonal",this.DEFAULT_BIAS_INITIALIZER="zeros",this.units=t.units,Nt(this.units,"units"),this.activation=Du(null==t.activation?this.DEFAULT_ACTIVATION:t.activation),this.useBias=null==t.useBias||t.useBias,this.kernelInitializer=_e(t.kernelInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.recurrentInitializer=_e(t.recurrentInitializer||this.DEFAULT_RECURRENT_INITIALIZER),this.biasInitializer=_e(t.biasInitializer||this.DEFAULT_BIAS_INITIALIZER),this.kernelRegularizer=Pu(t.kernelRegularizer),this.recurrentRegularizer=Pu(t.recurrentRegularizer),this.biasRegularizer=Pu(t.biasRegularizer),this.kernelConstraint=Yo(t.kernelConstraint),this.recurrentConstraint=Yo(t.recurrentConstraint),this.biasConstraint=Yo(t.biasConstraint),this.dropout=Gt([1,Ht([0,null==t.dropout?0:t.dropout])]),this.recurrentDropout=Gt([1,Ht([0,null==t.recurrentDropout?0:t.recurrentDropout])]),this.dropoutFunc=t.dropoutFunc,this.stateSize=this.units,this.dropoutMask=null,this.recurrentDropoutMask=null}build(t){t=Be(t),this.kernel=this.addWeight("kernel",[t[t.length-1],this.units],null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.recurrentKernel=this.addWeight("recurrent_kernel",[this.units,this.units],null,this.recurrentInitializer,this.recurrentRegularizer,!0,this.recurrentConstraint),this.useBias?this.bias=this.addWeight("bias",[this.units],null,this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint):this.bias=null,this.built=!0}call(e,n){return s((()=>{if(2!==e.length)throw new it(`SimpleRNNCell expects 2 input Tensors, got ${e.length}.`);let s=e[1];e=e[0];const i=null!=n.training&&n.training;let r;0t.onesLike(e),rate:this.dropout,training:i,dropoutFunc:this.dropoutFunc})),0t.onesLike(s),rate:this.recurrentDropout,training:i,dropoutFunc:this.dropoutFunc}));const a=this.dropoutMask,o=this.recurrentDropoutMask;r=oe(null!=a?t.mul(e,a):e,this.kernel.read()),null!=this.bias&&(r=ce(r,this.bias.read())),null!=o&&(s=t.mul(s,o));let l=t.add(r,oe(s,this.recurrentKernel.read()));return null!=this.activation&&(l=this.activation.apply(l)),[l,l]}))}getConfig(){const t=super.getConfig(),e={units:this.units,activation:$u(this.activation),useBias:this.useBias,kernelInitializer:Le(this.kernelInitializer),recurrentInitializer:Le(this.recurrentInitializer),biasInitializer:Le(this.biasInitializer),kernelRegularizer:Ou(this.kernelRegularizer),recurrentRegularizer:Ou(this.recurrentRegularizer),biasRegularizer:Ou(this.biasRegularizer),activityRegularizer:Ou(this.activityRegularizer),kernelConstraint:Jo(this.kernelConstraint),recurrentConstraint:Jo(this.recurrentConstraint),biasConstraint:Jo(this.biasConstraint),dropout:this.dropout,recurrentDropout:this.recurrentDropout};return Object.assign(Object.assign({},t),e)}}gh.className="SimpleRNNCell",r.registerClass(gh);class mh extends dh{constructor(t){t.cell=new gh(t),super(t)}call(e,n){return s((()=>{null!=this.cell.dropoutMask&&(t.dispose(this.cell.dropoutMask),this.cell.dropoutMask=null),null!=this.cell.recurrentDropoutMask&&(t.dispose(this.cell.recurrentDropoutMask),this.cell.recurrentDropoutMask=null);const s=null==n?null:n.mask,i=null==n?null:n.training,r=null==n?null:n.initialState;return super.call(e,{mask:s,training:i,initialState:r})}))}static fromConfig(t,e){return new t(e)}}mh.className="SimpleRNN",r.registerClass(mh);class yh extends fh{constructor(t){if(super(t),this.DEFAULT_ACTIVATION="tanh",this.DEFAULT_RECURRENT_ACTIVATION="hardSigmoid",this.DEFAULT_KERNEL_INITIALIZER="glorotNormal",this.DEFAULT_RECURRENT_INITIALIZER="orthogonal",this.DEFAULT_BIAS_INITIALIZER="zeros",t.resetAfter)throw new it("GRUCell does not support reset_after parameter set to true.");this.units=t.units,Nt(this.units,"units"),this.activation=Du(void 0===t.activation?this.DEFAULT_ACTIVATION:t.activation),this.recurrentActivation=Du(void 0===t.recurrentActivation?this.DEFAULT_RECURRENT_ACTIVATION:t.recurrentActivation),this.useBias=null==t.useBias||t.useBias,this.kernelInitializer=_e(t.kernelInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.recurrentInitializer=_e(t.recurrentInitializer||this.DEFAULT_RECURRENT_INITIALIZER),this.biasInitializer=_e(t.biasInitializer||this.DEFAULT_BIAS_INITIALIZER),this.kernelRegularizer=Pu(t.kernelRegularizer),this.recurrentRegularizer=Pu(t.recurrentRegularizer),this.biasRegularizer=Pu(t.biasRegularizer),this.kernelConstraint=Yo(t.kernelConstraint),this.recurrentConstraint=Yo(t.recurrentConstraint),this.biasConstraint=Yo(t.biasConstraint),this.dropout=Gt([1,Ht([0,null==t.dropout?0:t.dropout])]),this.recurrentDropout=Gt([1,Ht([0,null==t.recurrentDropout?0:t.recurrentDropout])]),this.dropoutFunc=t.dropoutFunc,this.implementation=t.implementation,this.stateSize=this.units,this.dropoutMask=null,this.recurrentDropoutMask=null}build(t){const e=(t=Be(t))[t.length-1];this.kernel=this.addWeight("kernel",[e,3*this.units],null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.recurrentKernel=this.addWeight("recurrent_kernel",[this.units,3*this.units],null,this.recurrentInitializer,this.recurrentRegularizer,!0,this.recurrentConstraint),this.useBias?this.bias=this.addWeight("bias",[3*this.units],null,this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint):this.bias=null,this.built=!0}call(e,n){return s((()=>{if(2!==e.length)throw new it(`GRUCell expects 2 input Tensors (inputs, h, c), got ${e.length}.`);const s=null!=n.training&&n.training;let i=e[1];e=e[0],0t.onesLike(e),rate:this.dropout,training:s,count:3,dropoutFunc:this.dropoutFunc})),0t.onesLike(i),rate:this.recurrentDropout,training:s,count:3,dropoutFunc:this.dropoutFunc}));const r=this.dropoutMask,a=this.recurrentDropoutMask;let o,l,u;0{null!=this.cell.dropoutMask&&(t.dispose(this.cell.dropoutMask),this.cell.dropoutMask=null),null!=this.cell.recurrentDropoutMask&&(t.dispose(this.cell.recurrentDropoutMask),this.cell.recurrentDropoutMask=null);const s=null==n?null:n.mask,i=null==n?null:n.training,r=null==n?null:n.initialState;return super.call(e,{mask:s,training:i,initialState:r})}))}static fromConfig(t,e){return 0===e.implmentation&&(e.implementation=1),new t(e)}}bh.className="GRU",r.registerClass(bh);class wh extends fh{constructor(t){super(t),this.DEFAULT_ACTIVATION="tanh",this.DEFAULT_RECURRENT_ACTIVATION="hardSigmoid",this.DEFAULT_KERNEL_INITIALIZER="glorotNormal",this.DEFAULT_RECURRENT_INITIALIZER="orthogonal",this.DEFAULT_BIAS_INITIALIZER="zeros",this.units=t.units,Nt(this.units,"units"),this.activation=Du(void 0===t.activation?this.DEFAULT_ACTIVATION:t.activation),this.recurrentActivation=Du(void 0===t.recurrentActivation?this.DEFAULT_RECURRENT_ACTIVATION:t.recurrentActivation),this.useBias=null==t.useBias||t.useBias,this.kernelInitializer=_e(t.kernelInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.recurrentInitializer=_e(t.recurrentInitializer||this.DEFAULT_RECURRENT_INITIALIZER),this.biasInitializer=_e(t.biasInitializer||this.DEFAULT_BIAS_INITIALIZER),this.unitForgetBias=t.unitForgetBias,this.kernelRegularizer=Pu(t.kernelRegularizer),this.recurrentRegularizer=Pu(t.recurrentRegularizer),this.biasRegularizer=Pu(t.biasRegularizer),this.kernelConstraint=Yo(t.kernelConstraint),this.recurrentConstraint=Yo(t.recurrentConstraint),this.biasConstraint=Yo(t.biasConstraint),this.dropout=Gt([1,Ht([0,null==t.dropout?0:t.dropout])]),this.recurrentDropout=Gt([1,Ht([0,null==t.recurrentDropout?0:t.recurrentDropout])]),this.dropoutFunc=t.dropoutFunc,this.implementation=t.implementation,this.stateSize=[this.units,this.units],this.dropoutMask=null,this.recurrentDropoutMask=null}build(t){var e;const n=(t=Be(t))[t.length-1];let s;if(this.kernel=this.addWeight("kernel",[n,4*this.units],null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.recurrentKernel=this.addWeight("recurrent_kernel",[this.units,4*this.units],null,this.recurrentInitializer,this.recurrentRegularizer,!0,this.recurrentConstraint),this.useBias){if(this.unitForgetBias){const t=this.biasInitializer,n=this.units;s=new((e=class extends me{apply(e,s){const i=t.apply([n]),r=(new be).apply([n]),a=t.apply([2*n]);return ie(ie(i,r),a)}}).className="CustomInit",e)}else s=this.biasInitializer;this.bias=this.addWeight("bias",[4*this.units],null,s,this.biasRegularizer,!0,this.biasConstraint)}else this.bias=null;this.built=!0}call(e,n){return s((()=>{const s=null!=n.training&&n.training;if(3!==e.length)throw new it(`LSTMCell expects 3 input Tensors (inputs, h, c), got ${e.length}.`);let i=e[1];const r=e[2];e=e[0],0t.onesLike(e),rate:this.dropout,training:s,count:4,dropoutFunc:this.dropoutFunc})),0t.onesLike(i),rate:this.recurrentDropout,training:s,count:4,dropoutFunc:this.dropoutFunc}));const a=this.dropoutMask,o=this.recurrentDropoutMask;let l,u,h,c;0{null!=this.cell.dropoutMask&&(t.dispose(this.cell.dropoutMask),this.cell.dropoutMask=null),null!=this.cell.recurrentDropoutMask&&(t.dispose(this.cell.recurrentDropoutMask),this.cell.recurrentDropoutMask=null);const s=null==n?null:n.mask,i=null==n?null:n.training,r=null==n?null:n.initialState;return super.call(e,{mask:s,training:i,initialState:r})}))}static fromConfig(t,e){return 0===e.implmentation&&(e.implementation=1),new t(e)}}kh.className="LSTM",r.registerClass(kh);class vh extends fh{constructor(t){super(t),this.cells=t.cells}get stateSize(){const t=[];for(const e of this.cells.slice().reverse())Array.isArray(e.stateSize)?t.push(...e.stateSize):t.push(e.stateSize);return t}call(t,e){return s((()=>{let n=t.slice(1);const s=[];for(const t of this.cells.slice().reverse())Array.isArray(t.stateSize)?s.push(n.splice(0,t.stateSize.length)):s.push(n.splice(0,1));s.reverse();const i=[];let r;for(let a=0;a{Ut(`RNNCell_${s}`,(()=>{n.build(t),e=Array.isArray(n.stateSize)?n.stateSize[0]:n.stateSize,t=[t[0],e]}))})),this.built=!0}getConfig(){const t=super.getConfig(),e={cells:this.cells.map((t=>({className:t.getClassName(),config:t.getConfig()})))};return Object.assign(Object.assign({},t),e)}static fromConfig(t,e,n={}){const s=[];for(const t of e.cells)s.push(cl(t,n));return new t({cells:s})}get trainableWeights(){if(!this.trainable)return[];const t=[];for(const e of this.cells)t.push(...e.trainableWeights);return t}get nonTrainableWeights(){const t=[];for(const e of this.cells)t.push(...e.nonTrainableWeights);if(!this.trainable){const e=[];for(const t of this.cells)e.push(...t.trainableWeights);return e.concat(t)}return t}getWeights(){const t=[];for(const e of this.cells)t.push(...e.weights);return We(t)}setWeights(t){const e=[];for(const n of this.cells){const s=n.weights.length,i=t.splice(s);for(let t=0;tnull!=a?a(n(),s):pe(n(),s),l=()=>de(o,n,i);if(!r||r<=1)return t.keep(l().clone());return Array(r).fill(void 0).map(l).map((e=>t.keep(e.clone())))}vh.className="StackedRNNCells",r.registerClass(vh);var xh=function(t,e){var n={};for(var s in t)Object.prototype.hasOwnProperty.call(t,s)&&e.indexOf(s)<0&&(n[s]=t[s]);if(null!=t&&"function"==typeof Object.getOwnPropertySymbols){var i=0;for(s=Object.getOwnPropertySymbols(t);i{if(null!=this.cell.dropoutMask&&(t.dispose(this.cell.dropoutMask),this.cell.dropoutMask=null),null!=this.cell.recurrentDropoutMask&&(t.dispose(this.cell.recurrentDropoutMask),this.cell.recurrentDropoutMask=null),n&&n.constants)throw new it("ConvRNN2D cell does not support constants");const s=null==n?null:n.mask,i=null==n?null:n.training,r=null==n?null:n.initialState;return super.call(e,{mask:s,training:i,initialState:r})}))}computeOutputShape(t){let e=this.computeSingleOutputShape(t);return this.returnSequences||(e=[e[0],...e.slice(2)]),this.returnState&&(e=[e,...Array(2).fill([t[0],...e.slice(-3)])]),e}getInitialState(e){return t.tidy((()=>{const{stateSize:n}=this.cell,s=e.shape,i=this.computeSingleOutputShape(s),r=[i[0],...i.slice(2)],a=t.zeros(r);return Array.isArray(n)?Array(n.length).fill(a):[a]}))}resetStates(n,s=!1){t.tidy((()=>{if(!this.stateful)throw new nt("Cannot call resetStates() on an RNN Layer that is not stateful.");const i=this.inputSpec[0].shape,r=this.computeSingleOutputShape(i),a=[r[0],...r.slice(2)];if(null==i[0])throw new it("If an RNN is stateful, it needs to know its batch size. Specify the batch size of your input tensors: \n- If using a Sequential model, specify the batch size by passing a `batchInputShape` option to your first layer.\n- If using the functional API, specify the batch size by passing a `batchShape` option to your Input layer.");if(null==this.getStates())Array.isArray(this.cell.stateSize)?this.states_=this.cell.stateSize.map((()=>t.zeros(a))):this.states_=[t.zeros(a)];else if(null==n)t.dispose(this.states_),null!=this.keptStates&&(t.dispose(this.keptStates),this.keptStates=[]),Array.isArray(this.cell.stateSize)?this.states_=this.cell.stateSize.map((()=>t.zeros(a))):this.states_[0]=t.zeros(a);else{if(Array.isArray(n)||(n=[n]),n.length!==this.states_.length)throw new it(`Layer ${this.name} expects ${this.states_.length} state(s), but it received ${n.length} state value(s). Input received: ${n}`);s?this.keptStates.push(this.states_.slice()):t.dispose(this.states_);for(let t=0;tt.keep(e.clone())))}))}computeSingleOutputShape(t){const{dataFormat:e,filters:n,kernelSize:s,padding:i,strides:r,dilationRate:a}=this.cell,o="channelsFirst"===e,l=t[o?3:2],u=t[o?4:3],h=Hu(l,s[0],i,r[0],a[0]),c=Hu(u,s[1],i,r[1],a[1]);return[...t.slice(0,2),...o?[n,h,c]:[h,c,n]]}}Nh.className="ConvRNN2D";class Ih extends wh{constructor(t){const{filters:e,kernelSize:n,strides:s,padding:i,dataFormat:r,dilationRate:a}=t;super(Object.assign(Object.assign({},t),{units:e})),this.filters=e,Nt(this.filters,"filters"),this.kernelSize=Gu(n,2,"kernelSize"),this.kernelSize.forEach((t=>Nt(t,"kernelSize"))),this.strides=Gu(s||1,2,"strides"),this.strides.forEach((t=>Nt(t,"strides"))),this.padding=i||"valid",Ot(this.padding),this.dataFormat=r||"channelsLast",Mt(this.dataFormat),this.dilationRate=Gu(a||1,2,"dilationRate"),this.dilationRate.forEach((t=>Nt(t,"dilationRate")))}build(e){var n;e=Be(e);const s="channelsFirst"===this.dataFormat?1:e.length-1;if(null==e[s])throw new it(`The channel dimension of the input should be defined. Found ${e[s]}`);const i=e[s],r=this.kernelSize.concat([i,4*this.filters]);this.kernel=this.addWeight("kernel",r,null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint);const a=this.kernelSize.concat([this.filters,4*this.filters]);if(this.recurrentKernel=this.addWeight("recurrent_kernel",a,null,this.recurrentInitializer,this.recurrentRegularizer,!0,this.recurrentConstraint),this.useBias){let e;if(this.unitForgetBias){const s=this.biasInitializer,i=this.filters;e=new((n=class extends me{apply(e,n){return se([s.apply([i]),t.ones([i]),s.apply([2*i])])}}).className="CustomInit",n)}else e=this.biasInitializer;this.bias=this.addWeight("bias",[4*this.filters],null,e,this.biasRegularizer,!0,this.biasConstraint)}this.built=!0}call(e,n){return t.tidy((()=>{if(3!==e.length)throw new it(`ConvLSTM2DCell expects 3 input Tensors (inputs, h, c), got ${e.length}.`);const s=n.training||!1,i=e[0],r=e[1],a=e[2];0t.onesLike(i),rate:this.dropout,training:s,count:4,dropoutFunc:this.dropoutFunc}));const o=this.dropoutMask,l=(e,n,s)=>n&&n[s]?t.mul(n[s],e):e;let u=l(i,o,0),h=l(i,o,1),c=l(i,o,2),p=l(i,o,3);0t.onesLike(r),rate:this.recurrentDropout,training:s,count:4,dropoutFunc:this.dropoutFunc}));const d=this.recurrentDropoutMask;let f=l(r,d,0),g=l(r,d,1),m=l(r,d,2),y=l(r,d,3);const[b,w,k,v]=t.split(this.kernel.read(),4,3),[S,x,N,I]=this.useBias?t.split(this.bias.read(),4):[null,null,null,null];u=this.inputConv(u,b,S,this.padding),h=this.inputConv(h,w,x,this.padding),c=this.inputConv(c,k,N,this.padding),p=this.inputConv(p,v,I,this.padding);const[A,T,E,C]=t.split(this.recurrentKernel.read(),4,3);f=this.recurrentConv(f,A),g=this.recurrentConv(g,T),m=this.recurrentConv(m,E),y=this.recurrentConv(y,C);const z=this.recurrentActivation.apply(t.add(u,f)),$=this.recurrentActivation.apply(t.add(h,g)),F=t.add(t.mul($,a),t.mul(z,this.activation.apply(t.add(c,m)))),D=t.mul(this.recurrentActivation.apply(t.add(p,y)),this.activation.apply(F));return[D,D,F]}))}getConfig(){const t=super.getConfig(),e=xh(t,["units"]),n={filters:this.filters,kernelSize:this.kernelSize,padding:this.padding,dataFormat:this.dataFormat,dilationRate:this.dilationRate,strides:this.strides};return Object.assign(Object.assign({},e),n)}inputConv(e,n,s,i){const r=t.conv2d(e,n,this.strides,i||"valid","channelsFirst"===this.dataFormat?"NCHW":"NHWC",this.dilationRate);return s?ce(r,s,this.dataFormat):r}recurrentConv(e,n){return t.conv2d(e,n,1,"same","channelsFirst"===this.dataFormat?"NCHW":"NHWC")}}Ih.className="ConvLSTM2DCell",t.serialization.registerClass(Ih);class Ah extends Nh{constructor(t){const e=new Ih(t);super(Object.assign(Object.assign({},t),{cell:e}))}static fromConfig(t,e){return new t(e)}}Ah.className="ConvLSTM2D",t.serialization.registerClass(Ah);class Th extends Je{constructor(t){super(t),this.rate=Math.max(Math.min(t.rate,1),0),this.noiseShape=t.noiseShape,this.seed=t.seed,this.supportsMasking=!0}getNoiseShape(t){if(null==this.noiseShape)return this.noiseShape;const e=t.shape,n=[];for(let t=0;t{this.invokeCallHook(t,e);const n=Oe(t);if(0pe(n,this.rate,s,this.seed)),(()=>n),t)}return t}))}getConfig(){const t={rate:this.rate,noiseShape:this.noiseShape,seed:this.seed},e=super.getConfig();return Object.assign(t,e),t}dispose(){return super.dispose()}}Th.className="Dropout",r.registerClass(Th);class Eh extends Th{constructor(t){super(t),this.inputSpec=[{ndim:3}]}getNoiseShape(t){const e=t.shape;return[e[0],1,e[2]]}}Eh.className="SpatialDropout1D",r.registerClass(Eh);class Ch extends Je{constructor(t){if(super(t),this.activation=null,this.useBias=!0,this.kernel=null,this.bias=null,this.DEFAULT_KERNEL_INITIALIZER="glorotNormal",this.DEFAULT_BIAS_INITIALIZER="zeros",null==t.batchInputShape&&null==t.inputShape&&null!=t.inputDim){let e=null;null!=t.batchSize&&(e=t.batchSize),this.batchInputShape=[e,t.inputDim]}this.units=t.units,Nt(this.units,"units"),this.activation=Du(t.activation),null!=t.useBias&&(this.useBias=t.useBias),this.kernelInitializer=_e(t.kernelInitializer||this.DEFAULT_KERNEL_INITIALIZER),this.biasInitializer=_e(t.biasInitializer||this.DEFAULT_BIAS_INITIALIZER),this.kernelConstraint=Yo(t.kernelConstraint),this.biasConstraint=Yo(t.biasConstraint),this.kernelRegularizer=Pu(t.kernelRegularizer),this.biasRegularizer=Pu(t.biasRegularizer),this.activityRegularizer=Pu(t.activityRegularizer),this.supportsMasking=!0,this.inputSpec=[{minNDim:2}]}build(t){const e=(t=Be(t))[t.length-1];null==this.kernel&&(this.kernel=this.addWeight("kernel",[e,this.units],null,this.kernelInitializer,this.kernelRegularizer,!0,this.kernelConstraint),this.useBias&&(this.bias=this.addWeight("bias",[this.units],null,this.biasInitializer,this.biasRegularizer,!0,this.biasConstraint))),this.inputSpec=[{minNDim:2,axes:{[-1]:e}}],this.built=!0}computeOutputShape(t){const e=(t=Be(t)).slice();return e[e.length-1]=this.units,e}call(t,e){return s((()=>{this.invokeCallHook(t,e);const n=Oe(t),s=At(this.activation.getClassName());let i;return null!=s?i=oe(n,this.kernel.read(),s,this.bias?this.bias.read():null):(i=oe(n,this.kernel.read()),null!=this.bias&&(i=ce(i,this.bias.read())),null!=this.activation&&(i=this.activation.apply(i))),i}))}getConfig(){const t={units:this.units,activation:$u(this.activation),useBias:this.useBias,kernelInitializer:Le(this.kernelInitializer),biasInitializer:Le(this.biasInitializer),kernelRegularizer:Ou(this.kernelRegularizer),biasRegularizer:Ou(this.biasRegularizer),activityRegularizer:Ou(this.activityRegularizer),kernelConstraint:Jo(this.kernelConstraint),biasConstraint:Jo(this.biasConstraint)},e=super.getConfig();return Object.assign(t,e),t}}Ch.className="Dense",r.registerClass(Ch);class zh extends Je{constructor(t){super(t=t||{}),this.inputSpec=[{minNDim:3}],this.dataFormat=t.dataFormat}computeOutputShape(t){t=Be(t);for(const e of t.slice(1))if(null==e)throw new it(`The shape of the input to "Flatten" is not fully defined (got ${t.slice(1)}). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model.`);return[t[0],Kt(t,1)]}call(e,n){return s((()=>{this.invokeCallHook(e,n);let s=Oe(e);if("channelsFirst"===this.dataFormat&&s.rank>1){const t=[0];for(let e=2;e{this.invokeCallHook(t,e);const n=Oe(t);return this.activation.apply(n)}))}getConfig(){const t={activation:$u(this.activation)},e=super.getConfig();return Object.assign(t,e),t}}$h.className="Activation",r.registerClass($h);class Fh extends Je{constructor(t){super(t),this.n=t.n,this.inputSpec=[{ndim:2}]}computeOutputShape(t){return[t[0],this.n,t[1]]}call(t,e){return s((()=>{return t=Oe(t),e=t,n=this.n,s((()=>{if(2!==e.shape.length)throw new it(`repeat() expects a rank-2 tensor, but received a rank-${e.shape.length} tensor.`);return re(Qt(e,1),[1,n,1])}));var e,n}))}getConfig(){const t={n:this.n},e=super.getConfig();return Object.assign(t,e),t}}Fh.className="RepeatVector",r.registerClass(Fh);class Dh extends Je{constructor(t){super(t),this.targetShape=t.targetShape;for(let t=0;t{this.invokeCallHook(t,e);const n=Oe(t),s=n.shape,i=s.slice(0,1).concat(this.fixUnknownDimension(s.slice(1),this.targetShape));return I(n,i)}))}getConfig(){const t={targetShape:this.targetShape},e=super.getConfig();return Object.assign(t,e),t}}Dh.className="Reshape",r.registerClass(Dh);class Lh extends Je{constructor(t){if(super(t),null==t.dims)throw new Error("Required configuration field `dims` is missing during Permute constructor call.");if(!Array.isArray(t.dims))throw new Error(`Permute constructor requires \`dims\` to be an Array, but received ${t.dims} instead.`);const n=Jt(1,t.dims.length+1);if(!e.arraysEqual(t.dims.slice().sort(),n))throw new Error("Invalid permutation `dims`: "+JSON.stringify(t.dims)+" `dims` must contain consecutive integers starting from 1.");this.dims=t.dims,this.dimsIncludingBatch=[0].concat(this.dims),this.inputSpec=[new Ve({ndim:this.dims.length+1})]}computeOutputShape(t){const e=(t=Be(t)).slice();return this.dims.forEach(((n,s)=>{e[s+1]=t[n]})),e}call(t,e){return P(Oe(t),this.dimsIncludingBatch)}getConfig(){const t={dims:this.dims},e=super.getConfig();return Object.assign(t,e),t}}Lh.className="Permute",r.registerClass(Lh);class _h extends Je{constructor(t){super(null==t?{}:t),this.supportsMasking=!0,this.maskValue=null!=t?null==t.maskValue?0:t.maskValue:0}computeOutputShape(t){return t}getConfig(){const t=super.getConfig(),e={maskValue:this.maskValue};return Object.assign(e,t),e}computeMask(t,e){const n=Oe(t);return U(W(n,this.maskValue),-1)}call(t,e){return s((()=>{this.invokeCallHook(t,e);const n=Oe(t),s=U(W(n,this.maskValue),-1,!0);return l(n,m(s,n.dtype))}))}}_h.className="Masking",r.registerClass(_h);class Rh extends Je{constructor(t){if(super(t),this.embeddings=null,this.DEFAULT_EMBEDDINGS_INITIALIZER="randomUniform",null==t.batchInputShape&&null==t.inputShape){let e=null;null!=t.batchSize&&(e=t.batchSize),null==t.inputLength?this.batchInputShape=[e,null]:this.batchInputShape=[e].concat(pt(t.inputLength))}this.inputDim=t.inputDim,Nt(this.inputDim,"inputDim"),this.outputDim=t.outputDim,Nt(this.outputDim,"outputDim"),this.embeddingsInitializer=_e(t.embeddingsInitializer||this.DEFAULT_EMBEDDINGS_INITIALIZER),this.embeddingsRegularizer=Pu(t.embeddingsRegularizer),this.activityRegularizer=Pu(t.activityRegularizer),this.embeddingsConstraint=Yo(t.embeddingsConstraint),this.maskZero=t.maskZero,this.supportsMasking=t.maskZero,this.inputLength=t.inputLength}build(t){this.embeddings=this.addWeight("embeddings",[this.inputDim,this.outputDim],this.dtype,this.embeddingsInitializer,this.embeddingsRegularizer,!0,this.embeddingsConstraint),this.built=!0}warnOnIncompatibleInputShape(t){}computeMask(t,e){return s((()=>this.maskZero?(t=Oe(t),W(t,j(t))):null))}computeOutputShape(t){if(t=Be(t),null==this.inputLength)return[...t,this.outputDim];const e=pt(this.inputLength);if(e.length!==t.length-1)throw new it(`"inputLength" is ${this.inputLength}, but received input shape has shape ${t}`);{let n=0;for(let s=0;s{this.invokeCallHook(t,e);let n=Oe(t);"int32"!==n.dtype&&(n=Xt(n,"int32"));const s=le(this.embeddings.read(),I(n,[n.size]));return I(s,Be(this.computeOutputShape(n.shape)))}))}getConfig(){const t={inputDim:this.inputDim,outputDim:this.outputDim,embeddingsInitializer:Le(this.embeddingsInitializer),embeddingsRegularizer:Ou(this.embeddingsRegularizer),activityRegularizer:Ou(this.activityRegularizer),embeddingsConstraint:Jo(this.embeddingsConstraint),maskZero:this.maskZero,inputLength:this.inputLength},e=super.getConfig();return Object.assign(t,e),t}}Rh.className="Embedding",r.registerClass(Rh);class Mh extends Je{constructor(t){super(t||{}),this.supportsMasking=!0}mergeFunction(t){throw new rt}computeElementwiseOpOutputShape(t,e){if(null==t||null==e)return null;if(t.length1)throw new it(`Can not merge tensors with different batch sizes. Got tensors with shapes: ${JSON.stringify(t)}.`);let n=null==t[0]?null:t[0].slice(1);for(let e=1;et.length));-1===t.indexOf(null)&&1===kt(s).length?this.reshapeRequired=!1:this.reshapeRequired=!0}call(e,n){return s((()=>{if(this.reshapeRequired){const n=[],s=e.map((t=>t.rank));if(-1===s.indexOf(null)){const t=Ht(s);for(let s of e){const e=s.rank;for(let n=0;n1){const r=Jt(1,e).concat([0]);n.push(t.transpose(i,r)),s=!0}else n.push(i)}let i=this.mergeFunction(n);const r=i.rank;if(s)if(null==r){const e=i.shape,n=e[e.length-1],s=[n].concat(e.slice(0,e.length-1));i=t.reshape(t.transpose(t.reshape(i,[-1,n]),[1,0]),s)}else if(r>1){const e=[r-1].concat(Jt(0,r-1));i=t.transpose(i,e)}return i}}return this.mergeFunction(e)}))}computeOutputShape(t){let e;e=null==t[0]?null:t[0].slice(1);for(let n=1;n{if(null==n)return null;if(!Array.isArray(n))throw new it("`mask` should be an Array");if(!Array.isArray(e))throw new it("`inputs` should be an Array");if(n.length!==e.length)throw new it(`The Array 'inputs' and 'mask' are expected to have the same length, but have different lengths (${e.length} vs ${n.length})`);if(n.every((t=>null==t)))return null;let s=(n=n.map((e=>null==e?e:t.expandDims(e,0))))[0];for(let e=1;e{let n=e[0].clone();for(let s=1;s{let n=e[0].clone();for(let s=1;s{let n=e[0].clone();for(let s=1;s{let n=e[0];for(let s=1;s{let n=e[0];for(let s=1;s1)throw new it("A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got input shapes: "+JSON.stringify(t))}mergeFunction(t){return s((()=>se(t,this.axis)))}computeOutputShape(t){if(!Array.isArray(t)||!Array.isArray(t[0]))throw new it("A `Concatenate` layer should be called on a list of inputs.");const e=t,n=e[0].slice(),s=this.axis<0?n.length+this.axis:this.axis;for(const t of e.slice(1)){if(null==n[s]||null==t[s]){n[s]=null;break}n[s]+=t[s]}return n}computeMask(e,n){if(null==n)return null;if(!Array.isArray(n))throw new it("`mask` should be an array for Concatenate");if(!Array.isArray(e))throw new it("`inputs` should be an array for Concatenate");if(n.length!==e.length)throw new it(`Mismatch in the length of mask (${n.length}) and the legnth of inputs (${e.length})`);return t.tidy((()=>{let s=!0;if(n.forEach((t=>{null==t||(s=!1)})),s)return null;const i=[];for(let s=0;s"A `Dot` layer should be called on a list of exactly 2 inputs."));const n=e[0],s=e[1];if(n.length>3||s.length>3)throw new rt("Dot layer does not support tensors of 4D or higher rank yet.");const i=this.interpretAxes(n,s);if(n[i[0]]!==s[i[1]])throw new it(`Dimension incompatibility: ${n[i[0]]} !== ${s[i[1]]}`)}mergeFunction(e){if(2!==e.length)throw new it(`A \`Dot\` layer must be called on exactly 2 inputs, but received ${e.length} input(s).`);let n,s=e[0],i=e[1];return n=Array.isArray(this.axes)?this.axes.map(((t,n)=>Vh(t,e[n].shape.length))):[Vh(this.axes,s.shape.length),Vh(this.axes,i.shape.length)],this.normalize&&(s=pl(s,n[0]),i=pl(i,n[1])),function(e,n,s){if(e.shape.length>3||n.shape.length>3)throw new rt("batchDot is not implemented for tensors of 4D or higher rank yet");if(t.util.assert(e.shape.length>=2,(()=>`batchDot requires the rank of x to be >= 2, but got ${e.shape.length}`)),t.util.assert(e.shape.length>=2,(()=>`batchDot requires the rank of y to be >= 2, but got ${n.shape.length}`)),"number"==typeof s&&(s=[s,s]),"complex64"===e.dtype||"complex64"===n.dtype)throw new rt("batchDot is not implemented for complex64-type Tensors yet.");const i=e.shape.length,r=n.shape.length;null==s&&(s=[i-1,r-2]);const a=s;return t.tidy((()=>{let s,o;if(i>r){s=i-r;const e=[];for(let t=0;ti){s=r-i;const n=[];for(let t=0;t0){let e;e=i>r?i+r-3:i-1;const n=[];for(let t=e;t"A `Dot` layer should be called on a list of exactly 2 inputs."));const n=e[0].slice(),s=e[1].slice();if(n.length>3||s.length>3)throw new rt("Dot layer does not support tensors of 4D or higher rank yet.");const i=this.interpretAxes(n,s);n.splice(i[0],1),s.splice(i[1],1),s.splice(0,1);const r=n.concat(s);return 1===r.length&&r.push(1),r}computeMask(t,e){return null}getConfig(){const t={axes:this.axes,normalize:this.normalize},e=super.getConfig();return Object.assign(t,e),t}}qh.className="Dot",r.registerClass(qh);class Kh extends Je{constructor(t){super(t),this.supportsMasking=!0,this.stddev=t.stddev}computeOutputShape(t){return t}getConfig(){const t=super.getConfig(),e={stddev:this.stddev};return Object.assign(e,t),e}call(t,e){return s((()=>{this.invokeCallHook(t,e);const n=Oe(t);return de((()=>w(ae(n.shape,0,this.stddev),n)),(()=>n),e.training||!1)}))}}Kh.className="GaussianNoise",r.registerClass(Kh);class Gh extends Je{constructor(t){super(t),this.supportsMasking=!0,this.rate=t.rate}computeOutputShape(t){return t}getConfig(){const t=super.getConfig(),e={rate:this.rate};return Object.assign(e,t),e}call(t,e){return s((()=>{this.invokeCallHook(t,e);const n=Oe(t);if(this.rate>0&&this.rate<1){return de((()=>{const t=Math.sqrt(this.rate/(1-this.rate));return l(n,ae(n.shape,1,t))}),(()=>n),e.training||!1)}return n}))}}Gh.className="GaussianDropout",r.registerClass(Gh);class Hh extends Je{constructor(t){super(t),this.supportsMasking=!0,this.rate=t.rate,this.noiseShape=t.noiseShape}_getNoiseShape(t){return this.noiseShape||Oe(t).shape}computeOutputShape(t){return t}getConfig(){const t=super.getConfig(),e={rate:this.rate};return Object.assign(e,t),e}call(t,e){return s((()=>{if(this.rate<1&&this.rate>0){const n=this._getNoiseShape(t),s=()=>{const e=Oe(t),s=-1.7580993408473766;let i=V(h(n),this.rate);i=Xt(i,"float32");const r=((1-this.rate)*(1+this.rate*s**2))**-.5,a=-r*s*this.rate,o=w(l(e,i),l(w(i,-1),s));return w(l(o,r),a)};return de(s,(()=>Oe(t)),e.training||!1)}return t}))}}function Jh(e,n,s,i,r,a=.001){let o;if(2===e.rank)o=t.batchNorm2d(e,n,s,i,r,a);else if(3===e.rank)o=t.batchNorm3d(e,n,s,i,r,a);else{if(4!==e.rank)throw new rt(`batchNormalization is not implemented for array of rank ${e.rank} yet`);o=t.batchNorm4d(e,n,s,i,r,a)}return o}function Zh(n,i,r,a,o=.001){return e.arraysEqual(a.slice().sort(),Jt(0,n.rank-1))?function(e,n,i,r,a=.001){return s((()=>{const s=t.moments(e,r),o=s.mean,l=s.variance;return[Jh(e,o,l,i,n,a),o,l]}))}(n,i,r,a,o):function(e,n,i,r,a=.001){return s((()=>{const s=t.moments(e,r),o=s.mean,l=s.variance,u=[];for(const t of Jt(0,e.rank))-1!==r.indexOf(t)?u.push(1):u.push(e.shape[t]);const h=I(o,u),c=I(l,u),p=null==n?null:I(n,u),d=null==i?null:I(i,u);return[Jh(e,h,c,d,p,a),o,l]}))}(n,i,r,a,o)}Hh.className="AlphaDropout",r.registerClass(Hh);class Yh extends Je{constructor(t){null==t&&(t={}),super(t),this.supportsMasking=!0,this.axis=null==t.axis?-1:t.axis,this.momentum=null==t.momentum?.99:t.momentum,this.epsilon=null==t.epsilon?.001:t.epsilon,this.center=null==t.center||t.center,this.scale=null==t.scale||t.scale,this.betaInitializer=_e(t.betaInitializer||"zeros"),this.gammaInitializer=_e(t.gammaInitializer||"ones"),this.movingMeanInitializer=_e(t.movingMeanInitializer||"zeros"),this.movingVarianceInitializer=_e(t.movingVarianceInitializer||"ones"),this.betaConstraint=Yo(t.betaConstraint),this.gammaConstraint=Yo(t.gammaConstraint),this.betaRegularizer=Pu(t.betaRegularizer),this.gammaRegularizer=Pu(t.gammaRegularizer)}build(t){t=Be(t);const e=this.axis>=0?this.axis:this.axis+t.length,n=t[e];if(null==n)throw new it(`Axis ${e} of input tensor should have a defined dimension but the layer received an input with shape ${JSON.stringify(t)}.`);this.inputSpec=[new Ve({ndim:t.length,axes:{[e]:n}})];const s=[n];this.scale&&(this.gamma=this.addWeight("gamma",s,null,this.gammaInitializer,this.gammaRegularizer,!0,this.gammaConstraint)),this.center&&(this.beta=this.addWeight("beta",s,null,this.betaInitializer,this.betaRegularizer,!0,this.betaConstraint)),this.movingMean=this.addWeight("moving_mean",s,null,this.movingMeanInitializer,null,!1),this.movingVariance=this.addWeight("moving_variance",s,null,this.movingVarianceInitializer,null,!1),this.built=!0}call(n,i){return s((()=>{const s=null!=i.training&&i.training,r=Oe(n),a=r.shape,o=a.length,l=Jt(0,o),u=this.axis>=0?this.axis:this.axis+o;l.splice(u,1);const h=lt(1,o);h[u]=a[u];const c=l.slice();c.sort();const p=!e.arraysEqual(c,Jt(0,o).slice(0,o-1));if(!s)return(()=>{if(p){const t=I(this.movingMean.read(),h),e=I(this.movingVariance.read(),h),n=this.center?I(this.beta.read(),h):null,s=this.scale?I(this.gamma.read(),h):null;return Jh(r,t,e,n,s,this.epsilon)}return Jh(r,this.movingMean.read(),this.movingVariance.read(),null==this.beta?null:this.beta.read(),null==this.gamma?null:this.gamma.read(),this.epsilon)})();const[d,f,g]=Zh(r,this.gamma.read(),this.beta.read(),l,this.epsilon),m=(e,n,s)=>{t.tidy((()=>{const i=1-s,r=e.read(),a=t.mul(t.sub(r,n),i);e.write(t.sub(r,a))}))};return(()=>{m(this.movingMean,f,this.momentum),m(this.movingVariance,g,this.momentum)})(),d}))}getConfig(){const t={axis:this.axis,momentum:this.momentum,epsilon:this.epsilon,center:this.center,scale:this.scale,betaInitializer:Le(this.betaInitializer),gammaInitializer:Le(this.gammaInitializer),movingMeanInitializer:Le(this.movingMeanInitializer),movingVarianceInitializer:Le(this.movingVarianceInitializer),betaRegularizer:Ou(this.betaRegularizer),gammaRegularizer:Ou(this.gammaRegularizer),betaConstraint:Jo(this.betaConstraint),gammaConstraint:Jo(this.gammaConstraint)},e=super.getConfig();return Object.assign(t,e),t}}Yh.className="BatchNormalization",r.registerClass(Yh);class Xh extends Je{constructor(t){if(null==t&&(t={}),super(t),this.axis=null==t.axis?-1:t.axis,"number"==typeof this.axis){if(!Number.isInteger(this.axis))throw new Error(`Expected axis to be an integer, but received ${this.axis}`)}else{if(!Array.isArray(this.axis))throw new Error(`Expected axis to be an integer or an array of integers, but received ${JSON.stringify(this.axis)}`);for(const t of this.axis)if(!Number.isInteger(t))throw new Error(`Expected axis to be an array of integers, but received ${JSON.stringify(this.axis)}`)}this.epsilon=null==t.epsilon?.001:t.epsilon,this.center=null==t.center||t.center,this.scale=null==t.scale||t.scale,this.betaInitializer=_e(t.betaInitializer||"zeros"),this.gammaInitializer=_e(t.gammaInitializer||"ones"),this.betaRegularizer=Pu(t.betaRegularizer),this.gammaRegularizer=Pu(t.gammaRegularizer),this.supportsMasking=!0}build(t){const e=(t=Be(t)).length;"number"==typeof this.axis&&(this.axis=[this.axis]);for(let t=0;t=e)throw new Error(`Invalid axis: ${t}`);if(this.axis.length!==kt(this.axis).length)throw new Error(`Found duplicate axes in: ${this.axis}`);const n=this.axis.map((e=>t[e]));this.scale?this.gamma=this.addWeight("gamma",n,"float32",this.gammaInitializer,this.gammaRegularizer,true):this.gamma=null,this.center?this.beta=this.addWeight("beta",n,"float32",this.betaInitializer,this.betaRegularizer,true):this.beta=null,this.built=!0}call(e,n){const i=Oe(e),r=i.shape,a=r.length;return s((()=>{let{mean:e,variance:n}=q(i,this.axis,!0);const s=lt(1,a);for(const t of this.axis)s[t]=r[t];const o=e=>null!=e&&e.shape.length!==a?t.reshape(e,s):e;let l=this.scale?o(this.gamma.read()):null,u=this.center?o(this.beta.read()):null;const h=[],c=[];for(let t=0;t=0?t[2]+this.padding[0][0]+this.padding[0][1]:null,n=null!=t[3]&&t[3]>=0?t[3]+this.padding[1][0]+this.padding[1][1]:null,[t[0],t[1],e,n]):(e=null!=t[1]&&t[1]>=0?t[1]+this.padding[0][0]+this.padding[0][1]:null,n=null!=t[2]&&t[2]>=0?t[2]+this.padding[1][0]+this.padding[1][1]:null,[t[0],e,n,t[3]])}call(e,n){return s((()=>{return n=Oe(e),i=this.padding,r=this.dataFormat,s((()=>{if(4!==n.rank)throw new it(`temporalPadding expects input tensor to be 4-D, but received a ${n.rank}-D tensor.`);if(null==i&&(i=[[1,1],[1,1]]),2!==i.length||2!==i[0].length||2!==i[1].length)throw new it("spatial2dPadding expects `padding` to be an Array of two Arrays, each of which is an Array of two integers.");if(null==r&&(r="channelsLast"),"channelsLast"!==r&&"channelsFirst"!==r)throw new it(`Unknown data format: ${r}. Supported data formats are 'channelsLast' and 'channelsFirst.`);let e;return e="channelsFirst"===r?[[0,0],[0,0],i[0],i[1]]:[[0,0],i[0],i[1],[0,0]],t.pad(n,e)}));var n,i,r}))}getConfig(){const t={padding:this.padding,dataFormat:this.dataFormat},e=super.getConfig();return Object.assign(t,e),t}}function tc(e,n,i,r,a,o){return s((()=>{let s;Mt(a),Bt(o),Ot(r),null==i&&(i=[1,1]),null==r&&(r="valid"),null==a&&(a="channelsLast"),null==o&&(o="max"),e=Zu(e,a);const l="same"===r?"same":"valid";return s="max"===o?t.maxPool(e,n,i,l):t.avgPool(e,n,i,l),"channelsFirst"===a&&(s=t.transpose(s,[0,3,1,2])),s}))}function ec(e,n,i,r,a,o){return s((()=>{let s;Mt(a),Bt(o),Ot(r),null==i&&(i=[1,1,1]),null==r&&(r="valid"),null==a&&(a="channelsLast"),null==o&&(o="max"),e=Yu(e,a);const l="same"===r?"same":"valid";return s="max"===o?t.maxPool3d(e,n,i,l):t.avgPool3d(e,n,i,l),"channelsFirst"===a&&(s=t.transpose(s,[0,4,1,2,3])),s}))}Qh.className="ZeroPadding2D",r.registerClass(Qh);class nc extends Je{constructor(t){if(null==t.poolSize&&(t.poolSize=2),super(t),"number"==typeof t.poolSize)this.poolSize=[t.poolSize];else{if(!Array.isArray(t.poolSize)||1!==t.poolSize.length||"number"!=typeof t.poolSize[0])throw new it(`poolSize for 1D convolutional layer must be a number or an Array of a single number, but received ${JSON.stringify(t.poolSize)}`);this.poolSize=t.poolSize}if(Nt(this.poolSize,"poolSize"),null==t.strides)this.strides=this.poolSize;else if("number"==typeof t.strides)this.strides=[t.strides];else{if(!Array.isArray(t.strides)||1!==t.strides.length||"number"!=typeof t.strides[0])throw new it(`strides for 1D convolutional layer must be a number or an Array of a single number, but received ${JSON.stringify(t.strides)}`);this.strides=t.strides}Nt(this.strides,"strides"),this.padding=null==t.padding?"valid":t.padding,Ot(this.padding),this.inputSpec=[new Ve({ndim:3})]}computeOutputShape(t){const e=Hu((t=Be(t))[1],this.poolSize[0],this.padding,this.strides[0]);return[t[0],e,t[2]]}call(e,n){return s((()=>{this.invokeCallHook(e,n),e=Qt(Oe(e),2);const s=this.poolingFunction(Oe(e),[this.poolSize[0],1],[this.strides[0],1],this.padding,"channelsLast");return t.squeeze(s,[2])}))}getConfig(){const t={poolSize:this.poolSize,padding:this.padding,strides:this.strides},e=super.getConfig();return Object.assign(t,e),t}}class sc extends nc{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return Mt(i),Ot(s),tc(t,e,n,s,i,"max")}}sc.className="MaxPooling1D",r.registerClass(sc);class ic extends nc{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return Mt(i),Ot(s),tc(t,e,n,s,i,"avg")}}ic.className="AveragePooling1D",r.registerClass(ic);class rc extends Je{constructor(t){if(null==t.poolSize&&(t.poolSize=[2,2]),super(t),this.poolSize=Array.isArray(t.poolSize)?t.poolSize:[t.poolSize,t.poolSize],null==t.strides)this.strides=this.poolSize;else if(Array.isArray(t.strides)){if(2!==t.strides.length)throw new it(`If the strides property of a 2D pooling layer is an Array, it is expected to have a length of 2, but received length ${t.strides.length}.`);this.strides=t.strides}else this.strides=[t.strides,t.strides];Nt(this.poolSize,"poolSize"),Nt(this.strides,"strides"),this.padding=null==t.padding?"valid":t.padding,this.dataFormat=null==t.dataFormat?"channelsLast":t.dataFormat,Mt(this.dataFormat),Ot(this.padding),this.inputSpec=[new Ve({ndim:4})]}computeOutputShape(t){t=Be(t);let e="channelsFirst"===this.dataFormat?t[2]:t[1],n="channelsFirst"===this.dataFormat?t[3]:t[2];return e=Hu(e,this.poolSize[0],this.padding,this.strides[0]),n=Hu(n,this.poolSize[1],this.padding,this.strides[1]),"channelsFirst"===this.dataFormat?[t[0],t[1],e,n]:[t[0],e,n,t[3]]}call(t,e){return s((()=>(this.invokeCallHook(t,e),this.poolingFunction(Oe(t),this.poolSize,this.strides,this.padding,this.dataFormat))))}getConfig(){const t={poolSize:this.poolSize,padding:this.padding,strides:this.strides,dataFormat:this.dataFormat},e=super.getConfig();return Object.assign(t,e),t}}class ac extends rc{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return Mt(i),Ot(s),tc(t,e,n,s,i,"max")}}ac.className="MaxPooling2D",r.registerClass(ac);class oc extends rc{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return Mt(i),Ot(s),tc(t,e,n,s,i,"avg")}}oc.className="AveragePooling2D",r.registerClass(oc);class lc extends Je{constructor(t){if(null==t.poolSize&&(t.poolSize=[2,2,2]),super(t),this.poolSize=Array.isArray(t.poolSize)?t.poolSize:[t.poolSize,t.poolSize,t.poolSize],null==t.strides)this.strides=this.poolSize;else if(Array.isArray(t.strides)){if(3!==t.strides.length)throw new it(`If the strides property of a 3D pooling layer is an Array, it is expected to have a length of 3, but received length ${t.strides.length}.`);this.strides=t.strides}else this.strides=[t.strides,t.strides,t.strides];Nt(this.poolSize,"poolSize"),Nt(this.strides,"strides"),this.padding=null==t.padding?"valid":t.padding,this.dataFormat=null==t.dataFormat?"channelsLast":t.dataFormat,Mt(this.dataFormat),Ot(this.padding),this.inputSpec=[new Ve({ndim:5})]}computeOutputShape(t){t=Be(t);let e="channelsFirst"===this.dataFormat?t[2]:t[1],n="channelsFirst"===this.dataFormat?t[3]:t[2],s="channelsFirst"===this.dataFormat?t[4]:t[3];return e=Hu(e,this.poolSize[0],this.padding,this.strides[0]),n=Hu(n,this.poolSize[1],this.padding,this.strides[1]),s=Hu(s,this.poolSize[2],this.padding,this.strides[2]),"channelsFirst"===this.dataFormat?[t[0],t[1],e,n,s]:[t[0],e,n,s,t[4]]}call(t,e){return s((()=>(this.invokeCallHook(t,e),this.poolingFunction(Oe(t),this.poolSize,this.strides,this.padding,this.dataFormat))))}getConfig(){const t={poolSize:this.poolSize,padding:this.padding,strides:this.strides,dataFormat:this.dataFormat},e=super.getConfig();return Object.assign(t,e),t}}class uc extends lc{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return Mt(i),Ot(s),ec(t,e,n,s,i,"max")}}uc.className="MaxPooling3D",r.registerClass(uc);class hc extends lc{constructor(t){super(t)}poolingFunction(t,e,n,s,i){return Mt(i),Ot(s),ec(t,e,n,s,i,"avg")}}hc.className="AveragePooling3D",r.registerClass(hc);class cc extends Je{constructor(t){super(t),this.inputSpec=[new Ve({ndim:3})]}computeOutputShape(t){return[t[0],t[2]]}call(t,e){throw new rt}}class pc extends cc{constructor(t){super(t||{})}call(e,n){return s((()=>{const n=Oe(e);return t.mean(n,1)}))}}pc.className="GlobalAveragePooling1D",r.registerClass(pc);class dc extends cc{constructor(t){super(t||{})}call(e,n){return s((()=>{const n=Oe(e);return t.max(n,1)}))}}dc.className="GlobalMaxPooling1D",r.registerClass(dc);class fc extends Je{constructor(t){super(t),this.dataFormat=null==t.dataFormat?"channelsLast":t.dataFormat,Mt(this.dataFormat),this.inputSpec=[new Ve({ndim:4})]}computeOutputShape(t){return"channelsLast"===this.dataFormat?[t[0],t[3]]:[t[0],t[1]]}call(t,e){throw new rt}getConfig(){const t={dataFormat:this.dataFormat},e=super.getConfig();return Object.assign(t,e),t}}class gc extends fc{call(e,n){return s((()=>{const n=Oe(e);return"channelsLast"===this.dataFormat?t.mean(n,[1,2]):t.mean(n,[2,3])}))}}gc.className="GlobalAveragePooling2D",r.registerClass(gc);class mc extends fc{call(e,n){return s((()=>{const n=Oe(e);return"channelsLast"===this.dataFormat?t.max(n,[1,2]):t.max(n,[2,3])}))}}mc.className="GlobalMaxPooling2D",r.registerClass(mc);class yc extends Je{constructor(t){super(t),this.layer=t.layer}build(t){this.built=!0}get trainable(){return null!=this.layer&&this.layer.trainable}set trainable(t){null!=this.layer&&(this.layer.trainable=t)}get trainableWeights(){return this.layer.trainableWeights}get nonTrainableWeights(){return this.layer.nonTrainableWeights}get updates(){return this.layer._updates}get losses(){return this.layer.losses}getWeights(){return this.layer.getWeights()}setWeights(t){this.layer.setWeights(t)}getConfig(){const t={layer:{className:this.layer.getClassName(),config:this.layer.getConfig()}},e=super.getConfig();return Object.assign(t,e),t}setFastWeightInitDuringBuild(t){super.setFastWeightInitDuringBuild(t),null!=this.layer&&this.layer.setFastWeightInitDuringBuild(t)}static fromConfig(t,e,n={}){const s=cl(e.layer,n);delete e.layer;const i={layer:s};return Object.assign(i,e),new t(i)}}class bc extends yc{constructor(t){super(t),this.supportsMasking=!0}build(t){if((t=Be(t)).length<3)throw new it(`TimeDistributed layer expects an input shape >= 3D, but received input shape ${JSON.stringify(t)}`);this.inputSpec=[{shape:t}];const e=[t[0]].concat(t.slice(2));this.layer.built||(this.layer.build(e),this.layer.built=!0),super.build(t)}computeOutputShape(t){const e=[(t=Be(t))[0]].concat(t.slice(2)),n=this.layer.computeOutputShape(e),s=t[1];return[n[0],s].concat(n.slice(1))}call(t,e){return s((()=>ph(((t,n)=>[Oe(this.layer.call(t,e)),[]]),t=Oe(t),[],!1,null,null,!1,!0)[1]))}}bc.className="TimeDistributed",r.registerClass(bc);class wc extends yc{constructor(t){super(t);const e=t.layer.getConfig(),n={};n.className=t.layer.getClassName(),n.config=e,this.forwardLayer=cl(n),e.goBackwards=!0!==e.goBackwards;const s={};var i;if(s.className=t.layer.getClassName(),s.config=e,this.backwardLayer=cl(s),this.forwardLayer.name="forward_"+this.forwardLayer.name,this.backwardLayer.name="backward_"+this.backwardLayer.name,this.mergeMode=void 0===t.mergeMode?"concat":t.mergeMode,i=this.mergeMode,St(_t,"BidirectionalMergeMode",i),t.weights)throw new rt("weights support is not implemented for Bidirectional layer yet.");this._stateful=t.layer.stateful,this.returnSequences=t.layer.returnSequences,this.returnState=t.layer.returnState,this.supportsMasking=!0,this._trainable=!0,this.inputSpec=t.layer.inputSpec,this.numConstants=null}get trainable(){return this._trainable}set trainable(t){this._trainable=t,null!=this.forwardLayer&&(this.forwardLayer.trainable=t),null!=this.backwardLayer&&(this.backwardLayer.trainable=t)}getWeights(){return this.forwardLayer.getWeights().concat(this.backwardLayer.getWeights())}setWeights(t){const e=t.length,n=Math.floor(e/2);this.forwardLayer.setWeights(t.slice(0,n)),this.backwardLayer.setWeights(t.slice(n))}computeOutputShape(t){let e,n,s,i=this.forwardLayer.computeOutputShape(t);return Array.isArray(i)&&Array.isArray(i[0])||(i=[i]),this.returnState?(s=i.slice(1),e=i[0]):e=i[0],"concat"===this.mergeMode?(e[e.length-1]*=2,n=[e]):n=null==this.mergeMode?[e,e.slice()]:[e],this.returnState?null==this.mergeMode?n.concat(s).concat(s.slice()):[e].concat(s).concat(s.slice()):ct(n)}apply(t,e){let n=null==e?null:e.initialState,s=null==e?null:e.constants;null==e&&(e={});const i=ch(t,n,s,this.numConstants);if(t=i.inputs,n=i.initialState,s=i.constants,Array.isArray(t)&&(n=t.slice(1),t=t[0]),(null==n||0===n.length)&&null==s)return super.apply(t,e);const r=[],a=[];if(null!=n){const t=n.length;if(t%2>0)throw new it("When passing `initialState` to a Bidrectional RNN, the state should be an Array containing the states of the underlying RNNs.");e.initialState=n,r.push(...n);const s=n.map((t=>new Ve({shape:t.shape})));this.forwardLayer.stateSpec=s.slice(0,t/2),this.backwardLayer.stateSpec=s.slice(t/2),a.push(...s)}if(null!=s)throw new rt("Support for constants in Bidirectional layers is not implemented yet.");const o=r[0]instanceof qe;for(const t of r)if(t instanceof qe!==o)throw new it("The initial state of a Bidirectional layer cannot be specified as a mix of symbolic and non-symbolic tensors");if(o){const n=[t].concat(r),s=this.inputSpec.concat(a),i=this.inputSpec;this.inputSpec=s;const o=super.apply(n,e);return this.inputSpec=i,o}return super.apply(t,e)}call(e,n){return s((()=>{const s=n.initialState;let i,r,a,o;if(null==s)i=this.forwardLayer.call(e,n),r=this.backwardLayer.call(e,n);else{const t=s.slice(0,s.length/2),a=s.slice(s.length/2);i=this.forwardLayer.call(e,Object.assign(n,{initialState:t})),r=this.backwardLayer.call(e,Object.assign(n,{initialState:a}))}return this.returnState&&(Array.isArray(i)&&(a=i.slice(1).concat(r.slice(1))),i=i[0],r=r[0]),this.returnSequences&&(r=t.reverse(r,1)),"concat"===this.mergeMode?o=se([i,r]):"sum"===this.mergeMode?o=t.add(i,r):"ave"===this.mergeMode?o=t.mul(.5,t.add(i,r)):"mul"===this.mergeMode?o=t.mul(i,r):null==this.mergeMode&&(o=[i,r]),this.returnState?null==this.mergeMode?o.concat(a):[o].concat(a):o}))}resetStates(t){this.forwardLayer.resetStates(),this.backwardLayer.resetStates()}build(t){Ut(this.forwardLayer.name,(()=>{this.forwardLayer.build(t)})),Ut(this.backwardLayer.name,(()=>{this.backwardLayer.build(t)})),this.built=!0}computeMask(t,e){let n;if(Array.isArray(e)&&(e=e[0]),n=this.returnSequences?null==this.mergeMode?[e,e]:e:null==this.mergeMode?[null,null]:null,this.returnState){const t=this.forwardLayer.states.map((t=>null));return Array.isArray(n)?n.concat(t).concat(t):[n].concat(t).concat(t)}return n}get trainableWeights(){return this.forwardLayer.trainableWeights.concat(this.backwardLayer.trainableWeights)}get nonTrainableWeights(){return this.forwardLayer.nonTrainableWeights.concat(this.backwardLayer.nonTrainableWeights)}setFastWeightInitDuringBuild(t){super.setFastWeightInitDuringBuild(t),null!=this.forwardLayer&&this.forwardLayer.setFastWeightInitDuringBuild(t),null!=this.backwardLayer&&this.backwardLayer.setFastWeightInitDuringBuild(t)}getConfig(){const t={mergeMode:this.mergeMode},e=super.getConfig();return Object.assign(t,e),t}static fromConfig(t,e){const n=cl(e.layer);if(delete e.layer,null!=e.numConstants)throw new rt("Deserialization of a Bidirectional layer with numConstants present is not supported yet.");const s=e;return s.layer=n,new t(s)}}wc.className="Bidirectional",r.registerClass(wc);class kc extends Je{constructor(t){super(t),this.scale=t.scale,t.offset?this.offset=t.offset:this.offset=0}getConfig(){const t={scale:this.scale,offset:this.offset},e=super.getConfig();return Object.assign(t,e),t}call(t,e){return s((()=>("float32"!==(t=Oe(t)).dtype&&(t=Xt(t,"float32")),w(l(t,this.scale),this.offset))))}}kc.className="Rescaling",r.registerClass(kc);const{resizeBilinear:vc,cropAndResize:Sc}=Z;class xc extends Je{constructor(t){super(t),this.height=t.height,this.width=t.width}centerCrop(t,e,n,i,r,a,o,l){return s((()=>{let s,u=!1;const h=[e/a,n/o,(i+e)/a,(r+n)/o],c=[];3===t.rank?(u=!0,s=K([t])):s=t;for(let t=0;tXt(vc(t,[e,n]),i)))}call(t,e){return s((()=>{const e=Oe(t),n=e.dtype,s=e.shape,i=s[s.length-3],r=s[s.length-2];let a=0;i!==this.height&&(a=Math.floor((i-this.height)/2));let o=0;return r!==this.width&&(o=Math.floor((r-this.width)/2),0===o&&(o=1)),a>=0&&o>=0?this.centerCrop(e,a,o,this.height,this.width,i,r,n):this.upsize(t,this.height,this.width,n)}))}getConfig(){const t={height:this.height,width:this.width},e=super.getConfig();return Object.assign(t,e),t}computeOutputShape(t){const e=(t=Be(t)).length-3,n=t.length-2;return t[e]=this.height,t[n]=this.width,t}}xc.className="CenterCrop",r.registerClass(xc);class Nc extends Je{constructor(t){super(t),this.numTokens=t.numTokens,t.outputMode?this.outputMode=t.outputMode:this.outputMode="multiHot"}getConfig(){const t={numTokens:this.numTokens,outputMode:this.outputMode},e=super.getConfig();return Object.assign(t,e),t}computeOutputShape(t){return null==(t=Be(t))?[this.numTokens]:"oneHot"===this.outputMode&&1!==t[t.length-1]?(t.push(this.numTokens),t):(t[t.length-1]=this.numTokens,t)}call(t,e){return s((()=>{let n;if("int32"!==(t=Oe(t)).dtype&&(t=Xt(t,"int32")),"undefined"!=typeof e.countWeights){if("count"!==this.outputMode)throw new it(`countWeights is not used when outputMode !== count.\n Received countWeights=${e.countWeights}`);n=Oe(e.countWeights)}const s=Q(t),i=tt(t),r=R(this.numTokens,s).bufferSync().get(0),a=V(i,0).bufferSync().get(0);if(!r||!a)throw new it(`Input values must be between 0 < values <= numTokens with numTokens=${this.numTokens}`);return function(t,e,n,s){let i=Oe(t);if("int32"!==i.dtype&&(i=Xt(i,"int32")),"int"===e)return i;const r=i.shape;if(0===i.rank&&(i=Y(i,-1)),"oneHot"===e&&1!==i.shape[i.shape.length-1]&&(i=Y(i,-1)),i.rank>2)throw new it(`When outputMode is not int, maximum output rank is 2 Received outputMode ${e} and input shape ${r} which would result in output rank ${i.rank}.`);const a=["multiHot","oneHot"].includes(e);let o;if(o=X(i,"undefined"!=typeof s&&"count"===e?s:[],n,a),"tfIdf"!==e)return o;if(s)return l(o,s);throw new it("When outputMode is 'tfIdf', weights must be provided.")}(t,this.outputMode,this.numTokens,n)}))}}Nc.className="CategoryEncoding",r.registerClass(Nc);const Ic=new Set(["bilinear","nearest"]);class Ac extends Je{constructor(t){if(super(t),this.height=t.height,this.width=t.width,t.interpolation){if(!Ic.has(t.interpolation))throw new it(`Invalid interpolation parameter: ${t.interpolation} is not implemented`);this.interpolation=t.interpolation}else this.interpolation="bilinear";this.cropToAspectRatio=Boolean(t.cropToAspectRatio)}computeOutputShape(t){const e=(t=Be(t))[2];return[this.height,this.width,e]}getConfig(){const t={height:this.height,width:this.width,interpolation:this.interpolation,cropToAspectRatio:this.cropToAspectRatio},e=super.getConfig();return Object.assign(t,e),t}call(t,e){return s((()=>{const e=[this.height,this.width];if("bilinear"===this.interpolation)return Z.resizeBilinear(t,e,!this.cropToAspectRatio);if("nearest"===this.interpolation)return Z.resizeNearestNeighbor(t,e,!this.cropToAspectRatio);throw new Error(`Interpolation is ${this.interpolation} but only ${[...Ic]} are supported`)}))}}Ac.className="Resizing",r.registerClass(Ac);class Tc{constructor(t){this.seed=t}next(){if(void 0!==this.seed)return this.seed++}}Tc.className="RandomSeed";class Ec extends Je{constructor(t){super(t),this.randomGenerator=new Tc(t.seed)}getConfig(){const t={seed:this.randomGenerator.seed},e=super.getConfig();return Object.assign(t,e),t}}Ec.className="BaseRandomLayer";const Cc=new Set(["bilinear","nearest"]);class zc extends Ec{constructor(t){super(t);const{factor:e,interpolation:n="bilinear"}=t;if(this.factor=e,Array.isArray(this.factor)&&2===this.factor.length)this.widthLower=this.factor[0],this.widthUpper=this.factor[1];else{if(Array.isArray(this.factor)||!(this.factor>0))throw new it(`Invalid factor: ${this.factor}. Must be positive number or tuple of 2 numbers`);this.widthLower=-this.factor,this.widthUpper=this.factor}if(this.widthLower<-1||this.widthUpper<-1)throw new it(`factor must have values larger than -1. Got: ${this.factor}`);if(this.widthUpper{const e=Oe(t);this.imgHeight=e.shape[e.shape.length-3];const n=e.shape[e.shape.length-2];this.widthFactor=h([1],1+this.widthLower,1+this.widthUpper,"float32",this.randomGenerator.next());let s=this.widthFactor.dataSync()[0]*n;s=Math.round(s);const i=[this.imgHeight,s];switch(this.interpolation){case"bilinear":return Z.resizeBilinear(t,i);case"nearest":return Z.resizeNearestNeighbor(t,i);default:throw new Error(`Interpolation is ${this.interpolation}\n but only ${[...Cc]} are supported`)}}))}}function $c(t){return new ic(t)}function Fc(t){return new oc(t)}function Dc(t){return new hc(t)}function Lc(t){return new dc(t)}function _c(t){return new mc(t)}function Rc(t){return new sc(t)}function Mc(t){return new ac(t)}zc.className="RandomWidth",r.registerClass(zc);var Oc={__proto__:null,Layer:Je,RNN:dh,RNNCell:fh,activation:function(t){return new $h(t)},add:function(t){return new Oh(t)},alphaDropout:function(t){return new Hh(t)},average:function(t){return new Ph(t)},averagePooling1d:$c,averagePooling2d:Fc,averagePooling3d:Dc,avgPool1d:function(t){return $c(t)},avgPool2d:function(t){return Fc(t)},avgPool3d:function(t){return Dc(t)},avgPooling1d:function(t){return $c(t)},avgPooling2d:function(t){return Fc(t)},avgPooling3d:function(t){return Dc(t)},batchNormalization:function(t){return new Yh(t)},bidirectional:function(t){return new wc(t)},categoryEncoding:function(t){return new Nc(t)},centerCrop:function(t){return new xc(t)},concatenate:function(t){return new jh(t)},conv1d:function(t){return new oh(t)},conv2d:function(t){return new eh(t)},conv2dTranspose:function(t){return new sh(t)},conv3d:function(t){return new nh(t)},conv3dTranspose:function(t){return new ih(t)},convLstm2d:function(t){return new Ah(t)},convLstm2dCell:function(t){return new Ih(t)},cropping2D:function(t){return new lh(t)},dense:function(t){return new Ch(t)},depthwiseConv2d:function(t){return new hh(t)},dot:function(t){return new qh(t)},dropout:function(t){return new Th(t)},elu:function(t){return new Vu(t)},embedding:function(t){return new Rh(t)},flatten:function(t){return new zh(t)},gaussianDropout:function(t){return new Gh(t)},gaussianNoise:function(t){return new Kh(t)},globalAveragePooling1d:function(t){return new pc(t)},globalAveragePooling2d:function(t){return new gc(t)},globalMaxPool1d:Lc,globalMaxPool2d:_c,globalMaxPooling1d:Lc,globalMaxPooling2d:_c,gru:function(t){return new bh(t)},gruCell:function(t){return new yh(t)},input:fu,inputLayer:function(t){return new Ye(t)},layerNormalization:function(t){return new Xh(t)},leakyReLU:function(t){return new Wu(t)},lstm:function(t){return new kh(t)},lstmCell:function(t){return new wh(t)},masking:function(t){return new _h(t)},maxPool1d:Rc,maxPool2d:Mc,maxPooling1d:Rc,maxPooling2d:Mc,maxPooling3d:function(t){return new uc(t)},maximum:function(t){return new Uh(t)},minimum:function(t){return new Wh(t)},multiply:function(t){return new Bh(t)},permute:function(t){return new Lh(t)},prelu:function(t){return new ju(t)},randomWidth:function(t){return new zc(t)},reLU:function(t){return new Uu(t)},repeatVector:function(t){return new Fh(t)},rescaling:function(t){return new kc(t)},reshape:function(t){return new Dh(t)},resizing:function(t){return new Ac(t)},rnn:function(t){return new dh(t)},separableConv2d:function(t){return new ah(t)},simpleRNN:function(t){return new mh(t)},simpleRNNCell:function(t){return new gh(t)},softmax:function(t){return new Ku(t)},spatialDropout1d:function(t){return new Eh(t)},stackedRNNCells:function(t){return new vh(t)},thresholdedReLU:function(t){return new qu(t)},timeDistributed:function(t){return new bc(t)},upSampling2d:function(t){return new uh(t)},zeroPadding2d:function(t){return new Qh(t)}};var Bc={__proto__:null,MAPE:function(t,e){return gl(t,e)},MSE:function(t,e){return dl(t,e)},binaryAccuracy:function(t,e){return Sl(t,e)},binaryCrossentropy:function(t,e){return Tl(t,e)},categoricalAccuracy:function(t,e){return xl(t,e)},categoricalCrossentropy:function(t,e){return Cl(t,e)},cosineProximity:function(t,e){return wl(t,e)},mape:function(t,e){return gl(t,e)},meanAbsoluteError:function(t,e){return fl(t,e)},meanAbsolutePercentageError:function(t,e){return gl(t,e)},meanSquaredError:function(t,e){return dl(t,e)},mse:function(t,e){return dl(t,e)},precision:function(t,e){return Il(t,e)},recall:function(t,e){return Al(t,e)},sparseCategoricalAccuracy:function(t,e){return El(t,e)}},Pc={__proto__:null,modelFromJSON:async function(t,e){"modelTopology"in t||(t={modelTopology:t});let n=t.modelTopology;null!=n.model_config&&(n=n.model_config);const s=cl(Ul(n),e);if(null!=t.weightsManifest){const e=await E.loadWeights(t.weightsManifest,t.pathPrefix,s.weights.map((t=>t.originalName))),n={};for(const t of s.weights)n[t.originalName]=e[t.originalName];s.loadWeights(n),f(e)}return s}};var Uc={__proto__:null,l1:function(t){return Lu(e=t),new Ru({l1:null!=e?e.l1:null,l2:0});var e},l1l2:function(t){return new Ru(t)},l2:function(t){return Lu(e=t),new Ru({l2:null!=e?e.l2:null,l1:0});var e}};class Wc extends sl{constructor(){super(...arguments),this.model=null}setModel(t){if(!(t instanceof lu))throw new Error("model must be a LayersModel, not some other Container");this.model=t}}function jc(t,e){return te}class qc extends Wc{constructor(t){if(super(),null==t&&(t={}),t.restoreBestWeights)throw new rt("restoreBestWeights = True is not implemented in EarlyStopping yet.");this.monitor=t.monitor||"val_loss",this.minDelta=Math.abs(t.minDelta||0),this.patience=t.patience||0,this.verbose=t.verbose||0,this.mode=t.mode||"auto",this.baseline=t.baseline,-1===["auto","min","max"].indexOf(this.mode)&&(console.warn(`EarlyStopping mode '${this.mode}' is invalid. Falling back to mode 'auto'.`),this.mode="auto"),"min"===this.mode?this.monitorFunc=jc:"max"===this.mode||-1!==this.monitor.indexOf("acc")?this.monitorFunc=Vc:this.monitorFunc=jc,this.monitorFunc===jc&&(this.minDelta*=-1)}async onTrainBegin(t){this.wait=0,this.stoppedEpoch=0,null!=this.baseline?this.best=this.baseline:this.best=this.monitorFunc===jc?1/0:-1/0}async onEpochEnd(t,e){await el(e);const n=this.getMonitorValue(e);null!=n&&(this.monitorFunc(n-this.minDelta,this.best)?(this.best=n,this.wait=0):(this.wait++,this.wait>=this.patience&&(this.stoppedEpoch=t,this.model.stopTraining=!0)))}async onTrainEnd(t){this.stoppedEpoch>0&&this.verbose&&console.log(`Epoch ${this.stoppedEpoch}: early stopping.`)}getMonitorValue(t){null==t&&(t={});const e=t[this.monitor];return null==e&&console.warn(`Metric for EarlyStopping ${this.monitor} is not available. Available metrics are: ${Object.keys(t)}`),e}}const Kc={earlyStopping:function(t){return new qc(t)}};export{Wc as Callback,il as CallbackList,ol as CustomCallback,qc as EarlyStopping,al as History,Ve as InputSpec,Ue as LayerVariable,lu as LayersModel,dh as RNN,cu as Sequential,qe as SymbolicTensor,Kc as callbacks,Xo as constraints,tl as initializers,fu as input,Oc as layers,hu as loadLayersModel,Bc as metrics,pu as model,Pc as models,gu as registerCallbackConstructor,Uc as regularizers,du as sequential,jl as version_layers}; //# sourceMappingURL=tf-layers.fesm.min.js.map