Python vggish_postprocess.Postprocessor() Examples

The following are 2 code examples of vggish_postprocess.Postprocessor(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module vggish_postprocess , or try the search function .
Example #1
Source File: audio_feature_extractor.py    From Tensorflow-Audio-Classification with Apache License 2.0 6 votes vote down vote up
def __init__(self, checkpoint, pca_params, input_tensor_name, output_tensor_name):
        """Create a new Graph and a new Session for every VGGishExtractor object."""
        super(VGGishExtractor, self).__init__()
        
        self.graph = tf.Graph()
        with self.graph.as_default():
            vggish_slim.define_vggish_slim(training=False)

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True
        self.sess = tf.Session(graph=self.graph, config=sess_config)
        vggish_slim.load_defined_vggish_slim_checkpoint(self.sess, checkpoint)
        
        # use the self.sess to init others
        self.input_tensor = self.graph.get_tensor_by_name(input_tensor_name)
        self.output_tensor = self.graph.get_tensor_by_name(output_tensor_name)

        # postprocessor
        self.postprocess = vggish_postprocess.Postprocessor(pca_params) 
Example #2
Source File: extract_audioset_embedding.py    From audioset_classification with MIT License 4 votes vote down vote up
def extract_audioset_embedding():
    """Extract log mel spectrogram features. 
    """
    
    # Arguments & parameters
    mel_bins = vggish_params.NUM_BANDS
    sample_rate = vggish_params.SAMPLE_RATE
    input_len = vggish_params.NUM_FRAMES
    embedding_size = vggish_params.EMBEDDING_SIZE
    
    '''You may modify the EXAMPLE_HOP_SECONDS in vggish_params.py to change the 
    hop size. '''

    # Paths
    audio_path = 'appendixes/01.wav'
    checkpoint_path = os.path.join('vggish_model.ckpt')
    pcm_params_path = os.path.join('vggish_pca_params.npz')
    
    if not os.path.isfile(checkpoint_path):
        raise Exception('Please download vggish_model.ckpt from '
            'https://storage.googleapis.com/audioset/vggish_model.ckpt '
            'and put it in the root of this codebase. ')
        
    if not os.path.isfile(pcm_params_path):
        raise Exception('Please download pcm_params_path from '
        'https://storage.googleapis.com/audioset/vggish_pca_params.npz '
        'and put it in the root of this codebase. ')
    
    # Load model
    sess = tf.Session()
    
    vggish_slim.define_vggish_slim(training=False)
    vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint_path)
    features_tensor = sess.graph.get_tensor_by_name(vggish_params.INPUT_TENSOR_NAME)
    embedding_tensor = sess.graph.get_tensor_by_name(vggish_params.OUTPUT_TENSOR_NAME)
    
    pproc = vggish_postprocess.Postprocessor(pcm_params_path)

    # Read audio
    (audio, _) = read_audio(audio_path, target_fs=sample_rate)
    
    # Extract log mel feature
    logmel = vggish_input.waveform_to_examples(audio, sample_rate)

    # Extract embedding feature
    [embedding_batch] = sess.run([embedding_tensor], feed_dict={features_tensor: logmel})
    
    # PCA
    postprocessed_batch = pproc.postprocess(embedding_batch)
    
    print('Audio length: {}'.format(len(audio)))
    print('Log mel shape: {}'.format(logmel.shape))
    print('Embedding feature shape: {}'.format(postprocessed_batch.shape))