Source code for brian2hears.hrtf.hrtf

from builtins import range
from copy import copy

import numpy as np
from numpy.fft import fft, ifft

from brian2hears.sounds import Sound
from brian2hears.filtering import FIRFilterbank


__all__ = ['HRTF', 'HRTFSet', 'HRTFDatabase',
           'make_coordinates']

[docs]class HRTF(object): ''' Head related transfer function. **Attributes** ``impulse_response`` The pair of impulse responses (as stereo :class:`Sound` objects) ``fir`` The impulse responses in a format suitable for using with :class:`FIRFilterbank` (the transpose of ``impulse_response``). ``left``, ``right`` The two HRTFs (mono :class:`Sound` objects) ``samplerate`` The sample rate of the HRTFs. **Methods** .. automethod:: apply .. automethod:: filterbank You can get the number of samples in the impulse response with ``len(hrtf)``. ''' def __init__(self, hrir_l, hrir_r=None): if hrir_r is None: hrir = hrir_l else: hrir = Sound((hrir_l, hrir_r), samplerate=hrir_l.samplerate) self.samplerate = hrir.samplerate self.impulse_response = hrir self.left = hrir.left self.right = hrir.right
[docs] def apply(self, sound): ''' Returns a stereo :class:`Sound` object formed by applying the pair of HRTFs to the mono ``sound`` input. Equivalently, you can write ``hrtf(sound)`` for ``hrtf`` an :class:`HRTF` object. ''' # Note we use an FFT based method for applying HRTFs that is # mathematically equivalent to using convolution (accurate to 1e-15 # in practice) and around 100x faster. if not sound.nchannels==1: raise ValueError('HRTF can only be applied to mono sounds') if len(np.unique(np.array([self.samplerate, sound.samplerate], dtype=int))) > 1: raise ValueError('HRTF and sound samplerates do not match.') sound = np.asarray(sound).flatten() # Pad left/right/sound with zeros of length max(impulse response length) # at the beginning, and at the end so that they are all the same length # which should be a power of 2 for efficiency. The reason to pad at # the beginning is that the first output samples are not guaranteed to # be equal because of the delays in the impulse response, but they # exactly equalise after the length of the impulse response, so we just # zero pad. The reason for padding at the end is so that for the FFT we # can just multiply the arrays, which should have the same shape. left = np.asarray(self.left).flatten() right =np.asarray(self.right).flatten() ir_nmax = max(len(left), len(right)) nmax = max(ir_nmax, len(sound))+ir_nmax nmax = 2**int(np.ceil(np.log2(nmax))) leftpad = np.hstack((left, np.zeros(nmax-len(left)))) rightpad = np.hstack((right, np.zeros(nmax-len(right)))) soundpad = np.hstack((np.zeros(ir_nmax), sound, np.zeros(nmax-ir_nmax-len(sound)))) # Compute FFTs, multiply and compute IFFT left_fft = fft(leftpad, n=nmax) right_fft = fft(rightpad, n=nmax) sound_fft = fft(soundpad, n=nmax) left_sound_fft = left_fft*sound_fft right_sound_fft = right_fft*sound_fft left_sound = ifft(left_sound_fft).real right_sound = ifft(right_sound_fft).real # finally, we take only the unpadded parts of these left_sound = left_sound[ir_nmax:ir_nmax+len(sound)] right_sound = right_sound[ir_nmax:ir_nmax+len(sound)] return Sound((left_sound, right_sound), samplerate=self.samplerate)
__call__ = apply def get_fir(self): return np.array(self.impulse_response.T, copy=True) fir = property(fget=get_fir)
[docs] def filterbank(self, source, **kwds): ''' Returns an :class:`FIRFilterbank` object that can be used to apply the HRTF as part of a chain of filterbanks. ''' return FIRFilterbank(source, self.fir, **kwds)
def __len__(self): return self.impulse_response.shape[0]
[docs]def make_coordinates(**kwds): ''' Creates a numpy record array from the keywords passed to the function. Each keyword/value pair should be the name of the coordinate the array of values of that coordinate for each location. Returns a numpy record array. For example:: coords = make_coordinates(azimuth=[0, 30, 60, 0, 30, 60], elevation=[0, 0, 0, 30, 30, 30]) print coords['azimuth'] ''' dtype = [(name, float) for name in kwds.keys()] n = len(next(iter(kwds.values()))) x = np.zeros(n, dtype=dtype) for name, values in kwds.items(): x[name] = values return x
[docs]class HRTFSet(object): ''' A collection of HRTFs, typically for a single individual. Normally this object is created automatically by an :class:`HRTFDatabase`. **Attributes** ``hrtf`` A list of ``HRTF`` objects for each index. ``num_indices`` The number of HRTF locations. You can also use ``len(hrtfset)``. ``num_samples`` The sample length of each HRTF. ``fir_serial``, ``fir_interleaved`` The impulse responses in a format suitable for using with :class:`FIRFilterbank`, in serial (LLLLL...RRRRR....) or interleaved (LRLRLR...). **Methods** .. automethod:: subset .. automethod:: filterbank .. automethod:: get_index You can access an HRTF by index via ``hrtfset[index]``, or by its coordinates via ``hrtfset(coord1=val1, coord2=val2)``. **Initialisation** ``data`` An array of shape (2, num_indices, num_samples) where data[0,:,:] is the left ear and data[1,:,:] is the right ear, num_indices is the number of HRTFs for each ear, and num_samples is the length of the HRTF. ``samplerate`` The sample rate for the HRTFs (should have units of Hz). ``coordinates`` A record array of length ``num_indices`` giving the coordinates of each HRTF. You can use :func:`make_coordinates` to help with this. ''' def __init__(self, data, samplerate, coordinates): self.data = data self.samplerate = samplerate self.coordinates = coordinates self.hrtf = [] for i in range(self.num_indices): l = Sound(self.data[0, i, :], samplerate=self.samplerate) r = Sound(self.data[1, i, :], samplerate=self.samplerate) self.hrtf.append(HRTF(l, r)) def __getitem__(self, key): return self.hrtf[key]
[docs] def get_index(self, **kwds): ''' Return the index of the HRTF with the coords specified by keyword. ''' I = np.ones(self.num_indices, dtype=bool) for key, value in kwds.items(): I = np.logical_and(I, abs(self.coordinates[key]-value)<1e-10) indices = I.nonzero()[0] if len(indices)==0: raise IndexError('No HRTF exists with those coordinates') if len(indices)>1: raise IndexError('More than one HRTF exists with those coordinates') return indices[0]
def __call__(self, **kwds): return self.hrtf[self.get_index(**kwds)]
[docs] def subset(self, condition): ''' Generates the subset of the set of HRTFs whose coordinates satisfy the ``condition``. This should be one of: a boolean array of length the number of HRTFs in the set, with values of True/False to indicate if the corresponding HRTF should be included or not; an integer array with the indices of the HRTFs to keep; or a function whose argument names are names of the parameters of the coordinate system, e.g. ``condition=lambda azim:azim<pi/2``. ''' if callable(condition): fcode = condition.__code__ fvars = fcode.co_varnames ns = dict((name, self.coordinates[name]) for name in fvars) try: I = condition(**ns) I = I.nonzero()[0] except: I = False if isinstance(I, bool): # vector-based calculation doesn't work n = len(ns[fvars[0]]) I = np.array([condition(**dict((name, ns[name][j]) for name in fvars)) for j in range(n)]) I = I.nonzero()[0] else: if condition.dtype==bool: I = condition.nonzero()[0] else: I = condition hrtf = [self.hrtf[i] for i in I] coords = self.coordinates[I] data = self.data[:, I, :] obj = copy(self) obj.hrtf = hrtf obj.coordinates = coords obj.data = data return obj
def __len__(self): return self.num_indices @property def num_indices(self): return self.data.shape[1] @property def num_samples(self): return self.data.shape[2] @property def fir_serial(self): return np.reshape(self.data, (self.num_indices*2, self.num_samples)) @property def fir_interleaved(self): fir = np.empty((self.num_indices*2, self.num_samples)) fir[::2, :] = self.data[0, :, :] fir[1::2, :] = self.data[1, :, :] return fir
[docs] def filterbank(self, source, interleaved=False, **kwds): ''' Returns an :class:`FIRFilterbank` object which applies all of the HRTFs in the set. If ``interleaved=False`` then the channels are arranged in the order LLLL...RRRR..., otherwise they are arranged in the order LRLRLR.... ''' if interleaved: fir = self.fir_interleaved else: fir = self.fir_serial return FIRFilterbank(source, fir, **kwds)
[docs]class HRTFDatabase(object): ''' Base class for databases of HRTFs Should have an attribute 'subjects' giving a list of available subjects, and a method ``load_subject(subject)`` which returns an ``HRTFSet`` for that subject. The initialiser should take (optional) keywords: ``samplerate`` The intended samplerate (resampling will be used if it is wrong). If left unset, the natural samplerate of the data set will be used. ''' def __init__(self, samplerate=None): raise NotImplementedError def load_subject(self, subject): raise NotImplementedError