#!/usr/bin/env python
# Copyright (C) 2015 Chris Pankow
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with with program; see the file COPYING. If not, write to the Free
# Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# 02111-1307  USA
#
__program__ = 'pyburst_excesspower'
import sys
import math
from argparse import ArgumentParser

import numpy
import scipy.stats

from glue.ligolw import lsctables, utils, ligolw
lsctables.use_in(ligolw.LIGOLWContentHandler)
from glue.lal import LIGOTimeGPS

import glue
from glue.ligolw.utils.process import register_to_xmldoc
from glue.ligolw.utils.search_summary import append_search_summary
from glue.segments import segment
from glue.lal import LIGOTimeGPS

import lal, lalburst

from pycbc import frame, strain, psd, filter, types, DYN_RANGE_FAC
from pycbc.fft import fft, ifft
import logging

def construct_args():
    argp = ArgumentParser()
    argp.add_argument("--tile-fap", type=float, default=1e-7, help="Tile false alarm probability threshold in Gaussian noise. Default is 1e-7")
    argp.add_argument("--verbose", action="store_true", help="Be verbose")
    argp.add_argument("--min-frequency", type=float, default=0, help="Lowest frequency of the filter bank, default is 0 Hz.")
    argp.add_argument("--max-frequency", type=float, default=None, help="Highest frequency of the filter bank, default is None, meaning use Nyquist.")
    argp.add_argument("--max-duration", type=float, default=None, help="Longest duration tile to compute.")
    argp.add_argument("--tile-bandwidth", type=float, default=None, help="Bandwidth of the finest filters. Default is None, and would be inferred from the data bandwidth and number of channels.")
    argp.add_argument("--channels", type=int, default=None, help="Number of frequency channels to use. Default is None, and would be inferred from the data bandwidth and tile bandwidth.")
    argp.add_argument("--analysis-start-time", type=int, default=None, help="Start analysis from this GPS time instead of the --gps-start-time")
    argp.add_argument("--analysis-end-time", type=int, default=None, help="End analysis at this GPS time instead of the --gps-end-time")
    strain.insert_strain_option_group(argp)
    psd.insert_psd_option_group(argp)
    return argp

def calculate_spectral_correlation(fft_window_len, wtype='hann', window_fraction=None):
    """
    Calculate the two point spectral correlation introduced by windowing the data before transforming to the frequency domain -- valid choices are 'hann' and 'tukey'. The window_fraction parameter only has meaning for wtype='tukey'.
    """
    # FIXME Dict -> window function map
    if wtype == 'hann':
        window = lal.CreateHannREAL8Window(fft_window_len)
    elif wtype == 'tukey':
        window = lal.CreateTukeyREAL8Window(fft_window_len, window_fraction)
    else:
        raise ValueError("Can't handle window type %s" % wtype)

    fft_plan = lal.CreateForwardREAL8FFTPlan(len(window.data.data), 1)

    return window, lal.REAL8WindowTwoPointSpectralCorrelation(window, fft_plan)

def create_filter_bank(delta_f, flow, band, nchan, psd, spec_corr):
    filter_length = int(2*band/delta_f)+1
    filter_fseries = numpy.zeros((nchan, filter_length), dtype=numpy.complex128)
    lal_psd = psd.lal()
    lal_filters = []
    for i in range(nchan):
        lal_filter = lalburst.CreateExcessPowerFilter(flow + i*band, band, lal_psd, spec_corr)
        #print lal_filter.name
        #print lalburst.ExcessPowerFilterInnerProduct(lal_filter, lal_filter, spec_corr, None)
        filter_fseries[i,:] = lal_filter.data.data
        lal_filters.append(lal_filter)
    return filter_fseries, lal_filters

def compute_filter_ips_self(lal_filters, spec_corr, psd=None):
    """
    Compute a set of inner products of input filters with themselves. If psd argument is given, the unwhitened filter inner products will be returned.
    """
    return numpy.array([lalburst.ExcessPowerFilterInnerProduct(f, f, spec_corr, psd) for f in lal_filters])

