#!/usr/bin/env python

import os
import scipy
import click
import numpy as np
import pretty_midi
import librosa

from midify.resample import resample
from midify.find_peaks import find_peaks, find_first_peak
from midify.notes import freqs

@click.command(no_args_is_help=True, context_settings={'show_default': True})
@click.argument('input')
@click.option('--output', help='Output MIDI file', default='output.mid')
@click.option('--sample-rate', help='Sample rate to make the analysis (not the sample rate of the input file!)', default=6000)
@click.option('--normalize-input/--no-normalize-input', help='Normalize input before MIDI converting', default=True)
@click.option('--peak-min-height', help='Minimum size of a peak to be considered a individual note', default=0.002)
@click.option('--peak-threshold', help='Minimum distance from the neighbours to be considered a peak', default=0.00001)
@click.option('--peak-min-distance', help='Minimum distance between peaks', default=600)
@click.option('--note-sample-size', help='How many samples use to identify a note', default=1024)
@click.option('--note-attack-shift', help='How many samples ignore from the start of a note (ignore attack)', default=20)
@click.option('--note-frequency-peak-min-height', help='Minimum height of a peak in the frequency spectrum to be considered a note', default=0.20)
@click.option('--note-frequency-threshold', help='Minimum distance between the neighbours of a frequency peak', default=0.01)
@click.option('--midi-note-duration', help='The amount of time the MIDI notes generated should last', default=0.5)
@click.option('--verbose/--no-verbose', help='Verbose', default=False)
def midify(input, output, normalize_input, sample_rate, peak_min_height, peak_threshold, peak_min_distance, note_sample_size, note_attack_shift, note_frequency_peak_min_height, note_frequency_threshold, midi_note_duration, verbose):
    """
    This command transform a input audio file in the MIDI format identifying notes positions and frequencies.
    INPUT: Input file to convert to MIDI. Acceptable file formats: wav,mp3
    """

    _, format = os.path.splitext(input)

    if format == '.wav':
        rate, data = scipy.io.wavfile.read(input)
    elif format == '.mp3':
        data, rate = librosa.load(input)
    else:
        raise click.ClickException(f"Invalid file format: ({format}). Supported audio files: [mp3,wav]")

    data = np.array(data)
    if verbose:
        click.echo(f"File sample rate {rate}")

    if verbose:
        click.echo(f"Data Shape: {data.shape}")

    if len(data.shape) == 2:
        if verbose:
            click.echo(f"Converting data to mono channel...")
        data = np.sum(data, axis=1)/data.shape[1]
        if verbose:
            click.echo(f"New Data Shape: {data.shape}")

    if normalize_input:
        if verbose:
            click.echo("Normalizing input...")
        data = data/np.linalg.norm(data)

    if verbose:
        click.echo(f"Resampling data to rate {sample_rate}")
    resampled = resample(data, rate, new_rate=sample_rate)

    peaks = find_peaks(resampled, height=peak_min_height, threshold=peak_threshold, min_distance=peak_min_distance)
    if verbose:
        click.echo(f"Found {len(peaks)} peaks")

    notes_samples = np.array([resampled[p + note_attack_shift: p + note_attack_shift + note_sample_size ] for p in peaks])

    if verbose:
        click.echo("Calculating FFT for each note...")
    notes_fft = np.array([np.abs(scipy.fft.fft(n))[0:512] for n in notes_samples])

    if verbose:
        click.echo("Inferring notes frequencies...")
    notes_freqs = np.array([find_first_peak(x, height=note_frequency_peak_min_height, threshold=note_frequency_threshold) * sample_rate / note_sample_size for x in notes_fft])

    if verbose:
        click.echo("Inferring notes from frequencies...")
    notes = np.array([np.argmin(np.abs(freqs - f)) for f in notes_freqs])

    if verbose:
        click.echo("Generating MIDI file...")
    mf = pretty_midi.PrettyMIDI()
    piano_program = pretty_midi.instrument_name_to_program('Acoustic Grand Piano')
    piano = pretty_midi.Instrument(program=piano_program)

    for note_number, peak in zip(notes, peaks):
        start = peak/sample_rate
        note = pretty_midi.Note(velocity=100, pitch=note_number, start=start, end=start + midi_note_duration)
        piano.notes.append(note)

    mf.instruments.append(piano)
    if verbose:
        click.echo("Saving MIDI File...")
    mf.write(output)

    if verbose:
        click.echo("Done!")

if __name__ == "__main__":
    midify()
