import datetime as dt
import xarray as xr
import pandas as pd
import math

import plotly.io as pio
import plotly.graph_objects as go
import dash_core_components as dcc


# some functions to calculate a bounding box for a given lat long coordinate, found on stackoverflow:
# https://stackoverflow.com/questions/238260/how-to-calculate-the-bounding-box-for-a-given-lat-lng-location

def deg2rad(degrees):
    """Convert degrees to radians"""
    return math.pi * degrees / 180.0


def rad2deg(radians):
    """Convert radians to degrees"""
    return 180.0 * radians / math.pi


# Semi-axes of WGS-84 geoidal reference
WGS84_a = 6378137.0  # Major semiaxis [m]
WGS84_b = 6356752.3  # Minor semiaxis [m]


def wgs84_earth_radius(lat):
    """
    Earth radius at a given latitude, according to the WGS-84 ellipsoid [m],
    see: http://en.wikipedia.org/wiki/Earth_radius
    """

    a_n = WGS84_a * WGS84_a * math.cos(lat)
    b_n = WGS84_b * WGS84_b * math.sin(lat)
    a_d = WGS84_a * math.cos(lat)
    b_d = WGS84_b * math.sin(lat)

    return math.sqrt((a_n * a_n + b_n * b_n) / (a_d * a_d + b_d * b_d))


def bounding_box(latitude_in_degrees, longitude_in_degrees, half_side_in_km):
    """
    Bounding box surrounding the point at given coordinates, assuming local approximation of Earth surface as a sphere
    of radius given by WGS84
    """
    lat = deg2rad(latitude_in_degrees)
    lon = deg2rad(longitude_in_degrees)
    half_side = 1000 * half_side_in_km

    # Radius of Earth at given latitude
    radius = wgs84_earth_radius(lat)
    # Radius of the parallel at given latitude
    pradius = radius * math.cos(lat)

    lat_min = lat - half_side / radius
    lat_max = lat + half_side / radius
    lon_min = lon - half_side / pradius
    lon_max = lon + half_side / pradius

    return rad2deg(lat_min), rad2deg(lon_min), rad2deg(lat_max), rad2deg(lon_max)


def assign_source(x):
    """Assigns a value indicating the source of observation for each fire event"""
    if x is None:
        return 'Media'
    else:
        return x


def to_datetime(x):
    """Convert the date into a datetime object"""
    return dt.datetime.strptime(x['DATE'], "%Y-%m-%d")


def create_hover_label(x):
    """Creates a label to show the user when hovering over a fire event point"""
    return f'Fire Event reported by {x["SOURCE"]}'


def get_weekday(x):
    """Return day of the week for each CAMS observation"""
    return x.weekday()


def get_hour(x):
    """Return the hour of day for each CAMS observation"""
    return x.strftime('%H')


def get_date_and_hour(x):
    """"Returns the hour and day"""
    return x['time'].strftime('%m %d %H:%M')


def get_historical_baseline(ds, timestamp, pollutant, fire_mask_flag=True):
    """Get average concentration levels during other years besides the fire for the area of interest"""

    year = timestamp.year

    all_years = ['2015', '2016', '2017', '2018', '2019', '2020']

    ds_other_years = [ds.sel(
        time=slice(
            timestamp.replace(hour=0, minute=0, year=int(y)) - dt.timedelta(days=7),
            timestamp.replace(hour=0, minute=0, year=int(y)) + dt.timedelta(days=7)
        )) for y in all_years if int(y) != year]

    if fire_mask_flag:
        fire_mask = xr.open_dataset(f"/var/www/FlaresApp/FlaresApp/data/cams/fire_mask.nc")['fire_mask']
        ds_other_years = [prev_ds.where(fire_mask) for prev_ds in ds_other_years if prev_ds.time.size != 0]

    mean_other_years = [
        pd.DataFrame(prev_ds.mean(dim=['latitude', 'longitude']).to_pandas(), columns=[f'{pollutant}_{year}']) for
        prev_ds in ds_other_years if prev_ds.time.size != 0
    ]

    df_index = pd.date_range(
        timestamp.replace(hour=0, minute=0) - dt.timedelta(days=7),
        timestamp.replace(hour=0, minute=0) + dt.timedelta(days=7),
        freq='h'
    )

    df_other_years = pd.concat(mean_other_years)
    df_other_years = df_other_years.rename(columns={df_other_years.columns[0]:'concentration'}).reset_index()

    df_other_years['date and hour'] = df_other_years.apply(get_date_and_hour, 1)
    baseline_series = df_other_years.groupby('date and hour').mean()['concentration']
    baseline_df = pd.DataFrame(data={'pollutant_conc': baseline_series, 'time': df_index})

    return baseline_df