def compute_filter_ips_adjacent(lal_filters, spec_corr, psd=None):
    """
    Compute a set of filter inner products between input adjacent filters. If psd argument is given, the unwhitened filter inner products will be returned.

    The returned array index is the inner product between the lal_filter of the same index, and its (array) adjacent filter --- assumed to be the frequency adjacent filter.
    """
    return numpy.array([lalburst.ExcessPowerFilterInnerProduct(f1, f2, spec_corr, psd) for f1, f2 in zip(lal_filters[:-1], lal_filters[1:])])

def compute_channel_renomalization(nc_sum, lal_filters, spec_corr, verbose=True):
    """
    Compute the renormalization for the base filters up to a given bandwidth.
    """
    min_band = (len(lal_filters[0].data.data)-1) * lal_filters[0].deltaF / 2

    if verbose:
        sys.stdout.write("Calculating renormalization for resolution level containing %d %f Hz channels" % (nc_sum+1, min_band))
    mu_sq = (nc_sum+1)*numpy.array([lalburst.ExcessPowerFilterInnerProduct(f, f, spec_corr, None) for f in lal_filters])

    # Uncomment to get all possible frequency renormalizations
    #for n in xrange(nc_sum, nchans): # channel position index
    for n in xrange(nc_sum, nchans, nc_sum+1): # channel position index
        for k in xrange(0, nc_sum): # channel sum index
            # FIXME: We've precomputed this, so use it instead
            mu_sq[n] += 2*lalburst.ExcessPowerFilterInnerProduct(lal_filters[n-k], lal_filters[n-1-k], spec_corr, None)
    #print mu_sq[nc_sum::nc_sum+1]
    if verbose:
        print " done."

    return mu_sq

def measure_hrss(z_j_b, uw_ss_ii, uw_ss_ij, w_ss_ij, delta_f, delta_t, filter_len, dof):
    """
    Approximation of unwhitened sum of squares signal energy in a given EP tile. See T1200125 for equation number reference.
    z_j_b - time frequency map block which the constructed tile covers
    uw_ss_ii - unwhitened filter inner products
    uw_ss_ij - unwhitened adjacent filter inner products
    w_ss_ij - whitened adjacent filter inner products
    delta_f - frequency binning of EP filters
    delta_t - native time resolution of the time frequency map
    filter_len - number of samples in a fitler
    dof - degrees of freedom in the tile (twice the time-frequency area)
    """

    s_j_b_avg = uw_ss_ii * delta_f / 2
    # unwhitened sum of squares of wide virtual filter
    s_j_nb_avg = uw_ss_ii.sum() / 2 + uw_ss_ij.sum()
    s_j_nb_avg *= delta_f

    s_j_nb_denom = s_j_b_avg.sum() + 2 * 2 / filter_len * \
        numpy.sum(numpy.sqrt(s_j_b_avg[:-1] * s_j_b_avg[1:]) * w_ss_ij)

    # eqn. 62
    uw_ups_ratio = s_j_nb_avg / s_j_nb_denom

    # eqn. 63 -- approximation of unwhitened signal energy time series
    # FIXME: The sum in this equation is over nothing, but indexed by frequency
    # I'll make that assumption here too.
    s_j_nb = numpy.sum(z_j_b.T * numpy.sqrt(s_j_b_avg), axis=0)
    s_j_nb *= numpy.sqrt(uw_ups_ratio / filter_len * 2)
    # eqn. 64 -- approximate unwhitened signal energy minus noise contribution
    # FIXME: correct axis of summation?
    return math.sqrt(numpy.sum(numpy.absolute(s_j_nb)**2) * delta_t - s_j_nb_avg * dof * delta_t)

# < s^2_j(f_1, b) > = 1 / 2 / N * \delta_t EPIP{\Theta, \Theta; P}
def uw_sum_sq(filter1, filter2, spec_corr, psd):
    return lalburst.ExcessPowerFilterInnerProduct(filter1, filter2, spec_corr, psd)

def measure_hrss_slowly(z_j_b, lal_filters, spec_corr, psd, delta_t, dof):
    """
    Approximation of unwhitened sum of squares signal energy in a given EP tile. See T1200125 for equation number reference.

    NOTE: This function is deprecated in favor of measure_hrss, since it requires recomputation of many inner products, making it particularly slow.
    """
    # FIXME: Make sure you sum in time correctly
    # Number of finest bands in given tile
    nb = len(z_j_b)
    # eqn. 56 -- unwhitened mean square of filter with itself
    uw_ss_ii = numpy.array([uw_sum_sq(lal_filters[i], lal_filters[i], spec_corr, psd) for i in range(nb)])
    s_j_b_avg = uw_ss_ii * lal_filters[0].deltaF / 2
    # eqn. 57 -- unwhitened mean square of filter with adjacent filter
    uw_ss_ij = numpy.array([uw_sum_sq(lal_filters[i], lal_filters[i+1], spec_corr, psd) for i in range(nb-1)])
    # unwhitened sum of squares of wide virtual filter
    s_j_nb_avg = uw_ss_ii.sum() / 2 + uw_ss_ij.sum()
    s_j_nb_avg *= lal_filters[0].deltaF

    # eqn. 61
    w_ss_ij = numpy.array([uw_sum_sq(lal_filters[i], lal_filters[i+1], spec_corr, None) for i in range(nb-1)])
    s_j_nb_denom = s_j_b_avg.sum() + 2 * 2 / len(lal_filters[0].data.data) * \
        (numpy.sqrt(s_j_b_avg[:-1] * s_j_b_avg[1:]) * w_ss_ij).sum()

    # eqn. 62
    uw_ups_ratio = s_j_nb_avg / s_j_nb_denom

    # eqn. 63 -- approximation of unwhitened signal energy time series
    # FIXME: The sum in this equation is over nothing, but indexed by frequency
    # I'll make that assumption here too.
    s_j_nb = numpy.sum(z_j_b.T * numpy.sqrt(s_j_b_avg), axis=0)
    s_j_nb *= numpy.sqrt(uw_ups_ratio / len(lal_filters[0].data.data) * 2)
    # eqn. 64 -- approximate unwhitened signal energy minus noise contribution
    # FIXME: correct axis of summation?
    return math.sqrt((numpy.absolute(s_j_nb)**2).sum() * delta_t - s_j_nb_avg * dof * delta_t)

def measure_hrss_poorly(tile_energy, sub_psd):
    return math.sqrt(tile_energy / numpy.average(1.0 / sub_psd) / 2)

def trigger_list_from_map(tfmap, event_list, threshold, start_time, start_freq, duration, band, df, dt, psd=None):

    # FIXME: If we don't convert this the calculation takes forever --- but we should convert it once and handle deltaF better later
    if psd is not None:
        npy_psd = psd.numpy()

    start_time = LIGOTimeGPS(float(start_time))
    ndof = 2 * duration * band

    for i, j in zip(*numpy.where(tfmap > threshold)):
        event = event_list.RowType()

        # The points are summed forward in time and thus a `summed point' is the
        # sum of the previous N points. If this point is above threshold, it
        # corresponds to a tile which spans the previous N points. However, th
        # 0th point (due to the convolution specifier 'valid') is actually
        # already a duration from the start time. All of this means, the +
        # duration and the - duration cancels, and the tile 'start' is, by
        # definition, the start of the time frequency map if j = 0
        # FIXME: I think this needs a + dt/2 to center the tile properly
        event.set_start(start_time + float(j * dt))
        event.set_peak(event.get_start() + duration / 2)
        event.central_freq = start_freq + i * df + 0.5 * band

        event.duration = duration
        event.bandwidth = band
        event.chisq_dof = ndof

        event.snr = math.sqrt(tfmap[i,j] / event.chisq_dof - 1)
        # FIXME: Magic number 0.62 should be determine empircally
        event.confidence = -lal.LogChisqCCDF(event.snr * 0.62, event.chisq_dof * 0.62)
        if psd is not None:
            # NOTE: I think the pycbc PSDs always start at 0 Hz --- check
            psd_idx_min = int((event.central_freq - event.bandwidth / 2) / psd.delta_f)
            psd_idx_max = int((event.central_freq + event.bandwidth / 2) / psd.delta_f)

            # FIXME: heuristically this works better with E - D -- it's all
            # going away with the better h_rss calculation soon anyway
            event.amplitude = measure_hrss_poorly(tfmap[i,j] - event.chisq_dof, npy_psd[psd_idx_min:psd_idx_max])
        else:
            event.amplitude = None

        event.process_id = None
        event.event_id = event_list.get_next_id()
        event_list.append(event)

def determine_output_segment(inseg, dt_stride, sample_rate, window_fraction=0.0):
    """
    Given an input data stretch segment inseg, a data block stride dt_stride, the data sample rate, and an optional window_fraction, return the amount of data that can be processed without corruption effects from the window.

    If window_fration is set to 0 (default), assume no windowing.
    """
    # Amount to overlap successive blocks so as not to lose data
    window_overlap_samples = window_fraction * sample_rate
    outseg = inseg.contract(window_fraction * dt_stride / 2)

    # With a given dt_stride, we cannot process the remainder of this data
    remainder = math.fmod(abs(outseg), dt_stride * (1 - window_fraction))
    # ...so make an accounting of it
    outseg = segment(outseg[0], outseg[1] - remainder)
    return outseg

sq_rt_2 = math.sqrt(2)
def make_tiles(tf_map, nc_sum, mu_sq):
    tiles = numpy.zeros(tf_map.shape)
    sum_filter = numpy.ones(nc_sum+1)
    # Here's the deal: we're going to keep only the valid output and
    # it's *always* going to exist in the lowest available indices
    for t in xrange(tf_map.shape[1]):
        # Sum and drop correlate tiles
        # FIXME: don't drop correlated tiles
        output = numpy.convolve(tf_map[:,t], sum_filter, 'valid')[::nc_sum+1]
        #output = fftconvolve(tf_map[:,t], sum_filter, 'valid')[::nc_sum+1]
        tiles[:len(output),t] = numpy.absolute(output) / sq_rt_2
    return tiles[:len(output)]**2 / mu_sq[nc_sum::nc_sum+1].reshape(-1, 1)

#
# Optimization plan: If we keep the summed complex TF plane in known indices,
# we can save ourselves individual sums at wider frequency resolutions.
# Caveats:
#   1.  We have to keep track of where we're storing things
#   2.  We have to do it from the finest resolution (for *all* t0s) and work our way up
#
# In the end, I think this is a Haar wavelet transform. Look into it
#
def make_indp_tiles(tf_map, nc_sum, mu_sq):
    """
    Create a time frequency map with resolution of tf_map binning divided by nc_sum + 1. All tiles will be independent up to overlap from the original tiling. The mu_sq is applied to the resulting addition to normalize the outputs to be zero-mean unit-variance Gaussian variables (if the input is Gaussian).
    """
    tiles = tf_map.copy()
    # Here's the deal: we're going to keep only the valid output and
    # it's *always* going to exist in the lowest available indices
    stride = nc_sum + 1
    for i in xrange(tiles.shape[0]/stride):
        numpy.absolute(tiles[stride*i:stride*(i+1)].sum(axis=0), tiles[stride*(i+1)-1])

    # Do the proper normalization
    return tiles[nc_sum::nc_sum+1].real**2 / mu_sq[nc_sum::nc_sum+1].reshape(-1, 1)

