import geopandas as gpd
from shapely.geometry import shape
import pandas as pd
import json
import math

all_pois = gpd.read_file("./poidata/dc_poi_merged.geojson")
amenity_type = ['Theater', 'Cafe', 'Restaurant', 'Hospitals', 'Parks', 'Pharmacy', 'Schools', 'Stadium',
                'Worship', 'Fitness', 'Market']
dict_replace = {
    'school': 'Schools',
    'restaurant': 'Restaurant',
    'stadium': 'Stadium',
    'supermarket': 'Market',
    'place_of_worship': 'Worship',
    'pharmacy': 'Pharmacy',
    'fitness_centre': 'Fitness',
    'theatre': 'Theater',
    'cafe': 'Cafe',
    'park': 'Parks',
    'hospital': 'Hospitals'
}

all_pois['amenity_type'] = all_pois['amenity_type'].replace(dict_replace)
beta = {
    'Theater': 1/23, 'Cafe': 1/23, 'Restaurant': 2/23, 'Hospitals': 3/23, 'Parks': 2/23, 'Pharmacy': 3/23,
    'Schools': 3/23, 'Stadium': 1/23, 'Worship': 2/23, 'Fitness': 2/23, 'Market': 3/23
}

output_path = r'./isochrone/dc/' # read from this files
# boun_cent = boun_cent  # Your boun_cent DataFrame
# all_pois = all_pois  # Your POI GeoDataFrame
# modes = ['drive', 'bicycle', 'walk', 'transit']
time_ranges = [10, 20, 30, 40, 50, 60]
# you may test one combination
drive_time = 20
bicycle_time = 10
walk_time = 10
transit_time = 30
neigh_name = 'NBH_NAMES' # the column name for DC geographic layer

def calc_boun(boundary_path, borough=None):
    boundary = gpd.read_file(boundary_path)
    if borough != None:
        boundary = boundary[boundary['borough'] == borough]
    boundary['geometry'] = boundary['geometry'].to_crs(epsg=4326)
    # attention to change name below
    boun_poly = boundary[[neigh_name, 'geometry']]
    boun_cent = boundary[[neigh_name, 'geometry']]
    boun_cent['cent'] = boun_cent['geometry'].centroid
    boun_cent['cent'] = boun_cent['cent'].to_crs(epsg=4326)

    # make cent in bound_cent to lat and lon columns
    boun_cent['lon'] = boun_cent.cent.apply(lambda p: p.x)
    boun_cent['lat'] = boun_cent.cent.apply(lambda p: p.y)
    boun_cent = boun_cent[[neigh_name, 'lat', 'lon']]
    boun_cent[neigh_name] = boun_cent[neigh_name].str.replace('-', ' ')
    boun_cent[neigh_name] = boun_cent[neigh_name].str.replace('/', ' ')
    boun_poly[neigh_name] = boun_poly[neigh_name].str.replace('-', ' ')
    boun_poly[neigh_name] = boun_poly[neigh_name].str.replace('/', ' ')
    return boun_poly, boun_cent

def calc_census(census_path, boun_poly):
    census_data = gpd.read_file(census_path)

    census_data.to_crs("EPSG:4326")

    # make the demo data specific to each community
    census_data['geometry'] = census_data['geometry'].to_crs(epsg=4326)
    census_data = census_data.rename(columns={'B19013_001E': 'income'})
    census_data = census_data.rename(columns={'B01003_001E': 'population'})

    # get the centroid of each geoid
    census_data['cent'] = census_data['geometry'].centroid
    census_data['cent'] = census_data['cent'].to_crs(epsg=4326)
    census_data = census_data[['GEOID', 'cent', 'income', 'population']] #, 'noschool', 'college', 'graduate', 'novechile']]

    # Ensure 'cent' column is recognized as the active geometry column
    census_data = census_data.set_geometry('cent')

    # Now attempt to set the CRS to match that of boun_poly
    census_data = census_data.to_crs(boun_poly.crs)

    # assign geoid to each community
    census_data = gpd.GeoDataFrame(census_data, geometry='cent')

    census_data = census_data.set_crs(boun_poly.crs, allow_override=True)

    geoid_comm_data = gpd.sjoin(census_data, boun_poly, how="left", op='within')

    # calculate the average income for each community
    average_income = geoid_comm_data.groupby(neigh_name)['income'].mean().reset_index()
    average_population = geoid_comm_data.groupby(neigh_name)['population'].sum().reset_index()
    merged_geoid_comm_data = pd.merge(average_income, boun_poly, on=neigh_name, how='left')

    merged_geoid_comm_data = gpd.GeoDataFrame(merged_geoid_comm_data, geometry='geometry')
    merged_geoid_comm_data = pd.merge(merged_geoid_comm_data, average_population, on=neigh_name, how='left')
    
    max_income = merged_geoid_comm_data['income'].max()
    min_income = merged_geoid_comm_data['income'].min()

    merged_geoid_comm_data['lat'] = merged_geoid_comm_data['geometry'].centroid.y
    merged_geoid_comm_data['lon'] = merged_geoid_comm_data['geometry'].centroid.x

    merged_geoid_comm_data = merged_geoid_comm_data.rename(columns={neigh_name: 'name'})

    return min_income, max_income, merged_geoid_comm_data


