Fixing ModelLoaderUtils.split() to pass tests (#126009)

Prior to these changes, the split method would fail tests. Additionally,
the method had code which could be refactored.

A new variable (numRanges) was introduced to replace the direct usage of numStreams.
The method was refactored to make the code easier to understand. Javadocs were updated.
Tests for this method now pass.
This commit is contained in:
Jason-Whitmore 2025-04-09 12:37:15 -07:00 committed by GitHub
parent 6c6500ec3b
commit 36280d2630
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 23 additions and 23 deletions

View file

@ -0,0 +1,6 @@
pr: 126009
summary: Change ModelLoaderUtils.split to return the correct number of chunks and ranges.
area: Machine Learning
type: bug
issues:
- 121799

View file

@ -336,50 +336,44 @@ final class ModelLoaderUtils {
* Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1
* ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a
* whole number of chunks.
* The first {@code numberOfStreams} ranges will be split evenly (in terms of
* number of chunks not the byte size), the final range split
* All ranges except the final range will be split approximately evenly
* (in terms of number of chunks not the byte size), the final range split
* is for the single final chunk and will be no more than {@code chunkSizeBytes}
* in size. The separate range for the final chunk is because when streaming and
* uploading a large model definition, writing the last part has to handled
* as a special case.
* Less ranges may be returned in case the stream size is too small.
* Fewer ranges may be returned in case the stream size is too small.
* @param sizeInBytes The total size of the stream
* @param numberOfStreams Divide the bulk of the size into this many streams.
* @param chunkSizeBytes The size of each chunk
* @return List of {@code numberOfStreams} + 1 ranges.
* @return List of {@code numberOfStreams} + 1 or fewer ranges.
*/
static List<RequestRange> split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) {
int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes);
int numberOfRanges = numberOfStreams + 1;
if (numberOfStreams > numberOfChunks) {
numberOfStreams = numberOfChunks;
numberOfRanges = numberOfChunks;
}
var ranges = new ArrayList<RequestRange>();
int baseChunksPerStream = numberOfChunks / numberOfStreams;
int remainder = numberOfChunks % numberOfStreams;
int baseChunksPerRange = (numberOfChunks - 1) / (numberOfRanges - 1);
int remainder = (numberOfChunks - 1) % (numberOfRanges - 1);
long startOffset = 0;
int startChunkIndex = 0;
for (int i = 0; i < numberOfStreams - 1; i++) {
int numChunksInStream = (i < remainder) ? baseChunksPerStream + 1 : baseChunksPerStream;
long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes) - 1; // range index is 0 based
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInStream));
for (int i = 0; i < numberOfRanges - 1; i++) {
int numChunksInRange = (i < remainder) ? baseChunksPerRange + 1 : baseChunksPerRange;
long rangeEnd = startOffset + (((long) numChunksInRange) * chunkSizeBytes) - 1; // range index is 0 based
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInRange));
startOffset = rangeEnd + 1; // range is inclusive start and end
startChunkIndex += numChunksInStream;
}
// Want the final range request to be a single chunk
if (baseChunksPerStream > 1) {
int numChunksExcludingFinal = baseChunksPerStream - 1;
long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes) - 1;
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksExcludingFinal));
startOffset = rangeEnd + 1;
startChunkIndex += numChunksExcludingFinal;
startChunkIndex += numChunksInRange;
}
// The final range is a single chunk the end of which should not exceed sizeInBytes
long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerStream * chunkSizeBytes)) - 1;
long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerRange * chunkSizeBytes)) - 1;
ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, 1));
return ranges;