#
# Preliminaries
#
numpy.random.seed(0)

root = logging.getLogger()
root.setLevel(logging.INFO)

argp = construct_args()
args = argp.parse_args()

#
# Time - frequency settings
#

inseg = segment(LIGOTimeGPS(args.analysis_start_time or args.gps_start_time), LIGOTimeGPS(args.analysis_end_time or args.gps_end_time))

# How long to process in a block of data?
dt_stride = args.psd_segment_length
dt_stride = 32.0

# same, in samples
ts_stride = int(dt_stride * args.sample_rate)

# Bin spacing of the PSD (and filters), lowest frequency of the first filter
#delta_f, flow = 1.0/dt_stride, 4.0
delta_f, flow = 1.0/32.0, 4.0

#if args.min_frequency < args.strain_high_pass:
    #print >>sys.stderr, "Warning: strain high pass frequency %f is greater than the tile minimum frequency %f --- this is likely to cause strange output below the bandpass frequency" % (args.strain_high_pass, args.min_frequency)

if args.max_frequency is None:
    args.max_frequency = args.sample_rate / 2.0

if args.tile_bandwidth is None and args.channels is None:
    exit("Either --tile-bandwidth or --channels must be specified to set up time-frequency plane")
else:
    assert args.max_frequency >= args.min_frequency
    data_band = args.max_frequency - args.min_frequency

    if args.tile_bandwidth is not None:
        nchans = args.channels = int(data_band / args.tile_bandwidth)
    elif args.channels is not None:
        band = args.tile_bandwidth = data_band / args.channels

    assert args.channels > 1

# minimum bandwidth
band = 4.0 # Hz

# Number of channels
nchans = args.sample_rate / 2 / band - 1
#nchans = 64
nchans = 1021

ts_data = strain.from_cli(args, precision='double', dyn_range_fac=DYN_RANGE_FAC)

#
# Estimate PSD
#

fd_psd = psd.from_cli(args, int(args.sample_rate/delta_f)/2+1, delta_f, flow, ts_data, dyn_range_factor=DYN_RANGE_FAC)
# DEBUG for gaussian noise
#for i in range(len(fd_psd)):
    #fd_psd[i] = 2.0/args.sample_rate

# We need this for the SWIG functions
lal_psd = fd_psd.lal()

#
# Whitening window and spectral correlation
#

# How much data on either side of the tukey window to discard
# NOTE: Nominally, this means that one will lose window_fraction * dt_stride to
# corruption from the window: this is simply discarded
# NOTE 2: This is tuned to give an integer offset when used with dt_stride = 8,
# smaller windows will have fractions of integers, but larger powers of two will
# still preseve this --- probably not a big deal in the end
window_fraction = 0.25

window, spec_corr = calculate_spectral_correlation(ts_stride, 'tukey', window_fraction=window_fraction)
window = window.data.data
window_sigma_sq = numpy.mean(window**2)
# Pre scale the window by its root mean squared -- see eqn 11 of EP document
#window /= numpy.sqrt(window_sigma_sq)

#
# Filter generation
#

# TODO: Match these up
filter_bank, lal_filters = create_filter_bank(fd_psd.delta_f, flow+band/2, band, nchans, fd_psd, spec_corr)

# This is necessary to compute the mu^2 normalizations
white_filter_ip = compute_filter_ips_self(lal_filters, spec_corr, None)
unwhite_filter_ip = compute_filter_ips_self(lal_filters, spec_corr, lal_psd)
# These two are needed for the unwhitened mean square sum (hrss)
white_ss_ip = compute_filter_ips_adjacent(lal_filters, spec_corr, None)
unwhite_ss_ip = compute_filter_ips_adjacent(lal_filters, spec_corr, lal_psd)

#
# Compute normalization for virtual channel
#

mu_sq_dict = {}
for nc_sum in range(0, int(math.log(nchans, 2))): # nc_sum additional channel adds
    nc_sum = 2**nc_sum - 1
    mu_sq_dict[nc_sum] = compute_channel_renomalization(nc_sum, lal_filters, spec_corr)

#
# Convert to TF plane
#

event_list = lsctables.New(lsctables.SnglBurstTable, ['start_time', 'start_time_ns', 'peak_time', 'peak_time_ns', 'duration', 'bandwidth', 'central_freq', 'chisq_dof', 'confidence', 'snr', 'amplitude', 'channel', 'ifo', 'process_id', 'event_id', 'search'])

tf_map = numpy.zeros((nchans, ts_stride), dtype=numpy.complex128)

# Determine boundaries of stride in time domain
t_idx_min, t_idx_max = 0, ts_stride
if args.analysis_start_time is not None:
    t_idx_off = args.analysis_start_time - ts_data.start_time
    t_idx_off = int(t_idx_off * args.sample_rate)
else:
    t_idx_off = 0
t_idx_min += t_idx_off
t_idx_max += t_idx_off

# Shift to user requested time if necessary
if args.analysis_start_time is not None:
    t_idx_max_off = args.analysis_end_time - ts_data.start_time
    t_idx_max_off = int(t_idx_max_off * args.sample_rate)
else:
    t_idx_max_off = len(ts_data)

tmp_filter_bank = numpy.zeros(len(fd_psd), dtype=numpy.complex128)

while t_idx_max <= t_idx_max_off:

    print "Analyzing block %10.9f - %10.9f" % (ts_data.start_time + t_idx_min/float(args.sample_rate), ts_data.start_time + t_idx_max/float(args.sample_rate))

    # DEBUG: for gaussian noise
    #print "Generating random time series"
    #for i in range(t_idx_min, t_idx_max):
        #ts_data[i] = numpy.random.normal(0, 1)
    #for i in range(t_idx_min, t_idx_max):
        #ts_data[i] = 1. if i == (t_idx_max - t_idx_min)/2 else 0.
    #print "Done"

    tmp_ts_data = types.TimeSeries(ts_data[t_idx_min:t_idx_max]*window, 1.0/args.sample_rate, epoch=ts_data.start_time + float(t_idx_min)/args.sample_rate)
    fs_data = tmp_ts_data.to_frequencyseries()

    # Whiten
    # FIXME: Whiten the filters, not the data
    fs_data.data /= numpy.sqrt(fd_psd) / numpy.sqrt(2 * fd_psd.delta_f)

    #
    # Frequency domain filtering
    #

    # Force planning to happen
    #filter.matched_filter_core(types.FrequencySeries(tmp_filter_bank, delta_f=fd_psd.delta_f), fs_data, h_norm=1, psd=fd_psd, low_frequency_cutoff=lal_filters[0].f0, high_frequency_cutoff=lal_filters[0].f0+2*band)

    import time
    tdiff = time.time()
    print "Beginning filtering: %d" % int(tdiff)
    for i in range(nchans):

        # Reset filter bank series
        tmp_filter_bank *= 0.0

        #if args.verbose:
            #print "Filter %d: %d FD samples" % (i, len(tmp_filter_bank))
            #print "Beginning filtering %d / %d: %d" % (i, nchans, time.time() - tdiff)

        f1 = int(lal_filters[i].f0/fd_psd.delta_f)
        f2 = int((lal_filters[i].f0 + 2*band)/fd_psd.delta_f)+1

        # FIXME: Why is there a factor of 2 here?
        tmp_filter_bank[f1:f2] = filter_bank[i] * 2

        filtered_series = filter.matched_filter_core(
            types.FrequencySeries(tmp_filter_bank, delta_f=fd_psd.delta_f, copy=False),
            fs_data,
            h_norm=None,
            psd=None,
            low_frequency_cutoff=lal_filters[i].f0,
            high_frequency_cutoff=lal_filters[i].f0+2*band)

        tf_map[i,:] = filtered_series[0].numpy()

    tdiff = time.time() - tdiff
    print "Done filtering: %d, total %f" % (int(time.time()), tdiff)

    #print tf_map.shape
    #print tf_map.std(axis=1)
    #print tf_map.std()

    # FIXME: Debug switch
    if True:
        chan_sum_range = [xrange(0, int(math.log(nchans, 2))+1)]
    else:
        chan_sum_range = [2**l for l in xrange(0, int(math.log(nchans, 2))+1)]

    # Clip the boundaries to remove window corruption
    clip_samples = int(dt_stride * window_fraction * args.sample_rate / 2)

    import time
    level_tdiff = tdiff = time.time()
    print "Beginning tile construction: %d" % int(tdiff)

    from scipy.signal import fftconvolve
    for nc_sum in range(0, int(math.log(nchans, 2)))[::-1]: # nc_sum additional channel adds
        # FIXME: the minus one is to undo things later
        #k = nc_sum
        nc_sum = 2**nc_sum - 1
        print "Summing %d narrow band channels" % (nc_sum+1)

        #
        # Construct virtual channels
        #

        # FIXME: What to do here?
        # Okay, what I mean is, if the us rate is not an integer multiple
        # This would invoke a resampling, likely
        #
        # By way of explanation: this is one over the twice the *full* bandwidth
        # e.g. the full amount of frequency content covered by the virtual
        # filter. This is a rate in Hz, and we need the number of samples to
        # skip in the time series for numpy, so this is multiplied by delta_t to
        # get a sample number
        us_rate = int(1.0 / (2 * band*(nc_sum+1) * ts_data.delta_t))
        print >>sys.stderr, "Undersampling rate for this level: %f" % (args.sample_rate/us_rate)

        mu_sq = mu_sq_dict[nc_sum]

        sys.stderr.write("\t...calculating tiles...")
        if clip_samples > 0: # because [0:-0] does not give the full array
            tiles = make_indp_tiles(tf_map[:,clip_samples:-clip_samples:us_rate], nc_sum, mu_sq)
        else:
            tiles = make_indp_tiles(tf_map[:,::us_rate], nc_sum, mu_sq)
        #tiles = make_indp_tiles(tf_map[:,::us_rate], nc_sum, mu_sq)

        sys.stderr.write(" TF-plane is %dx%s samples... " % tiles.shape)

        print >>sys.stderr, " done"
        print numpy.mean(tiles), numpy.var(tiles)

        if args.max_duration is not None:
            max_dof = 2 * args.max_duration * (band * (nc_sum+1))
        else:
            max_dof = 32
        assert max_dof >= 2

        print "\t\t...getting longer durations..."
        for j in [2**l for l in xrange(1, int(math.log(max_dof, 2))+1)]:
            sys.stderr.write("\t\tSumming DOF = %d ..." % j)

            #tlen = tiles.shape[1] - j + 1
            tlen = tiles.shape[1] - 2*j + 1 + 1
            if tlen <= 0:
                print >>sys.stderr, " ...not enough samples."
                continue

            dof_tiles = numpy.zeros((tiles.shape[0], tlen))
            #:sum_filter = numpy.ones(j)
            # FIXME: This is the correct filter for 50% overlap
            sum_filter = numpy.array([1,0] * (j-1) + [1])
            #sum_filter = numpy.array([1,0] * int(math.log(j, 2)-1) + [1])
            for f in range(tiles.shape[0]):
                # Sum and drop correlate tiles
                # FIXME: don't drop correlated tiles
                #output = numpy.convolve(tiles[f,:], sum_filter, 'valid')
                dof_tiles[f] = fftconvolve(tiles[f], sum_filter, 'valid')

            print >>sys.stderr, " done"

            level_tdiff = time.time() - tdiff
            print >>sys.stderr, "Done with this resolution, total %f" % level_tdiff

            #
            # Trigger finding
            #

            # Current bandwidth of the time-frequency map tiles
            current_band = band * (nc_sum + 1)
            # How much each "step" is in the frequency domain -- almost
            # assuredly the fundamental bandwidth
            df = current_band
            # How much each "step" is in the time domain -- under sampling rate
            dt = 1.0 / 2 / (2 * current_band)
            # Duration is fixed by the NDOF and bandwidth
            duration = j / 2.0 / current_band

            threshold = scipy.stats.chi2.isf(args.tile_fap, j)

            # Since we clip the data, the start time needs to be adjusted
            # accordingly
            window_offset_epoch = fs_data.epoch + dt_stride * window_fraction / 2

            trigger_list_from_map(dof_tiles, event_list, threshold, window_offset_epoch, lal_filters[0].f0 + band/2, duration, current_band, df, dt, None)
            for event in event_list[::-1]:
                if event.amplitude != None:
                    continue
                etime_min_idx = float(event.get_start()) - float(fs_data.epoch)
                etime_min_idx = int(etime_min_idx / tmp_ts_data.delta_t)
                etime_max_idx = float(event.get_start()) - float(fs_data.epoch) + event.duration
                etime_max_idx = int(etime_max_idx / tmp_ts_data.delta_t)
                # (band / 2) to account for sin^2 wings from finest filters
                flow_idx = int((event.central_freq - event.bandwidth / 2 - (band / 2) - flow) / band)
                fhigh_idx = int((event.central_freq + event.bandwidth / 2 + (band / 2) - flow) / band)
                # TODO: Check that the undersampling rate is always commensurate
                # with the indexing: that is to say that
                # mod(etime_min_idx, us_rate) == 0 always
                z_j_b = tf_map[flow_idx:fhigh_idx,etime_min_idx:etime_max_idx:us_rate]
                # FIXME: Deal with negative hrss^2 -- e.g. remove the event
                try:
                    event.amplitude = measure_hrss(z_j_b, unwhite_filter_ip[flow_idx:fhigh_idx], unwhite_ss_ip[flow_idx:fhigh_idx-1], white_ss_ip[flow_idx:fhigh_idx-1], fd_psd.delta_f, tmp_ts_data.delta_t, len(lal_filters[0].data.data), event.chisq_dof)
                except ValueError:
                    event.amplitude = 0

            print "Total number of events: %d" % len(event_list)

    tdiff = time.time() - tdiff
    print "Done with this block: total %f" % tdiff

    t_idx_min += int(ts_stride * (1 - window_fraction))
    t_idx_max += int(ts_stride * (1 - window_fraction))