def create_fe_plot(lat, lon, timestamp, pollutant):
    """
    Function to create and return a Plotly line-graph for a given pollutant based on coordinate and time.

    :param lat: float, latitude of the fire event
    :param lon: float, longitude of the fire event
    :param timestamp: datetime, timestamp of when the fire event occurred
    :param pollutant: the pollutant that is to be plotted
    :return: plotly figure containing line graph
    """

    ds = xr.open_dataset(f"/var/www/FlaresApp/FlaresApp/data/cams/{pollutant}.nc")[pollutant]  # select data belonging to the pollutant of interest

    try:
        ds = ds.sel(level=0)  # select only data for the ground level, in case other levels are present
    except ValueError:
        pass

    begin_fire_event = timestamp - dt.timedelta(hours=24)
    end_fire_event = timestamp + dt.timedelta(hours=24)

    bb = bounding_box(lat, lon, 50)

    # select min and max longitude and latitude to select the cams data within the bounding box
    min_lon = 360 + bb[1]
    min_lat = bb[0]
    max_lon = 360 + bb[3]
    max_lat = bb[2]

    mask_lon = (ds.longitude >= min_lon) & (ds.longitude <= max_lon)
    mask_lat = (ds.latitude >= min_lat) & (ds.latitude <= max_lat)

    ds = ds.where(mask_lon & mask_lat, drop=True)

    ds['weekday'] = ds['time'].to_pandas().apply(get_weekday)
    ds['hour'] = ds['time'].to_pandas().apply(get_hour)

    # select data for the two week period that the fire event took place
    df_fe_aoi = ds.sel(
        time=slice(
            begin_fire_event.replace(hour=0, minute=0) - dt.timedelta(days=6),
            end_fire_event.replace(hour=0, minute=0) + dt.timedelta(days=6)
        )
    )

    # select data from the cams grid-cell closest to the location of the fire event
    df_fe_loc = df_fe_aoi.sel(
        latitude=lat,
        longitude=lon,
        method='nearest'
    )

    # calculate baseline by taking the average value per hour of each weekday for the entire year
    df = pd.DataFrame(data=ds.groupby('weekday').mean(dim=['latitude', 'longitude']).to_pandas(),
                      columns=[pollutant]).reset_index()
    df['hour'] = df['time'].apply(get_hour, 1)
    df['weekday'] = df['time'].apply(get_weekday, 1)
    df = df.groupby(['weekday', 'hour']).mean().reset_index()
    df = df.rename(columns={pollutant: f'{pollutant}_mean'})

    # select the data in the dataset derived from the CAMS cell closest to the fire and convert to pandas dataframe
    fe_val_per_hour_df = pd.DataFrame(data=df_fe_loc.to_pandas(), columns={pollutant}).reset_index()

    # calculate the mean value per hour for the area around the fire event
    mean_val_per_hour_aoi_df = pd.DataFrame(
        data=df_fe_aoi.mean(dim=['latitude', 'longitude']).to_pandas(),
        columns={pollutant}).reset_index()

    mean_val_per_hour_aoi_df['hour'] = mean_val_per_hour_aoi_df['time'].apply(get_hour, 1)
    mean_val_per_hour_aoi_df['weekday'] = mean_val_per_hour_aoi_df['time'].apply(get_weekday, 1)

    # join the dataframe with the baseline values based on the hour and weekday
    mean_val_per_hour_aoi_df = mean_val_per_hour_aoi_df.merge(df, on=['hour', 'weekday'], how='left')

    # calculates the average concentration levels forthe area of interest for the other years
    hist_baseline = get_historical_baseline(ds, timestamp, pollutant, fire_mask_flag=True)

    #     hist_baseline_no_masking = get_historical_baseline(ds, timestamp, pollutant, fire_mask_flag=False)

    # create the plotly figure to plot the line graph of the concentration levels for the pollutant
    fig = go.Figure()
    fig.update_layout(template=pio.templates["plotly_dark"])  # use the standard plotly dark template
    fig.update_layout(xaxis_title="Date", yaxis_title="Pollutant Concentration µg m<sup>-3</sup>")  # axis titles

    fig.update_layout(legend=dict(  # position the legend
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01))

    colors = ["#29bf12", "#abff4f", "#08bdbd", "#f21b3f", "#ff9914"]

    fig.add_trace(go.Scatter(  # add the CAMS analysis concentration data closest to the fireevent
        x=fe_val_per_hour_df['time'],
        y=fe_val_per_hour_df[pollutant],
        mode='lines',
        name='Exact Location',
        line={'color': colors[3]},

    ))
    fig.add_trace(go.Scatter(  # add the average CAMS analysis concentration data for the surrounding area
        x=mean_val_per_hour_aoi_df['time'],
        y=mean_val_per_hour_aoi_df[pollutant],
        mode='lines',
        name='Surrounding Area (100x100km)',
        line={'color': colors[4]},

    ))
    fig.add_trace(go.Scatter(  # add the baseline values
        x=mean_val_per_hour_aoi_df['time'],
        y=mean_val_per_hour_aoi_df[f'{pollutant}_mean'],
        mode='lines',
        name='Baseline',
        line={'color': colors[2]},
    ))

    fig.add_trace(go.Scatter(  # add the baseline values
        x=hist_baseline['time'],
        y=hist_baseline['pollutant_conc'],
        mode='lines',
        name='Average Concentration Levels over for the same period (2015-2020)',
        line={'color': colors[1]},
    ))

    #     fig.add_trace(go.Scatter(  # add the baseline values
    #         x=hist_baseline_no_masking['time'],
    #         y=hist_baseline_no_masking['pollutant_conc'],
    #         mode='lines',
    #         name='Average Concentration levels during other years (2015-2020) No Mask',
    #         line={'color':colors[0]},
    #     ))

    # calculate the min and max values present within the datasets to set the range of the y-axis
    max_val = max(
        fe_val_per_hour_df[pollutant].max(),
        mean_val_per_hour_aoi_df[pollutant].max(),
        mean_val_per_hour_aoi_df[f'{pollutant}_mean'].max(),
        hist_baseline['pollutant_conc'].max()
    )

    min_val = min(
        fe_val_per_hour_df[pollutant].min(),
        mean_val_per_hour_aoi_df[pollutant].min(),
        mean_val_per_hour_aoi_df[f'{pollutant}_mean'].min(),
        hist_baseline['pollutant_conc'].min()
    )

    val_range = max_val - min_val
    offset = val_range * 0.45  # small offset to prevent the plot from being too packed together

    fig.update_layout(yaxis_range=[min_val - offset, max_val + offset])

    # add a rectangle to illustrate the observation period related to the fire
    fig.add_vrect(x0=begin_fire_event, x1=end_fire_event,
                  # y0=min_val - offset,
                  # y1=max_val + offset,
                  fillcolor="orange", opacity=0.25, line_width=0)
    # add text to indicate that the rectangle illustrates the fire
    fig.add_annotation(text='Fire', x=begin_fire_event + dt.timedelta(hours=24), y=min_val - (offset * 0.5),
                       showarrow=False)

    # remove the margins
    fig.update_layout(margin={"r": 0, "t": 0, "l": 0, "b": 0})

    # return the figure as a dcc graph object
    return dcc.Graph(id='fe_plot', figure=fig)
