import * as util from '../util';
|
/**
|
* Wraps a list of ArrayBuffers into a `slice()`-able object without allocating
|
* a large ArrayBuffer.
|
*
|
* Allocating large ArrayBuffers (~2GB) can be unstable on Chrome. TFJS loads
|
* its weights as a list of (usually) 4MB ArrayBuffers and then slices the
|
* weight tensors out of them. For small models, it's safe to concatenate all
|
* the weight buffers into a single ArrayBuffer and then slice the weight
|
* tensors out of it, but for large models, a different approach is needed.
|
*/
|
export class CompositeArrayBuffer {
|
/**
|
* Concatenate a number of ArrayBuffers into one.
|
*
|
* @param buffers An array of ArrayBuffers to concatenate, or a single
|
* ArrayBuffer.
|
* @returns Result of concatenating `buffers` in order.
|
*/
|
static join(buffers) {
|
return new CompositeArrayBuffer(buffers).slice();
|
}
|
constructor(buffers) {
|
this.shards = [];
|
this.previousShardIndex = 0;
|
if (buffers == null) {
|
return;
|
}
|
// Normalize the `buffers` input to be `ArrayBuffer[]`.
|
if (!(buffers instanceof Array)) {
|
buffers = [buffers];
|
}
|
buffers = buffers.map((bufferOrTypedArray) => {
|
if (util.isTypedArray(bufferOrTypedArray)) {
|
return bufferOrTypedArray.buffer;
|
}
|
return bufferOrTypedArray;
|
});
|
// Skip setting up shards if there are no buffers.
|
if (buffers.length === 0) {
|
return;
|
}
|
this.bufferUniformSize = buffers[0].byteLength;
|
let start = 0;
|
for (let i = 0; i < buffers.length; i++) {
|
const buffer = buffers[i];
|
// Check that all buffers except the last one have the same length.
|
if (i !== buffers.length - 1 &&
|
buffer.byteLength !== this.bufferUniformSize) {
|
// Unset the buffer uniform size, since the buffer sizes are not
|
// uniform.
|
this.bufferUniformSize = undefined;
|
}
|
// Create the shards, including their start and end points.
|
const end = start + buffer.byteLength;
|
this.shards.push({ buffer, start, end });
|
start = end;
|
}
|
// Set the byteLenghth
|
if (this.shards.length === 0) {
|
this.byteLength = 0;
|
}
|
this.byteLength = this.shards[this.shards.length - 1].end;
|
}
|
slice(start = 0, end = this.byteLength) {
|
// If there are no shards, then the CompositeArrayBuffer was initialized
|
// with no data.
|
if (this.shards.length === 0) {
|
return new ArrayBuffer(0);
|
}
|
// NaN is treated as zero for slicing. This matches ArrayBuffer's behavior.
|
start = isNaN(Number(start)) ? 0 : start;
|
end = isNaN(Number(end)) ? 0 : end;
|
// Fix the bounds to within the array.
|
start = Math.max(0, start);
|
end = Math.min(this.byteLength, end);
|
if (end <= start) {
|
return new ArrayBuffer(0);
|
}
|
const startShardIndex = this.findShardForByte(start);
|
if (startShardIndex === -1) {
|
// This should not happen since the start and end indices are always
|
// within 0 and the composite array's length.
|
throw new Error(`Could not find start shard for byte ${start}`);
|
}
|
const size = end - start;
|
const outputBuffer = new ArrayBuffer(size);
|
const outputArray = new Uint8Array(outputBuffer);
|
let sliced = 0;
|
for (let i = startShardIndex; i < this.shards.length; i++) {
|
const shard = this.shards[i];
|
const globalStart = start + sliced;
|
const localStart = globalStart - shard.start;
|
const outputStart = sliced;
|
const globalEnd = Math.min(end, shard.end);
|
const localEnd = globalEnd - shard.start;
|
const outputSlice = new Uint8Array(shard.buffer, localStart, localEnd - localStart);
|
outputArray.set(outputSlice, outputStart);
|
sliced += outputSlice.length;
|
if (end < shard.end) {
|
break;
|
}
|
}
|
return outputBuffer;
|
}
|
/**
|
* Get the index of the shard that contains the byte at `byteIndex`.
|
*/
|
findShardForByte(byteIndex) {
|
if (this.shards.length === 0 || byteIndex < 0 ||
|
byteIndex >= this.byteLength) {
|
return -1;
|
}
|
// If the buffers have a uniform size, compute the shard directly.
|
if (this.bufferUniformSize != null) {
|
this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize);
|
return this.previousShardIndex;
|
}
|
// If the buffers don't have a uniform size, we need to search for the
|
// shard. That means we need a function to check where the byteIndex lies
|
// relative to a given shard.
|
function check(shard) {
|
if (byteIndex < shard.start) {
|
return -1;
|
}
|
if (byteIndex >= shard.end) {
|
return 1;
|
}
|
return 0;
|
}
|
// For efficiency, try the previous shard first.
|
if (check(this.shards[this.previousShardIndex]) === 0) {
|
return this.previousShardIndex;
|
}
|
// Otherwise, use a generic search function.
|
// This should almost never end up being used in practice since the weight
|
// entries should always be in order.
|
const index = search(this.shards, check);
|
if (index === -1) {
|
return -1;
|
}
|
this.previousShardIndex = index;
|
return this.previousShardIndex;
|
}
|
}
|
/**
|
* Search for an element of a sorted array.
|
*
|
* @param sortedArray The sorted array to search
|
* @param compare A function to compare the current value against the searched
|
* value. Return 0 on a match, negative if the searched value is less than
|
* the value passed to the function, and positive if the searched value is
|
* greater than the value passed to the function.
|
* @returns The index of the element, or -1 if it's not in the array.
|
*/
|
export function search(sortedArray, compare) {
|
// Binary search
|
let min = 0;
|
let max = sortedArray.length;
|
while (min <= max) {
|
const middle = Math.floor((max - min) / 2) + min;
|
const side = compare(sortedArray[middle]);
|
if (side === 0) {
|
return middle;
|
}
|
else if (side < 0) {
|
max = middle;
|
}
|
else {
|
min = middle + 1;
|
}
|
}
|
return -1;
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"composite_array_buffer.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/io/composite_array_buffer.ts"],"names":[],"mappings":"AAiBA,OAAO,KAAK,IAAI,MAAM,SAAS,CAAC;AAQhC;;;;;;;;;GASG;AAEH,MAAM,OAAO,oBAAoB;IAM/B;;;;;;OAMG;IACH,MAAM,CAAC,IAAI,CAAC,OAAqC;QAC/C,OAAO,IAAI,oBAAoB,CAAC,OAAO,CAAC,CAAC,KAAK,EAAE,CAAC;IACnD,CAAC;IAED,YAAY,OACE;QAjBN,WAAM,GAAkB,EAAE,CAAC;QAC3B,uBAAkB,GAAG,CAAC,CAAC;QAiB7B,IAAI,OAAO,IAAI,IAAI,EAAE;YACnB,OAAO;SACR;QACD,uDAAuD;QACvD,IAAI,CAAC,CAAC,OAAO,YAAY,KAAK,CAAC,EAAE;YAC/B,OAAO,GAAG,CAAC,OAAO,CAAC,CAAC;SACrB;QACD,OAAO,GAAG,OAAO,CAAC,GAAG,CAAC,CAAC,kBAAkB,EAAE,EAAE;YAC3C,IAAI,IAAI,CAAC,YAAY,CAAC,kBAAkB,CAAC,EAAE;gBACzC,OAAO,kBAAkB,CAAC,MAAM,CAAC;aAClC;YACD,OAAO,kBAAkB,CAAC;QAC5B,CAAC,CAAC,CAAC;QAEH,kDAAkD;QAClD,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE;YACxB,OAAO;SACR;QAED,IAAI,CAAC,iBAAiB,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC;QAC/C,IAAI,KAAK,GAAG,CAAC,CAAC;QAEd,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;YACvC,MAAM,MAAM,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;YAC1B,mEAAmE;YACnE,IAAI,CAAC,KAAK,OAAO,CAAC,MAAM,GAAG,CAAC;gBAC1B,MAAM,CAAC,UAAU,KAAK,IAAI,CAAC,iBAAiB,EAAE;gBAC9C,gEAAgE;gBAChE,WAAW;gBACX,IAAI,CAAC,iBAAiB,GAAG,SAAS,CAAC;aACpC;YAED,2DAA2D;YAC3D,MAAM,GAAG,GAAG,KAAK,GAAG,MAAM,CAAC,UAAU,CAAC;YACtC,IAAI,CAAC,MAAM,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,KAAK,EAAE,GAAG,EAAE,CAAC,CAAC;YACzC,KAAK,GAAG,GAAG,CAAC;SACb;QAED,sBAAsB;QACtB,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE;YAC5B,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC;SACrB;QACD,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,GAAG,CAAC;IAC5D,CAAC;IAED,KAAK,CAAC,KAAK,GAAG,CAAC,EAAE,GAAG,GAAG,IAAI,CAAC,UAAU;QACpC,wEAAwE;QACxE,gBAAgB;QAChB,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE;YAC5B,OAAO,IAAI,WAAW,CAAC,CAAC,CAAC,CAAC;SAC3B;QAED,2EAA2E;QAC3E,KAAK,GAAG,KAAK,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC;QACzC,GAAG,GAAG,KAAK,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC;QAEnC,sCAAsC;QACtC,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;QAC3B,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,IAAI,CAAC,UAAU,EAAE,GAAG,CAAC,CAAC;QACrC,IAAI,GAAG,IAAI,KAAK,EAAE;YAChB,OAAO,IAAI,WAAW,CAAC,CAAC,CAAC,CAAC;SAC3B;QAED,MAAM,eAAe,GAAG,IAAI,CAAC,gBAAgB,CAAC,KAAK,CAAC,CAAC;QACrD,IAAI,eAAe,KAAK,CAAC,CAAC,EAAE;YAC1B,oEAAoE;YACpE,6CAA6C;YAC7C,MAAM,IAAI,KAAK,CAAC,uCAAuC,KAAK,EAAE,CAAC,CAAC;SACjE;QAED,MAAM,IAAI,GAAG,GAAG,GAAG,KAAK,CAAC;QACzB,MAAM,YAAY,GAAG,IAAI,WAAW,CAAC,IAAI,CAAC,CAAC;QAC3C,MAAM,WAAW,GAAG,IAAI,UAAU,CAAC,YAAY,CAAC,CAAC;QACjD,IAAI,MAAM,GAAG,CAAC,CAAC;QACf,KAAK,IAAI,CAAC,GAAG,eAAe,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;YACzD,MAAM,KAAK,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC;YAE7B,MAAM,WAAW,GAAG,KAAK,GAAG,MAAM,CAAC;YACnC,MAAM,UAAU,GAAG,WAAW,GAAG,KAAK,CAAC,KAAK,CAAC;YAC7C,MAAM,WAAW,GAAG,MAAM,CAAC;YAE3B,MAAM,SAAS,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,EAAE,KAAK,CAAC,GAAG,CAAC,CAAC;YAC3C,MAAM,QAAQ,GAAG,SAAS,GAAG,KAAK,CAAC,KAAK,CAAC;YAEzC,MAAM,WAAW,GAAG,IAAI,UAAU,CAAC,KAAK,CAAC,MAAM,EAAE,UAAU,EACxB,QAAQ,GAAG,UAAU,CAAC,CAAC;YAC1D,WAAW,CAAC,GAAG,CAAC,WAAW,EAAE,WAAW,CAAC,CAAC;YAC1C,MAAM,IAAI,WAAW,CAAC,MAAM,CAAC;YAE7B,IAAI,GAAG,GAAG,KAAK,CAAC,GAAG,EAAE;gBACnB,MAAM;aACP;SACF;QACD,OAAO,YAAY,CAAC;IACtB,CAAC;IAED;;OAEG;IACK,gBAAgB,CAAC,SAAiB;QACxC,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,KAAK,CAAC,IAAI,SAAS,GAAG,CAAC;YAC3C,SAAS,IAAI,IAAI,CAAC,UAAU,EAAE;YAC9B,OAAO,CAAC,CAAC,CAAC;SACX;QAED,kEAAkE;QAClE,IAAI,IAAI,CAAC,iBAAiB,IAAI,IAAI,EAAE;YAClC,IAAI,CAAC,kBAAkB,GAAG,IAAI,CAAC,KAAK,CAAC,SAAS,GAAG,IAAI,CAAC,iBAAiB,CAAC,CAAC;YACzE,OAAO,IAAI,CAAC,kBAAkB,CAAC;SAChC;QAED,sEAAsE;QACtE,yEAAyE;QACzE,6BAA6B;QAC7B,SAAS,KAAK,CAAC,KAAkB;YAC/B,IAAI,SAAS,GAAG,KAAK,CAAC,KAAK,EAAE;gBAC3B,OAAO,CAAC,CAAC,CAAC;aACX;YACD,IAAI,SAAS,IAAI,KAAK,CAAC,GAAG,EAAE;gBAC1B,OAAO,CAAC,CAAC;aACV;YACD,OAAO,CAAC,CAAC;QACX,CAAC;QAED,gDAAgD;QAChD,IAAI,KAAK,CAAC,IAAI,CAAC,MAAM,CAAC,IAAI,CAAC,kBAAkB,CAAC,CAAC,KAAK,CAAC,EAAE;YACrD,OAAO,IAAI,CAAC,kBAAkB,CAAC;SAChC;QAED,4CAA4C;QAC5C,0EAA0E;QAC1E,qCAAqC;QACrC,MAAM,KAAK,GAAG,MAAM,CAAC,IAAI,CAAC,MAAM,EAAE,KAAK,CAAC,CAAC;QACzC,IAAI,KAAK,KAAK,CAAC,CAAC,EAAE;YAChB,OAAO,CAAC,CAAC,CAAC;SACX;QAED,IAAI,CAAC,kBAAkB,GAAG,KAAK,CAAC;QAChC,OAAO,IAAI,CAAC,kBAAkB,CAAC;IACjC,CAAC;CACF;AAED;;;;;;;;;GASG;AACH,MAAM,UAAU,MAAM,CAAI,WAAgB,EAAE,OAAyB;IACnE,gBAAgB;IAChB,IAAI,GAAG,GAAG,CAAC,CAAC;IACZ,IAAI,GAAG,GAAG,WAAW,CAAC,MAAM,CAAC;IAE7B,OAAO,GAAG,IAAI,GAAG,EAAE;QACjB,MAAM,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,GAAG,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,GAAG,CAAC;QACjD,MAAM,IAAI,GAAG,OAAO,CAAC,WAAW,CAAC,MAAM,CAAC,CAAC,CAAC;QAE1C,IAAI,IAAI,KAAK,CAAC,EAAE;YACd,OAAO,MAAM,CAAC;SACf;aAAM,IAAI,IAAI,GAAG,CAAC,EAAE;YACnB,GAAG,GAAG,MAAM,CAAC;SACd;aAAM;YACL,GAAG,GAAG,MAAM,GAAG,CAAC,CAAC;SAClB;KACF;IACD,OAAO,CAAC,CAAC,CAAC;AACZ,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\nimport {TypedArray} from '../types';\nimport * as util from '../util';\n\ntype BufferShard = {\n  start: number,\n  end: number,\n  buffer: ArrayBuffer,\n};\n\n/**\n * Wraps a list of ArrayBuffers into a `slice()`-able object without allocating\n * a large ArrayBuffer.\n *\n * Allocating large ArrayBuffers (~2GB) can be unstable on Chrome. TFJS loads\n * its weights as a list of (usually) 4MB ArrayBuffers and then slices the\n * weight tensors out of them. For small models, it's safe to concatenate all\n * the weight buffers into a single ArrayBuffer and then slice the weight\n * tensors out of it, but for large models, a different approach is needed.\n */\n\nexport class CompositeArrayBuffer {\n  private shards: BufferShard[] = [];\n  private previousShardIndex = 0;\n  private bufferUniformSize?: number;\n  public readonly byteLength: number;\n\n  /**\n   * Concatenate a number of ArrayBuffers into one.\n   *\n   * @param buffers An array of ArrayBuffers to concatenate, or a single\n   *     ArrayBuffer.\n   * @returns Result of concatenating `buffers` in order.\n   */\n  static join(buffers?: ArrayBuffer[] | ArrayBuffer) {\n    return new CompositeArrayBuffer(buffers).slice();\n  }\n\n  constructor(buffers?: ArrayBuffer | ArrayBuffer[] | TypedArray |\n    TypedArray[]) {\n    if (buffers == null) {\n      return;\n    }\n    // Normalize the `buffers` input to be `ArrayBuffer[]`.\n    if (!(buffers instanceof Array)) {\n      buffers = [buffers];\n    }\n    buffers = buffers.map((bufferOrTypedArray) => {\n      if (util.isTypedArray(bufferOrTypedArray)) {\n        return bufferOrTypedArray.buffer;\n      }\n      return bufferOrTypedArray;\n    });\n\n    // Skip setting up shards if there are no buffers.\n    if (buffers.length === 0) {\n      return;\n    }\n\n    this.bufferUniformSize = buffers[0].byteLength;\n    let start = 0;\n\n    for (let i = 0; i < buffers.length; i++) {\n      const buffer = buffers[i];\n      // Check that all buffers except the last one have the same length.\n      if (i !== buffers.length - 1 &&\n        buffer.byteLength !== this.bufferUniformSize) {\n        // Unset the buffer uniform size, since the buffer sizes are not\n        // uniform.\n        this.bufferUniformSize = undefined;\n      }\n\n      // Create the shards, including their start and end points.\n      const end = start + buffer.byteLength;\n      this.shards.push({ buffer, start, end });\n      start = end;\n    }\n\n    // Set the byteLenghth\n    if (this.shards.length === 0) {\n      this.byteLength = 0;\n    }\n    this.byteLength = this.shards[this.shards.length - 1].end;\n  }\n\n  slice(start = 0, end = this.byteLength): ArrayBuffer {\n    // If there are no shards, then the CompositeArrayBuffer was initialized\n    // with no data.\n    if (this.shards.length === 0) {\n      return new ArrayBuffer(0);\n    }\n\n    // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior.\n    start = isNaN(Number(start)) ? 0 : start;\n    end = isNaN(Number(end)) ? 0 : end;\n\n    // Fix the bounds to within the array.\n    start = Math.max(0, start);\n    end = Math.min(this.byteLength, end);\n    if (end <= start) {\n      return new ArrayBuffer(0);\n    }\n\n    const startShardIndex = this.findShardForByte(start);\n    if (startShardIndex === -1) {\n      // This should not happen since the start and end indices are always\n      // within 0 and the composite array's length.\n      throw new Error(`Could not find start shard for byte ${start}`);\n    }\n\n    const size = end - start;\n    const outputBuffer = new ArrayBuffer(size);\n    const outputArray = new Uint8Array(outputBuffer);\n    let sliced = 0;\n    for (let i = startShardIndex; i < this.shards.length; i++) {\n      const shard = this.shards[i];\n\n      const globalStart = start + sliced;\n      const localStart = globalStart - shard.start;\n      const outputStart = sliced;\n\n      const globalEnd = Math.min(end, shard.end);\n      const localEnd = globalEnd - shard.start;\n\n      const outputSlice = new Uint8Array(shard.buffer, localStart,\n                                         localEnd - localStart);\n      outputArray.set(outputSlice, outputStart);\n      sliced += outputSlice.length;\n\n      if (end < shard.end) {\n        break;\n      }\n    }\n    return outputBuffer;\n  }\n\n  /**\n   * Get the index of the shard that contains the byte at `byteIndex`.\n   */\n  private findShardForByte(byteIndex: number): number {\n    if (this.shards.length === 0 || byteIndex < 0 ||\n      byteIndex >= this.byteLength) {\n      return -1;\n    }\n\n    // If the buffers have a uniform size, compute the shard directly.\n    if (this.bufferUniformSize != null) {\n      this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize);\n      return this.previousShardIndex;\n    }\n\n    // If the buffers don't have a uniform size, we need to search for the\n    // shard. That means we need a function to check where the byteIndex lies\n    // relative to a given shard.\n    function check(shard: BufferShard) {\n      if (byteIndex < shard.start) {\n        return -1;\n      }\n      if (byteIndex >= shard.end) {\n        return 1;\n      }\n      return 0;\n    }\n\n    // For efficiency, try the previous shard first.\n    if (check(this.shards[this.previousShardIndex]) === 0) {\n      return this.previousShardIndex;\n    }\n\n    // Otherwise, use a generic search function.\n    // This should almost never end up being used in practice since the weight\n    // entries should always be in order.\n    const index = search(this.shards, check);\n    if (index === -1) {\n      return -1;\n    }\n\n    this.previousShardIndex = index;\n    return this.previousShardIndex;\n  }\n}\n\n/**\n * Search for an element of a sorted array.\n *\n * @param sortedArray The sorted array to search\n * @param compare A function to compare the current value against the searched\n *     value. Return 0 on a match, negative if the searched value is less than\n *     the value passed to the function, and positive if the searched value is\n *     greater than the value passed to the function.\n * @returns The index of the element, or -1 if it's not in the array.\n */\nexport function search<T>(sortedArray: T[], compare: (t: T) => number): number {\n  // Binary search\n  let min = 0;\n  let max = sortedArray.length;\n\n  while (min <= max) {\n    const middle = Math.floor((max - min) / 2) + min;\n    const side = compare(sortedArray[middle]);\n\n    if (side === 0) {\n      return middle;\n    } else if (side < 0) {\n      max = middle;\n    } else {\n      min = middle + 1;\n    }\n  }\n  return -1;\n}\n"]}
|