Batching range adaptor?

It’s a minor thing, but a lot of the workbooks (e.g. 03) have duplicate code that looks like this:

    for i in 0 ..< (n-1)/bs {
        let startIdx = i * bs
        let endIdx = startIdx + bs
        let xb = xTrain[startIdx..<endIdx]
        let yb = yTrain[startIdx..<endIdx]

It would be much nicer if we had some sort of batching range, allowing us to write something along the lines of:

    for indices in batchedIndices(n, bs) {
        let xb = xTrain[indices]
        let yb = yTrain[indices]

Anyone interested in a project?

I’ll have a look. I’ve been looking at the data flow API thinking about how to make it more functional and almost certainly less generic, and this is the same sort of thing.

Awesome, thanks Alexis!

Note that this was just in the process of progressively refactoring like in the python version, so it’s doing by hand what Dataset.batched does a bit later.

As @sgugger mentioned, often the dupe code is intentional, so we can show how it’s refactored later. However if we don’t refactor it later in the notebook, then that’s a mistake! :slight_smile:

Okay, here’s something that does the trick. https://gist.github.com/algal/0fe7de7a083cdac102161d1503aab44f

I suspect his could be made much shorter just by defining a closure of AnySequence.

But it seems like there might be a didactic purpose behind not simplifying some of the loops, so I’ll just leave it as is for now. (Unless I can’t stop myself and start tweaking it tonight. :stuck_out_tongue_winking_eye: )

The gist above defines a batchedIndices(n:Int,batchSize:Int), which works like you suggested, as well as a generic batchedIndices<C:Collection>(coll:C,batchSize:Int), which might be safer to use in case someone wanted to generates batches over a collection where startIndex was not zero (which can happen with ArraySlice, if I recall correctly).

Feel free to show a tweaked version as well so we can compare :slight_smile:

Looks nice :slight_smile: Does __consuming do anything at the moment?

No, it’s not strictly necessary right now.

It’s an annotation to indicate that once you call makeIterator() on an object, then the object is invalid for producing another iterator. This is the usual contract for Sequence so I stuck it in out of habit and following standard library patterns.

But now that I think of it, it’s not actually true in this case, since RangeStrideThrough has no mutable state.

So in short, it shouldn’t be there. But it has no effect now. But might produce a performance cost in some future version of Swift that was trying to be more clever about move operations.

Okay, Jeremy. Now you’ve done it. You’ve provoked my worst impulses. :wink:

This does the same thing but is perhaps not a joy to read:

func batchedIndices<C:Collection>(_ coll:C, _ batchSize:Int) -> AnySequence<Range<Int>>
  where C.Index == Int
{
  var startIndex = coll.startIndex

  return AnySequence.init {
    () -> AnyIterator<Range<Int>> in
    return  AnyIterator.init {
      () -> Range<Int>? in
      let remaining = coll.endIndex - startIndex
      guard remaining > 0 else { return nil }
      let thisBatchSize = min(batchSize,remaining)
      let thisEndIndex = startIndex.advanced(by: thisBatchSize)
      defer {  startIndex = thisEndIndex  }
      return startIndex ..< thisEndIndex
    }
  }
}
2 Likes

Very nice Alexis, thank you! would it be any cleaner to make use of this function in the standard lib?:

func sequence<T>(first: T, next: @escaping (T) -> T?)

-Chris

Interesting. I didn’t know about that API! That’s not quite what’s needed but there’s another one, func sequence<T, State>(state: State, next: @escaping (inout State) -> T?) which does the trick. That gives you the following

private func batchedRanges(startIndex:Int, endIndex:Int, batchSize:Int) -> UnfoldSequence<Range<Int>,Int>
{
  return sequence(state: startIndex) { (batchStartIndex) -> Range<Int>? in
    let remaining = endIndex - batchStartIndex
    guard remaining > 0 else { return nil}
    let currentBatchSize = min(batchSize,remaining)
    let batchEndIndex = batchStartIndex.advanced(by: currentBatchSize)
    defer {  batchStartIndex = batchEndIndex  }
    return batchStartIndex ..< batchEndIndex
  }
}

This is better. The mutable state is called out in the API, rather than hiding in a closed-over local function variable.

2 Likes

Nice! If you’re curious, they were introduced in Swift 3.

-Chris

Cool. Not sure how I missed it.

I also like this SEP because it is another piece of evidence for my semi-facetious thesis that Swift has every functional programming language feature which people no longer think of as functional programming, either because it has become mainstream, or because has been renamed.

to be fair, the name is entirely undiscoverable… :frowning: