Source code for brian2hears.filtering.filterbank

from builtins import all, sum, range

try:
    import weave
except ImportError:
    try:
        from scipy import weave
    except ImportError:
        weave = None
import numpy as np

from brian2hears.bufferable import Bufferable

__all__ = ['Filterbank',
           'RestructureFilterbank',
                'Repeat', 'Tile', 'Join', 'Interleave',
           'FunctionFilterbank',
           'SumFilterbank',
           'DoNothingFilterbank',
           'ControlFilterbank',
           'CombinedFilterbank',
           ]

[docs]class Filterbank(Bufferable): ''' Generalised filterbank object **Documentation common to all filterbanks** Filterbanks all share a few basic attributes: .. autoattribute:: source .. attribute:: nchannels The number of channels. .. attribute:: samplerate The sample rate. .. autoattribute:: duration To process the output of a filterbank, the following method can be used: .. automethod:: process Alternatively, the buffer interface can be used, which is described in more detail below. Filterbank also defines arithmetical operations for +, -, ``*``, / where the other operand can be a filterbank or scalar. **Details on the class** This class is a base class not designed to be instantiated. A Filterbank object should define the interface of :class:`Bufferable`, as well as defining a ``source`` attribute. This is normally a :class:`Bufferable` object, but could be an iterable of sources (for example, for filterbanks that mix or add multiple inputs). The ``buffer_fetch_next(samples)`` method has a default implementation that fetches the next input, and calls the ``buffer_apply(input)`` method on it, which can be overridden by a derived class. This is typically the easiest way to implement a new filterbank. Filterbanks with multiple sources will need to override this default implementation. There is a default ``__init__`` method that can be called by a derived class that sets the ``source``, ``nchannels`` and ``samplerate`` from that of the ``source`` object. For multiple sources, the default implementation will check that each source has the same number of channels and samplerate and will raise an error if not. There is a default ``buffer_init()`` method that calls ``buffer_init()`` on the ``source`` (or list of sources). **Example of deriving a class** The following class takes N input channels and sums them to a single output channel:: class AccumulateFilterbank(Filterbank): def __init__(self, source): Filterbank.__init__(self, source) self.nchannels = 1 def buffer_apply(self, input): return reshape(sum(input, axis=1), (input.shape[0], 1)) Note that the default ``Filterbank.__init__`` will set the number of channels equal to the number of source channels, but we want to change it to have a single output channel. We use the ``buffer_apply`` method which automatically handles the efficient cacheing of the buffer for us. The method receives the array ``input`` which has shape ``(bufsize, nchannels)`` and sums over the channels (``axis=1``). It's important to reshape the output so that it has shape ``(bufsize, outputnchannels)`` so that it can be used as the input to subsequent filterbanks. ''' def __init__(self, source): if isinstance(source, Bufferable): self.source = source self.nchannels = source.nchannels self.samplerate = source.samplerate else: self.nchannels = source[0].nchannels self.samplerate = source[0].samplerate for s in source: if s.nchannels!=self.nchannels: raise ValueError('All sources must have the same number of channels.') if int(s.samplerate)!=int(self.samplerate): raise ValueError('All sources must have the same samplerate.') self.source = source def change_source(self, source): if not hasattr(self, '_source') or self._source is None: self._source = source return if isinstance(source, tuple): for s in source: if int(s.samplerate)!=int(self.samplerate): raise ValueError('source samplerate is wrong.') for news, olds in zip(source, self._source): if news.nchannels!=olds.nchannels: raise ValueError('New sources have different numbers of channels to old sources.') self._source = source return if source.nchannels==self.nchannels: self._source = source return if source.nchannels==1: self._source = Repeat(source, self.nchannels) else: raise ValueError('New source must have the same number of channels as old source.') source = property(fget=lambda self:self._source, fset=lambda self, source:self.change_source(source), doc=''' The source of the filterbank, a :class:`Bufferable` object, e.g. another :class:`Filterbank` or a :class:`Sound`. It can also be a tuple of sources. Can be changed after the object is created, although note that for some filterbanks this may cause problems if they do make assumptions about the input based on the first source object they were passed. If this is causing problems, you can insert a dummy filterbank (:class:`DoNothingFilterbank`) which is guaranteed to work if you change the source. ''') def get_duration(self): if hasattr(self, '_duration'): return self._duration else: source = self.source if isinstance(source, Bufferable): source = [source] try: durations = [s.duration for s in source] duration = max(durations) return duration except KeyError: raise KeyError('Cannot compute duration from sources.') def set_duration(self, duration): self._duration = duration duration = property(fget=get_duration, fset=set_duration, doc=''' The duration of the filterbank. If it is not specified by the user, it is computed by finding the maximum of its source durations. If these are not specified a :class:`KeyError` will be raised. ''')
[docs] def process(self, func=None, duration=None, buffersize=32): ''' Returns the output of the filterbank for the given duration. ``func`` If a function is specified, it should be a function of one or two arguments that will be called on each filtered buffered segment (of shape ``(buffersize, nchannels)`` in order. If the function has one argument, the argument should be buffered segment. If it has two arguments, the second argument is the value returned by the previous application of the function (or 0 for the first application). In this case, the method will return the final value returned by the function. See example below. ``duration=None`` The length of time (in seconds) or number of samples to process. If no ``func`` is specified, the method will return an array of shape ``(duration, nchannels)`` with the filtered outputs. Note that in many cases, this will be too large to fit in memory, in which you will want to process the filtered outputs online, by providing a function ``func`` (see example below). If no duration is specified, the maximum duration of the inputs to the filterbank will be used, or an error raised if they do not have durations. ``buffersize=32`` The size of the buffered segments to fetch, as a length of time or number of samples. 32 samples typically gives reasonably good performance. For example, to compute the RMS of each channel in a filterbank, you would do:: def sum_of_squares(input, running_sum_of_squares): return running_sum_of_squares+sum(input**2, axis=0) rms = sqrt(fb.process(sum_of_squares)/nsamples) ''' if duration is None: duration = self.duration if not isinstance(duration, int): duration = int(duration*self.samplerate) if not isinstance(buffersize, int): buffersize = int(buffersize*self.samplerate) self.buffer_init() endpoints = np.hstack((np.arange(0, duration, buffersize), duration)) zendpoints = zip(endpoints[:-1], endpoints[1:]) #sizes = diff(endpoints) if func is None: return np.vstack(tuple(self.buffer_fetch(start, end) for start, end in zendpoints)) else: if func.__code__.co_argcount==1: for start, end in zendpoints: func(self.buffer_fetch(start, end)) else: runningval = 0 for start, end in zendpoints: runningval = func(self.buffer_fetch(start, end), runningval) return runningval
def buffer_init(self): Bufferable.buffer_init(self) if isinstance(self.source, Bufferable): self.source.buffer_init() else: for s in self.source: s.buffer_init() self.next_sample = 0 def buffer_apply(self, input): raise NotImplementedError def buffer_fetch_next(self, samples): start = self.next_sample self.next_sample += samples end = start+samples input = self.source.buffer_fetch(start, end) return self.buffer_apply(input) def __add__ (self, other): if isinstance(other, Bufferable): return SumFilterbank((self, other)) else: func = lambda x: other+x return FunctionFilterbank(self, func) __radd__ = __add__ def __sub__ (self, other): if isinstance(other, Bufferable): return SumFilterbank((self, other), (1, -1)) else: func = lambda x: x-other return FunctionFilterbank(self, func) def __rsub__ (self, other): # Note that __rsub__ should return other-self if isinstance(other, Bufferable): return SumFilterbank((self, other), (-1, 1)) else: func = lambda x: other-x return FunctionFilterbank(self, func) def __mul__(self, other): if isinstance(other, Bufferable): func = lambda x, y: x*y return FunctionFilterbank((self, other), func) else: func = lambda x: x*other return FunctionFilterbank(self, func) __rmul__ = __mul__ def __div__(self, other): if isinstance(other, Bufferable): func = lambda x, y: x/y return FunctionFilterbank((self, other), func) else: func = lambda x: x/other return FunctionFilterbank(self, func) def __rdiv__(self, other): # Note __rdiv__ returns other/self if isinstance(other, Bufferable): func = lambda x, y: x/y return FunctionFilterbank((other, self), func) else: func = lambda x: other/x return FunctionFilterbank(self, func)
[docs]class RestructureFilterbank(Filterbank): ''' Filterbank used to restructure channels, including repeating and interleaving. **Standard forms of usage:** Repeat mono source N times:: RestructureFilterbank(source, N) For a stereo source, N copies of the left channel followed by N copies of the right channel:: RestructureFilterbank(source, N) For a stereo source, N copies of the channels tiled as LRLRLR...LR:: RestructureFilterbank(source, numtile=N) For two stereo sources AB and CD, join them together in serial to form the output channels in order ABCD:: RestructureFilterbank((AB, CD)) For two stereo sources AB and CD, join them together interleaved to form the output channels in order ACBD:: RestructureFilterbank((AB, CD), type='interleave') These arguments can also be combined together, for example to AB and CD into output channels AABBCCDDAABBCCDDAABBCCDD:: RestructureFilterbank((AB, CD), 2, 'serial', 3) The three arguments are the number of repeats before joining, the joining type ('serial' or 'interleave') and the number of tilings after joining. See below for details. **Initialise arguments:** ``source`` Input source or list of sources. ``numrepeat=1`` Number of times each channel in each of the input sources is repeated before mixing the source channels. For example, with repeat=2 an input source with channels ``AB`` will be repeated to form ``AABB`` ``type='serial'`` The method for joining the source channels, the options are ``'serial'`` to join the channels in series, or ``'interleave'`` to interleave them. In the case of ``'interleave'``, each source must have the same number of channels. An example of serial, if the input sources are ``abc`` and ``def`` the output would be ``abcdef``. For interleave, the output would be ``adbecf``. ``numtile=1`` The number of times the joined channels are tiled, so if the joined channels are ``ABC`` and ``numtile=3`` the output will be ``ABCABCABC``. ``indexmapping=None`` Instead of specifying the restructuring via ``numrepeat, type, numtile`` you can directly give the mapping of input indices to output indices. So for a single stereo source input, ``indexmapping=[1,0]`` would reverse left and right. Similarly, with two mono sources, ``indexmapping=[1,0]`` would have channel 0 of the output correspond to source 1 and channel 1 of the output corresponding to source 0. This is because the indices are counted in order of channels starting from the first source and continuing to the last. For example, suppose you had two sources, each consisting of a stereo sound, say source 0 was ``AB`` and source 1 was ``CD`` then ``indexmapping=[1, 0, 3, 2]`` would swap the left and right of each source, but leave the order of the sources the same, i.e. the output would be ``BADC``. ''' def __init__(self, source, numrepeat=1, type='serial', numtile=1, indexmapping=None): self._has_been_optimised = False self._reinit(source, numrepeat, type, numtile, indexmapping) def _do_reinit(self): self._reinit(*self._original_init_arguments) if self._has_been_optimised: self._optimisation_target._do_reinit() def _reinit(self, source, numrepeat, type, numtile, indexmapping): self._original_init_arguments = (source, numrepeat, type, numtile, indexmapping) if isinstance(source, Bufferable): source = (source,) if indexmapping is None: nchannels = np.array([s.nchannels for s in source]) idx = np.hstack(([0], np.cumsum(nchannels))) I = [np.arange(start, stop) for start, stop in zip(idx[:-1], idx[1:])] I = tuple(np.repeat(i, numrepeat) for i in I) if type=='serial': indexmapping = np.hstack(I) elif type=='interleave': if len(np.unique(nchannels))!=1: raise ValueError('For interleaving, all inputs must have an equal number of channels.') I0 = len(I[0]) indexmapping = np.zeros(I0*len(I), dtype=int) for j, i in enumerate(I): indexmapping[j::len(I)] = i else: raise ValueError('Type must be "serial" or "interleave"') indexmapping = np.tile(indexmapping, numtile) if not isinstance(indexmapping, np.ndarray): indexmapping = np.array(indexmapping, dtype=int) # optimisation to reduce multiple RestructureFilterbanks into a single # one, by collating the sources and reconstructing the indexmapping # from the individual indexmappings if all(isinstance(s, RestructureFilterbank) for s in source): newsource = () newsourcesizes = () for s in source: s._has_been_optimised = True s._optimisation_target = self newsource += s.source inputsourcesize = sum(inpsource.nchannels for inpsource in s.source) newsourcesizes += (inputsourcesize,) newsourcesizes = np.array(newsourcesizes) newsourceoffsets = np.hstack((0, np.cumsum(newsourcesizes))) new_indexmapping = np.zeros_like(indexmapping) sourcesizes = np.array(tuple(s.nchannels for s in source)) sourceoffsets = np.hstack((0, np.cumsum(sourcesizes))) # gives the index of the source of each element of indexmapping sourceindices = np.digitize(indexmapping, np.cumsum(sourcesizes)) for i in range(len(indexmapping)): source_index = sourceindices[i] s = source[source_index] relative_index = indexmapping[i]-sourceoffsets[source_index] source_relative_index = s.indexmapping[relative_index] new_index = source_relative_index+newsourceoffsets[source_index] new_indexmapping[i] = new_index source = newsource indexmapping = new_indexmapping self.indexmapping = indexmapping self.nchannels = len(indexmapping) self.samplerate = source[0].samplerate for s in source: if int(s.samplerate)!=int(self.samplerate): raise ValueError('All sources must have the same samplerate.') self._source = source def buffer_fetch_next(self, samples): start = self.next_sample self.next_sample += samples end = start+samples inputs = tuple(s.buffer_fetch(start, end) for s in self.source) input = np.hstack(inputs) input = input[:, self.indexmapping] return input def change_source(self, source): if not hasattr(self, '_source') or self._source is None: self._source = source return oldsource, numrepeat, type, numtile, indexmapping = self._original_init_arguments self._original_init_arguments = source, numrepeat, type, numtile, indexmapping self._do_reinit()
# self._reinit(source, numrepeat, type, numtile, indexmapping) # if self._has_been_optimised: # target = self._optimisation_target # target._reinit(*target._original_init_arguments)
[docs]class Repeat(RestructureFilterbank): ''' Filterbank that repeats each channel from its input, e.g. with 3 repeats channels ABC would map to AAABBBCCC. ''' def __init__(self, source, numrepeat): RestructureFilterbank.__init__(self, source, numrepeat)
[docs]class Tile(RestructureFilterbank): ''' Filterbank that tiles the channels from its input, e.g. with 3 tiles channels ABC would map to ABCABCABC. ''' def __init__(self, source, numtile): RestructureFilterbank.__init__(self, source, numtile=numtile)
[docs]class Join(RestructureFilterbank): ''' Filterbank that joins the channels of its inputs in series, e.g. with two input sources with channels AB and CD respectively, the output would have channels ABCD. You can initialise with multiple sources separated by commas, or by passing a list of sources. ''' def __init__(self, *sources): source = [] for s in sources: if isinstance(s, Bufferable): source.append(s) else: source.extend(s) RestructureFilterbank.__init__(self, tuple(source), type='serial')
[docs]class Interleave(RestructureFilterbank): ''' Filterbank that interleaves the channels of its inputs, e.g. with two input sources with channels AB and CD respectively, the output would have channels ACBD. You can initialise with multiple sources separated by commas, or by passing a list of sources. ''' def __init__(self, *sources): source = [] for s in sources: if isinstance(s, Bufferable): source.append(s) else: source.extend(s) RestructureFilterbank.__init__(self, tuple(source), type='interleave')
[docs]class FunctionFilterbank(Filterbank): ''' Filterbank that just applies a given function. The function should take as many arguments as there are sources. For example, to half-wave rectify inputs:: FunctionFilterbank(source, lambda x: clip(x, 0, Inf)) The syntax ``lambda x: clip(x, 0, Inf)`` defines a function object that takes a single argument ``x`` and returns ``clip(x, 0, Inf)``. The numpy function ``clip(x, low, high)`` returns the values of ``x`` clipped between ``low`` and ``high`` (so if ``x<low`` it returns ``low``, if ``x>high`` it returns ``high``, otherwise it returns ``x``). The symbol ``Inf`` means infinity, i.e. no clipping of positive values. **Technical details** Note that functions should operate on arrays, in particular on 2D buffered segments, which are arrays of shape ``(bufsize, nchannels)``. Typically, most standard functions from numpy will work element-wise. If you want a filterbank that changes the shape of the input (e.g. changes the number of channels), set the ``nchannels`` keyword argument to the number of output channels. ''' def __init__(self, source, func, nchannels=None,**params): if isinstance(source, Bufferable): source = (source,) Filterbank.__init__(self, source) self.func = func if nchannels is not None: self.nchannels = nchannels self.params = params def buffer_fetch_next(self, samples): start = self.cached_buffer_end end = start+samples inputs = tuple(s.buffer_fetch(start, end) for s in self.source) # print inputs,self.params return self.func(*inputs,**self.params)
[docs]class SumFilterbank(FunctionFilterbank): ''' Sum filterbanks together with given weight vectors. For example, to take the sum of two filterbanks:: SumFilterbank((fb1, fb2)) To take the difference:: SumFilterbank((fb1, fb2), (1, -1)) ''' def __init__(self, source, weights=None): if weights is None: weights = np.ones(len(source)) self.weights = weights func = lambda *inputs: sum(input*w for input, w in zip(inputs, weights)) FunctionFilterbank.__init__(self, source, func)
[docs]class DoNothingFilterbank(Filterbank): ''' Filterbank that does nothing to its input. Useful for removing a set of filters without having to rewrite your code. Can also be used for simply writing compound derived classes. For example, if you want a compound Filterbank that does AFilterbank and then BFilterbank, but you want to encapsulate that into a single class, you could do:: class ABFilterbank(DoNothingFilterbank): def __init__(self, source): a = AFilterbank(source) b = BFilterbank(a) DoNothingFilterbank.__init__(self, b) However, a more general way of writing compound filterbanks is to use :class:`CombinedFilterbank`. ''' def buffer_apply(self, input): return input
[docs]class ControlFilterbank(Filterbank): ''' Filterbank that can be used for controlling behaviour at runtime Typically, this class is used to implement a control path in an auditory model, modifying some filterbank parameters based on the output of other filterbanks (or the same ones). The controller has a set of input filterbanks whose output values are used to modify a set of output filterbanks. The update is done by a user specified function or class which is passed these output values. The controller should be inserted as the last bank in a chain. Initialisation arguments: ``source`` The source filterbank, the values from this are used unmodified as the output of this filterbank. ``inputs`` Either a single filterbank, or sequence of filterbanks which are used as inputs to the ``updater``. ``targets`` The filterbank or sequence of filterbanks that are modified by the updater. ``updater`` The function or class which does the updating, see below. ``max_interval`` If specified, ensures that the updater is called at least as often as this interval (but it may be called more often). Can be specified as a time or a number of samples. **The updater** The ``updater`` argument can be either a function or class instance. If it is a function, it should have a form like:: # A single input def updater(input): ... # Two inputs def updater(input1, input2): ... # Arbitrary number of inputs def updater(*inputs): ... Each argument ``input`` to the function is a numpy array of shape ``(numsamples, numchannels)`` where ``numsamples`` is the number of samples just computed, and ``numchannels`` is the number of channels in the corresponding filterbank. The function is not restricted in what it can do with these inputs. Functions can be used to implement relatively simple controllers, but for more complicated situations you may want to maintain some state variables for example, and in this case you can use a class. The object ``updater`` should be an instance of a class that defines the ``__call__`` method (with the same syntax as above for functions). In addition, you can define a reinitialisation method ``reinit()`` which will be called when the ``buffer_init()`` method is called on the filterbank, although this is entirely optional. **Example** The following will do a simple form of gain control, where the gain parameter will drift exponentially towards target_rms/rms with a given time constant:: # This class implements the gain (see Filterbank for details) class GainFilterbank(Filterbank): def __init__(self, source, gain=1.0): Filterbank.__init__(self, source) self.gain = gain def buffer_apply(self, input): return self.gain*input # This is the class for the updater object class GainController(object): def __init__(self, target, target_rms, time_constant): self.target = target self.target_rms = target_rms self.time_constant = time_constant def reinit(self): self.sumsquare = 0 self.numsamples = 0 def __call__(self, input): T = input.shape[0]/self.target.samplerate self.sumsquare += sum(input**2) self.numsamples += input.size rms = sqrt(self.sumsquare/self.numsamples) g = self.target.gain g_tgt = self.target_rms/rms tau = self.time_constant self.target.gain = g_tgt+exp(-T/tau)*(g-g_tgt) And an example of using this with an input ``source``, a target RMS of 0.2 and a time constant of 50 ms, updating every 10 ms:: gain_fb = GainFilterbank(source) updater = GainController(gain_fb, 0.2, 50*ms) control = ControlFilterbank(gain_fb, source, gain_fb, updater, 10*ms) ''' def __init__(self, source, inputs, targets, updater, max_interval=None): Filterbank.__init__(self, source) if not isinstance(inputs, (list, tuple)): inputs = [inputs] if not isinstance(targets, (list, tuple)): targets = [targets] self.inputs = inputs self.updater = updater if max_interval is not None: if not isinstance(max_interval, int): max_interval = int(max_interval*source.samplerate) for x in inputs+targets: x.maximum_buffer_size = max_interval self.maximum_buffer_size = max_interval def buffer_init(self): Filterbank.buffer_init(self) if hasattr(self.updater, 'reinit'): self.updater.reinit() def buffer_fetch_next(self, samples): start = self.next_sample self.next_sample += samples end = start+samples source_input = self.source.buffer_fetch(start, end) input_buffers = [x.buffer_fetch(start, end) for x in self.inputs] self.updater(*input_buffers) return source_input
[docs]class CombinedFilterbank(Filterbank): ''' Filterbank that encapsulates a chain of filterbanks internally. This class should mostly be used by people writing extensions to Brian hears rather than by users directly. The purpose is to take an existing chain of filterbanks and wrap them up so they appear to the user as a single filterbank which can be used exactly as any other filterbank. In order to do this, derive from this class and in your initialisation follow this pattern:: class RectifiedGammatone(CombinedFilterbank): def __init__(self, source, cf): CombinedFilterbank.__init__(self, source) source = self.get_modified_source() # At this point, insert your chain of filterbanks acting on # the modified source object gfb = Gammatone(source, cf) rectified = FunctionFilterbank(gfb, lambda input: clip(input, 0, Inf)) # Finally, set the output filterbank to be the last in your chain self.set_output(fb) This combination of a :class:`Gammatone` and a rectification via a :class:`FunctionFilterbank` can now be used as a single filterbank, for example:: x = whitenoise(100*ms) fb = RectifiedGammatone(x, [1*kHz, 1.5*kHz]) y = fb.process() **Details** The reason for the ``get_modified_source()`` call is that the source attribute of a filterbank can be changed after creation. The modified source provides a buffer (in fact, a :class:`DoNothingFilterbank`) so that the input to the chain of filters defined by the derived class doesn't need to be changed. ''' def __init__(self, source): Filterbank.__init__(self, source) def get_duration(self): if hasattr(self, '_duration'): return self._duration else: return max(Filterbank.get_duration(self), self.output.duration) source = property(fget=lambda self:self._source, fset=lambda self, source:self.change_source(source)) def change_source(self, source): Filterbank.change_source(self, source) if hasattr(self, '_modified_source'): self._modified_source.source = source def get_modified_source(self): self._modified_source = DoNothingFilterbank(self.source) return self._modified_source def set_output(self, output): self.output = output self.nchannels = output.nchannels def buffer_init(self): Filterbank.buffer_init(self) self.output.buffer_init() def buffer_fetch(self, start, end): return self.output.buffer_fetch(start, end)