xmldoc = ligolw.Document()
xmldoc.appendChild(ligolw.LIGO_LW())

ifo = args.channel_name.split(":")[0]
proc_row = register_to_xmldoc(xmldoc, __program__, args.__dict__, ifos=[ifo],version=glue.git_version.id, cvs_repository=glue.git_version.branch, cvs_entry_time=glue.git_version.date)

# Figure out the data we actually analyzed
outseg = determine_output_segment(inseg, dt_stride, args.sample_rate, window_fraction)

ss = append_search_summary(xmldoc, proc_row, ifos=("H1",), inseg=inseg, outseg=outseg)

for sb in event_list:
    sb.process_id = proc_row.process_id
    sb.search = proc_row.program
    sb.ifo, sb.channel = args.channel_name.split(":")

def make_filename(ifo, seg, tag="excesspower", ext="xml.gz"):
    if isinstance(ifo, str):
        ifostr = ifo
    else:
        ifostr = "".join(ifo)
    st_rnd, end_rnd = int(math.floor(seg[0])), int(math.ceil(seg[1]))
    dur = end_rnd - st_rnd
    return "%s-%s-%d-%d.%s" % (ifostr, tag, st_rnd, dur, ext)

xmldoc.childNodes[0].appendChild(event_list)
fname = make_filename("H1", inseg)
utils.write_filename(xmldoc, fname, gz=fname.endswith("gz"), verbose=True)