def read_isochrone(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
        polygon = shape(data['features'][0]['geometry'])
        return polygon

def apply_weight(row):
    amenity = row['amenity_type']
    weight = beta.get(amenity, 0)
    return row['poi_count'] * weight

def build_isochrone_df(boun_cent, modes, times, output_path=output_path):
    # modes in the order ['drive', 'bicycle', 'walk', 'transit']
    # times: d_time, b_time, w_time, t_time follows this order， from
    # Initialize a list to store all isochrone GeoDataFrames
    all_isochrones = []
    for _, cent in boun_cent.iterrows():
        name = cent['NBH_NAMES'] # it denpends on the column name in geographic layer

        for m, t in zip(modes, times):
            for ti in t:
                f_name = f"{output_path}{m}/{name}_dc_{ti}.json"
                polygon = read_isochrone(f_name)
                if polygon is not None:
                    gdf_isochrone = gpd.GeoDataFrame([{'name': name, 'mode': m, 'time_range': ti, 'geometry': polygon}], geometry='geometry')
                    gdf_isochrone = gdf_isochrone.set_crs(all_pois.crs)  # Ensure CRS consistency
                    all_isochrones.append(gdf_isochrone)

    # Concatenate all isochrone GeoDataFrames
    combined_gdf_isochrone = pd.concat(all_isochrones, ignore_index=True)

    # join POIs with isochrones considering the mode and time_range=
    joined = gpd.sjoin(all_pois, combined_gdf_isochrone, how='inner', op='within')

    # Group by name, mode, and time_range, then count POIs
    poi_counts = joined.groupby(['name', 'mode', 'time_range', 'amenity_type']).size().reset_index(name='poi_count')
    return poi_counts

def calc_gini(combined_data, df):
    gini_numerator = 0
    gini_denominator = 0
    mem_col_index = combined_data.columns.get_loc('MI')
    len = combined_data['MI'].__len__()
    # print(mem_col_index)

    for i in range(len):
        for j in range(len):
            a = (
                    combined_data.loc[i, 'wi'] *
                    combined_data.loc[j, 'wi'] *
                    abs(combined_data.iloc[i, mem_col_index] - combined_data.iloc[j, mem_col_index])
            )
            gini_numerator += a

    gini_denominator = (
            2 *
            combined_data['wi'].sum() *
            (combined_data['wi'] * combined_data.iloc[:, mem_col_index]).sum()
    )

    weighted_gini = gini_numerator / gini_denominator

    mem = 1 - weighted_gini
    record = {'time.drive': drive_time, 'time.bicycle': bicycle_time, 'time.walk': walk_time,
              'time.transit': transit_time,
              'weighted.gini': weighted_gini, 'MEM': mem}
    print(record)
    df = pd.concat([df, pd.DataFrame([record])], ignore_index=True)
    return df

def min_max_scaling(x, min_income, max_income):
    return (max_income - x) / (max_income - min_income)

def calculate_drive_cost(time_range):
    miles = (time_range / 60) * 25
    cost_per_mile = 0.67
    total_cost = miles * cost_per_mile
    return total_cost * 0.75

def calculate_bicycle_cost(time_range):
    yearly_cost = 95 / 365
    if time_range <= 45:
        return yearly_cost
    else:
        extra_time_cost = (time_range - 45) * 0.05
        return yearly_cost + extra_time_cost

def calc_mi(poi_counts, merged_geoid_comm_data, min_income, max_income, kappa_inc, cost):
    grouped_poi_counts = poi_counts.groupby('name')
    # Initialize DataFrame to store MEM results for each scenario
    mi_results = []

    for index, row in merged_geoid_comm_data.iterrows():
        community_name = row['name']
        median_income = row['income']

        # Normalize the income for kappa calculation
        normalized_income = min_max_scaling(median_income, min_income, max_income)
        kappa = normalized_income * (kappa_inc / 4030)  # Example normalization

        community_data = grouped_poi_counts.get_group(community_name)
        init_dict = dict(zip(beta.keys(), [0] * len(beta)))
        community_mi = 0
        for mode in community_data['mode'].unique():
            if mode == 'drive':
                mode_cost = calculate_drive_cost(drive_time)
            elif mode == 'bicycle':
                mode_cost = calculate_bicycle_cost(bicycle_time)
            else:
                mode_cost = cost[mode]  # Fixed costs for walk and transit
            mode_data = community_data[community_data['mode'] == mode]
            mode_data = mode_data.copy()
            mode_data.loc[:, 'w_poi_count'] = mode_data.apply(apply_weight, axis=1)
            sigma_sm = mode_data['w_poi_count'].sum()
            comm_mi = math.exp(-kappa * mode_cost) * sigma_sm
            community_mi += comm_mi
            mi_part2 = dict(zip(mode_data['amenity_type'].values, mode_data['w_poi_count'] / sigma_sm))
            for key in mi_part2.keys():
                init_dict[key] += mi_part2[key] * comm_mi
        for key in init_dict.keys():
            init_dict[key] /= community_mi
        mi_part1 = {'Community': community_name, 'MI': community_mi}
        mi_part1.update(init_dict)
        mi_results.append(mi_part1)

    # Convert results to DataFrame and print
    mi_df = pd.DataFrame(mi_results)

    # build a dataframe to store the data
    mi_df.rename(columns={'Community': 'name'}, inplace=True)
    combined_data = pd.merge(mi_df, merged_geoid_comm_data, on='name', how='inner')

    total_population = combined_data['population'].sum()
    combined_data['wi'] = combined_data['population'] / total_population

    combined_data['MI'].fillna(0, inplace=True)
    combined_data['income'].fillna(0, inplace=True)

    combined_data = combined_data.drop(['geometry', 'lat', 'lon'], axis='columns')
    combined_data.to_csv(f"./dc_mi.csv", index=0)
    return combined_data