An update on v2 audio, it is mostly functional, we just need some pieces of the high-level API put on top, and to fix a few places we feel like we could be doing things better/faster/more v2-like but aren’t sure exactly how to get there. I’m going to make a series of posts about each issue, so if you have ideas/questions or want to participate in the discussion, even if you know nothing about audio, please jump in. Also any feedback on making the posts more useful/readable is appreciated.
Issue 1: Getting transforms to use TypeDispatch, RandTransform, and be GPU-friendly
We have a lot of transforms for AudioSignals and AudioSpectrograms that were originally made as simple independent functions. Many transforms for signals are the same as those for spectrograms, just with an extra dimension. For instance, we may want to cut out a section of the data, for AudioSignals this is “cutout”, for AudioSpectrograms it “time masking”, but they’re the same idea just on 2D (channels x samples) and 3D (channels x height x width) tensors respectively. Given the functionality of v2, we would like to refactor many of these to use RandTransform, TypeDispatching and be easily transferable to GPU.
I am having trouble figuring out exactly how to design a transform that works on signals and spectrograms, individual items and batches, implements RandTransform, and is fast. If I get one, I should be able to copy the pattern to the others. I tried looking in 09_data_augment but it wasn’t that clear to me how to make it work for my code.
Here is my best attempt so far. It’s for shifting the data horizontally (roll adds wraparound) It works properly, but is much slower than my original transforms, 557µs for a single item and 2.35ms for a batch of 32 signals (CPU), compared to 54µs for a single item previously.
- Is it more or less correctly implemented, and how can I improve it with respect to…
- TypeDispatch - Making it work for both signal/spectrogram simultaneously?
- GPU - Making it work on both batches and individual items?
- RandTransform - Am I using it properly here?
- Any ideas on how to make it faster? Code of the original transform included at bottom.
Code:
class SignalShifter(RandTransform):
def __init__(self, p=0.5, max_pct=0.2, max_time=None, direction=0, roll=False):
if direction not in [-1, 0, 1]: raise ValueError("Direction must be -1(left) 0(bidirectional) or 1(right)")
store_attr(self, "max_pct,max_time,direction,roll")
super().__init__(p=p, as_item=True)
def before_call(self, b, split_idx):
super().before_call(b, split_idx)
self.shift_factor = random.uniform(-1, 1)
if self.direction != 0: self.shift_factor = self.direction*abs(self.shift_factor)
def encodes(self, ai:AudioItem):
if self.max_time is None: s = self.shift_factor*self.max_pct*ai.nsamples
else: s = self.shift_factor*self.max_time*ai.sr
ai.sig[:] = shift_signal(ai.sig, int(s), self.roll)
return ai
def encodes(self, sg:AudioSpectrogram):
if self.max_time is None: s = self.shift_factor*self.max_pct*sg.width
else: s = self.shift_factor*self.max_time*sg.sr
return shift_signal(sg, int(s), self.roll)
def _shift(sig, s):
samples = sig.shape[-1]
if s == 0: return torch.clone(sig)
elif s < 0: return torch.cat([sig[...,-1*s:], torch.zeros_like(sig)[...,s:]], dim=-1)
else : return torch.cat([torch.zeros_like(sig)[...,:s], sig[...,:samples-s]], dim=-1
def shift_signal(t:torch.Tensor, shift, roll):
#refactor 2nd half of this statement to just take and roll the final axis
if roll: t[:] = torch.from_numpy(np.roll(t.numpy(), shift, axis=-1))
else : t[:] = _shift(t[:], shift)
return t
Here’s the original code that works only on a signal:
def _shift(sig, s):
channels, samples = sig.shape[-2:]
if s == 0: return torch.clone(sig)
elif s < 0: return torch.cat([sig[...,-1*s:], torch.zeros_like(sig)[...,s:]], dim=-1)
else : return torch.cat([torch.zeros_like(sig)[...,:s], sig[...,:samples-s]], dim=-1)
#export
def ShiftSignal(max_pct=0.2, max_time=None, roll=False):
def _inner(ai: AudioItem)->AudioItem:
s = int(random.uniform(-1, 1)*max_pct*ai.nsamples if max_time is None else random.uniform(-1, 1)*max_time*ai.sr)
sig = torch.from_numpy(np.roll(ai.sig.numpy(), s, axis=1)) if roll else _shift(ai.sig, s)
return AudioItem((sig, ai.sr, ai.path))
return _inner