Python 104

Building on the code from last week

import pandas as pd  # Loading the NYC dataset 
import numpy as np   # performing fast calculations to build our simulation
import plotly.express as px # Quick visuals for data exploration
from ipywidgets import widgets # sliders, buttons and visual layout of the dashboard
import plotly.graph_objects as go # plotly objects to place on our graph
import math # python mathematical functions

Recap - Calculation method to calculate # of people affected

# We'll set some easy defaults for this function so that we immediately have outputs which make some sense when we execute this method.
def calculate_values_SIR(removal_rate = .1
                         , infection_rate = 1.1
                         , population = 1000
                         , days = 150
                         , initial_infection  = 1
                         
                         # we've added this additional parameter which allows us to cut the 
                         # infection rate by 2 at a given day to simulate drastic policy action.
                         , intervention_day = 50
                        ):
    
    # build out starting positions
    x_days = np.linspace(0,days, days)
    y_susceptible = [population]
    y_infected = [initial_infection]
    y_removed = [0]
    
    
    for day in x_days[1:days-1]:
        
        # if this day is after the day drastic intervention is made, reduce the infection rate by 2.
        if day > intervention_day:
            infection_rate = infection_rate / 2
            
        day = int(day)
        
        daily_infected = min((infection_rate * y_infected[day-1] * (y_susceptible[day-1])/population), y_susceptible[day-1])
        daily_removed = removal_rate * y_infected[day-1]
        
        y_susceptible.append( y_susceptible[day-1] - daily_infected )
        y_infected.append(    y_infected[day-1] + daily_infected - daily_removed )
        y_removed.append(   y_removed[day-1] + daily_removed )
    
    return {'infected': np.rint(y_infected)
            , 'removed' : np.rint(y_removed)
            , 'susceptible': np.rint(y_susceptible)}
        
    

Recap - Build widgets to control the model inputs

# This slider works with floating point numbers (hence being called Float Slider) and 
# allows us to set a variable with this slider. This is going to be the way we set the 
# infection rate.
ir = widgets.FloatSlider(
                value=1.187, # this is the initial value of our slider when it appears
                min=0.0,     # the minimum value we'll allow
                max=5.0,     # the maximum value we'll allow
                step=.001,   # by what increments the slider will change when we move it
                description='Infection_rate:', # the name of the slider
                continuous_update=False # Will this slider wait until it stops moving to 
                                        # call it's update function or does it call the 
                                        # update function immediately?
)

rr = widgets.FloatSlider(
                value=.46,
                min=0.1,     # this is set to greater than 0 because this is the denominator in the R0 calculation
                max=2.0,
                step=.01,
                description='Removal_Rate:',
                continuous_update=False
)


ii = widgets.IntSlider(
                value=1,
                min=1,
                max=50,
                step=1,
                description='Initially Infected:',
                continuous_update=False
)

ip = widgets.IntSlider(
                value=1000,
                min=500,
                max=10_000_000,
                step=500,
                description='Initial Population:',
                continuous_update=False
)


iday = widgets.IntSlider(
                value=15,
                min=1,
                max=500,
                step=1,
                description='Day of intervention (reducing infection rate):',
                continuous_update=False
)


first_slider_group = widgets.HBox(children=[ir, rr])
second_slider_group = widgets.HBox(children=[ii, ip, iday])

Recap - Create the first graph with default values

# First, we use the method created above to calculate a model using the initial
# values of the sliders we just created. Given that at this point we haven't 
# displayed the sliders yet, their values will be the default values we set above.

data = calculate_values_SIR(  removal_rate = rr.value
                            , infection_rate = ir.value
                            , population = ip.value
                            , days = 150
                            , initial_infection  = ii.value
                            , intervention_day = iday.value
                        )

# Next we add all the data traces to the chart

infected_trace =  go.Bar(x = list(range(1,len(data['infected'])))
              ,y = data['infected']
              , name='Infected'
              , marker = dict(color='red')
              )

susceptible_trace = go.Bar(x = list(range(1, len(data['susceptible'])))
              , y = data['susceptible']
              , name='Susceptible'
              , marker = dict(color='rgba(0,0,255,0.5)')
              , opacity=0.5
              )

removed_trace = go.Bar(x =list(range(1, len(data['removed'])))
              ,y = data['removed']
              , name='Removed'
              , marker = dict(color='rgba(0,128,0,0.5)')
              , opacity=0.5)

# This trace is interesting as it's basically just drawing a straight line on the 
# selected intervention day.

intervention_day = go.Scatter(x = [iday.value, iday.value]
              ,y = [0, ip.value]
              , name='Intervention day'
              , marker = dict(color='darkblue')
              , line = dict(width=5)
  )

# We create our figure adding all the traces we created to the data list, and setting some layout values in the layout parameter.
g = go.FigureWidget(data=[ infected_trace, removed_trace, susceptible_trace, intervention_day ],
                    layout=go.Layout(
                         title={
                                'text': f'R0 = {ir.value / rr.value} <br /> Post-Intervention R0: {ir.value/ 2* rr.value} <br />Infection_rate={ir.value} Removal_rate={rr.value}',
                                'y':.95,
                                'x':0.5,
                                'xanchor': 'center',
                                'yanchor': 'top'}
                        ,barmode='stack'
                        ,hovermode='x'
                        ,height=900
                        ,xaxis=dict(title='Number of Days')
                        ,yaxis=dict(title='Number of People')
                    ))

# This is to update the x-axis range to show only the days where we have cases
g.update_xaxes(range=[0,np.where(data['infected']==0)[0][0]])

x=1

Recap - Connect the widgets to the calculation method to allow the widgets to update the graph

# This method will be called any time one of the sliders is modified. It will re-run our model calculation
# with the new values and update the data for the 4 traces we added to the figure. 
def response(change):

    num_days=150
    
    # recalculate the model using the new values defined by the sliders
    pop_values = calculate_values_SIR(removal_rate = rr.value
                                      , infection_rate=ir.value
                                     , initial_infection=ii.value
                                      , population=ip.value
                                      , days=num_days
                                     , intervention_day = iday.value)

    # Try to find the first day where we have no more infections, 
    # if that fails and we get an error, use the maximum number of 
    # days for which we've calculated the model. We use this later 
    # to update the x-axis range to keep our curve centered.
    try:
        end_infection = np.where(pop_values['infected']==0)[0][0]
    except IndexError:
        end_infection = num_days
    
    # plotly updates much faster if you send all the updates in one go, 
    # this construct allows us to open an update session on our chart
    # and when we leave it, it will send all the updates to the figure 
    # at once.
    with g.batch_update():
        # update the y-axis values from the model
        g.data[0].y = pop_values['infected']
        g.data[1].y = pop_values['removed']
        g.data[2].y = pop_values['susceptible']
        
        # update the x-axis values 
        g.data[0].x = list(range(1,num_days))
        g.data[1].x = list(range(1,num_days))
        g.data[2].x = list(range(1,num_days))
        
        # Add the intervention day line
        g.data[3].y = [0         , ip.value]
        g.data[3].x = [iday.value, iday.value]
        
        # update the title to show the R values as well as the infection rate and the removal rate
        g.layout.title = {
                        'text': f'R0 = {ir.value / rr.value} <br /> Post-Intervention R0: {ir.value/ 2* rr.value} <br /> Infection_rate={ir.value} Removal_rate={rr.value}',
                        'x':0.5,
                        'xanchor': 'center',
                        'yanchor': 'top'}
        
        # change the x-axis range to keep the infection curve in scope
        g.layout.xaxis = dict(range=[0,end_infection])
        
        
        
# Update each of the widgets and register our update method as the method to 
# call when they change.
ir.observe(response, names='value')
rr.observe(response, names='value')
ii.observe(response, names='value')
ip.observe(response, names='value')
iday.observe(response, names='value')

# put the widgets, and our chart together into our layout
widgets.VBox([first_slider_group, second_slider_group, g])

Add an intervention effectiveness measure

We’re going to add an additional slider to allow us to control how effective the intervention measure is.


# We'll set some easy defaults for this function so that we immediately have outputs which make some sense when we execute this method.
def calculate_values_SIR(removal_rate = .1
                         , infection_rate = 1.1
                         , population = 1000
                         , days = 150
                         , initial_infection  = 1
                         
                         # we've added this additional parameter which allows us to cut the 
                         # infection rate by 2 at a given day to simulate drastic policy action.
                         , intervention_day = 50
                         , intervention_effectiveness = .46
                        ):
    
    # build out starting positions
    x_days = np.linspace(0,days, days)
    y_susceptible = [population]
    y_infected = [initial_infection]
    y_removed = [0]
    
    
    for day in x_days[1:days-1]:
        
        # if this day is after the day drastic intervention is made, reduce the infection rate by 2.
        if day > intervention_day:
            infection_rate = infection_rate * max((1-intervention_effectiveness),0)
            
        day = int(day)
        
        daily_infected = min((infection_rate * y_infected[day-1] * (y_susceptible[day-1])/population), y_susceptible[day-1])
        daily_removed = removal_rate * y_infected[day-1]
        
        y_susceptible.append( y_susceptible[day-1] - daily_infected )
        y_infected.append(    y_infected[day-1] + daily_infected - daily_removed )
        y_removed.append(   y_removed[day-1] + daily_removed )
    
    return {'infected': np.rint(y_infected)
            , 'removed' : np.rint(y_removed)
            , 'susceptible': np.rint(y_susceptible)}
        
    
# This slider works with floating point numbers (hence being called Float Slider) and 
# allows us to set a variable with this slider. This is going to be the way we set the 
# infection rate.
ir = widgets.FloatSlider(
                value=1.187, # this is the initial value of our slider when it appears
                min=0.0,     # the minimum value we'll allow
                max=5.0,     # the maximum value we'll allow
                step=.001,   # by what increments the slider will change when we move it
                description='Infection_rate:', # the name of the slider
                continuous_update=False # Will this slider wait until it stops moving to 
                                        # call it's update function or does it call the 
                                        # update function immediately?
)

rr = widgets.FloatSlider(
                value=.46,
                min=0.1,     # this is set to greater than 0 because this is the denominator in the R0 calculation
                max=2.0,
                step=.01,
                description='Removal_Rate:',
                continuous_update=False
)


ii = widgets.IntSlider(
                value=1,
                min=1,
                max=50,
                step=1,
                description='Initially Infected:',
                continuous_update=False
)

ip = widgets.IntSlider(
                value=1000,
                min=500,
                max=10_000_000,
                step=500,
                description='Initial Population:',
                continuous_update=False
)


iday = widgets.IntSlider(
                value=15,
                min=1,
                max=500,
                step=1,
                description='Day of intervention (reducing infection rate):',
                continuous_update=False
)

ie = widgets.FloatSlider(
                value=.46,
                min=0.0,    
                max=1.0,
                step=.01,
                description='Intervention effectiveness:',
                continuous_update=False
)



first_slider_group = widgets.HBox(children=[ir, rr, ie])
second_slider_group = widgets.HBox(children=[ii, ip, iday])


# First, we use the method created above to calculate a model using the initial
# values of the sliders we just created. Given that at this point we haven't 
# displayed the sliders yet, their values will be the default values we set above.

data = calculate_values_SIR(  removal_rate = rr.value
                            , infection_rate = ir.value
                            , population = ip.value
                            , days = 150
                            , initial_infection  = ii.value
                            , intervention_day = iday.value
                            , intervention_effectiveness= ie.value
                        )

# Next we add all the data traces to the chart

infected_trace =  go.Bar(x = list(range(1,len(data['infected'])))
              ,y = data['infected']
              , name='Infected'
              , marker = dict(color='red')
              )

susceptible_trace = go.Bar(x = list(range(1, len(data['susceptible'])))
              , y = data['susceptible']
              , name='Susceptible'
              , marker = dict(color='rgba(0,0,255,0.5)')
              , opacity=0.5
              )

removed_trace = go.Bar(x =list(range(1, len(data['removed'])))
              ,y = data['removed']
              , name='Removed'
              , marker = dict(color='rgba(0,128,0,0.5)')
              , opacity=0.5)

# This trace is interesting as it's basically just drawing a straight line on the 
# selected intervention day.

intervention_day = go.Scatter(x = [iday.value, iday.value]
              ,y = [0, ip.value]
              , name='Intervention day'
              , marker = dict(color='darkblue')
              , line = dict(width=5)
  )

# We create our figure adding all the traces we created to the data list, and setting some layout values in the layout parameter.
g = go.FigureWidget(data=[ infected_trace, removed_trace, susceptible_trace, intervention_day ],
                    layout=go.Layout(
                         title={
                                'text': f'R0 = {ir.value / rr.value} <br /> Post-Intervention R0: {(ir.value * .46) / rr.value} <br />Infection_rate={ir.value} Removal_rate={rr.value}',
                                'y':.95,
                                'x':0.5,
                                'xanchor': 'center',
                                'yanchor': 'top'}
                        ,barmode='stack'
                        ,hovermode='x'
                        ,height=900
                        ,xaxis=dict(title='Number of Days')
                        ,yaxis=dict(title='Number of People')
                    ))

# This is to update the x-axis range to show only the days where we have cases
g.update_xaxes(range=[0,np.where(data['infected']==0)[0][0]])

x=1


# This method will be called any time one of the sliders is modified. It will re-run our model calculation
# with the new values and update the data for the 4 traces we added to the figure. 
def response(change):

    num_days=150
    
    # recalculate the model using the new values defined by the sliders
    pop_values = calculate_values_SIR(removal_rate = rr.value
                                      , infection_rate=ir.value
                                     , initial_infection=ii.value
                                      , population=ip.value
                                      , days=num_days
                                     , intervention_day = iday.value
                                     , intervention_effectiveness= ie.value
                                     )

    # Try to find the first day where we have no more infections, 
    # if that fails and we get an error, use the maximum number of 
    # days for which we've calculated the model. We use this later 
    # to update the x-axis range to keep our curve centered.
    try:
        end_infection = np.where(pop_values['infected']==0)[0][0]
    except IndexError:
        end_infection = num_days
    
    # plotly updates much faster if you send all the updates in one go, 
    # this construct allows us to open an update session on our chart
    # and when we leave it, it will send all the updates to the figure 
    # at once.
    with g.batch_update():
        # update the y-axis values from the model
        g.data[0].y = pop_values['infected']
        g.data[1].y = pop_values['removed']
        g.data[2].y = pop_values['susceptible']
        
        # update the x-axis values 
        g.data[0].x = list(range(1,num_days))
        g.data[1].x = list(range(1,num_days))
        g.data[2].x = list(range(1,num_days))
        
        # Add the intervention day line
        g.data[3].y = [0         , ip.value]
        g.data[3].x = [iday.value, iday.value]
        
        # update the title to show the R values as well as the infection rate and the removal rate
        
        ie_rate = ir.value * max((1-ie.value),0)
        g.layout.title = {
                        'text': f'R0 = {ir.value / rr.value} <br /> Post-Intervention R0: {ie_rate / rr.value} <br /> Infection_rate={ir.value} Removal_rate={rr.value}',
                        'x':0.5,
                        'xanchor': 'center',
                        'yanchor': 'top'}
        
        # change the x-axis range to keep the infection curve in scope
        g.layout.xaxis = dict(range=[0,end_infection])
        
        
        
# Update each of the widgets and register our update method as the method to 
# call when they change.
ir.observe(response, names='value')
rr.observe(response, names='value')
ii.observe(response, names='value')
ip.observe(response, names='value')
iday.observe(response, names='value')
ie.observe(response, names='value')

# put the widgets, and our chart together into our layout
widgets.VBox([first_slider_group, second_slider_group, g])

Add display the effectiveness vs the rate had there been no change

Update the calculation method

The first thing we need to do if we want to show the effectiveness of the intervention at the same time as what happens after an intervention is to update the calculation method to calculate both the intervened infectious rates and the un-intervened infectious rates at the same time. This sounds more complex than it actually is, all we need to do during this is keep track of how many people would have been infected/removed if there had been no intervention and subtract the intervened numbers from that.

To that end, we need to keep track of 2 additional numbers for every day: - The number of people we need to add to the infected group to count how many would be infected if there were no intervention. - The number of people we need to add to the removed group if there were no intervention.

To denote such non -intervention figures in the code we’re going to append values with ‘_noi’ so that we can keep track of the changes.

# We'll set some easy defaults for this function so that we immediately have outputs which make some sense when we execute this method.
def calculate_values_SIR(removal_rate = .1
                         , infection_rate = 1.1
                         , population = 1000
                         , days = 150
                         , initial_infection  = 1
                           
                         , incubation_period = 10
                         , fatality_rate = .03
                         , immunity_duration = 30
                         
                         # we've added this additional parameter which allows us to cut the 
                         # infection rate by 2 at a given day to simulate drastic policy action.
                         , intervention_day = 50
                         , intervention_effectiveness = .46
                        ):
    
    # build out starting positions
    x_days = np.linspace(0,days, days)
    y_susceptible = [population]
    y_infected = [initial_infection]
    y_removed = [0]
    
    y_infected_noi = [0]
    y_removed_noi = [0]
    
    
    for day in x_days[1:days-1]:
        
        # if this day is after the day drastic intervention is made, reduce the infection rate by 2.
        if day >= intervention_day:
            inter_infection_rate = infection_rate * max((1-intervention_effectiveness),0)
        else:
            inter_infection_rate = infection_rate
            
        day = int(day)
        
        # Calculate the number of people susceptible from the previous day 
        pday_susceptible = y_susceptible[day-1] + y_infected_noi[day-1]
        
        daily_infected = min((inter_infection_rate * y_infected[day-1] * (pday_susceptible)/population), pday_susceptible)
        daily_removed = removal_rate * y_infected[day-1]
        
        # If we're after the intevention, caluclate what would have happened and update the 
        # hypothetical values
        if day >= intervention_day:
            
            # calculate the hypothetical number of infected
            pday_infected_noi = y_infected[day-1] + y_infected_noi[day-1]
            
            # Calculate the hypothetical daily infected rates
            daily_infected_noi = min((infection_rate * (pday_infected_noi) * ( y_susceptible[day-1])/population), y_susceptible[day-1]) - daily_infected
            
            # Calculate the hyothetical daily removed rate
            daily_removed_noi = (removal_rate * (pday_infected_noi)) 
            
            # Calculate our daily removed additional values 
            noi_removed = max(daily_removed_noi - daily_removed, 0 )
            
            # Keep track of our unintervened infected and removed values
            y_infected_noi.append(y_infected_noi[day-1] + (daily_infected_noi) - (noi_removed))
            y_removed_noi.append( y_removed_noi[day-1] + (noi_removed))
            
        else:
            daily_infected_noi = 0
            daily_removed_noi = 0
            y_infected_noi.append(0)
            y_removed_noi.append(0)
        
        y_susceptible.append( y_susceptible[day-1] - (daily_infected + (daily_infected_noi)))
        y_infected.append(    y_infected[day-1] + daily_infected - daily_removed )
        y_removed.append(   y_removed[day-1] + daily_removed )
    
    return {'infected': np.rint(y_infected)
            , 'removed' : np.rint(y_removed)
            , 'susceptible': np.rint(y_susceptible)
            , 'infected_noi': np.rint(y_infected_noi)
            , 'removed_noi': np.rint(y_removed_noi) }
    
data = calculate_values_SIR(  removal_rate = rr.value
                            , infection_rate = ir.value
                            , population = ip.value
                            , days = 150
                            , initial_infection  = ii.value
                            , intervention_day = 15
                        )



data
{'infected': array([  1.,   2.,   3.,   5.,   9.,  15.,  26.,  43.,  71., 113., 169.,
        231., 279., 285., 246., 161., 103.,  64.,  39.,  23.,  14.,   8.,
          5.,   3.,   2.,   1.,   1.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.]),
 'removed': array([  0.,   0.,   1.,   3.,   5.,   9.,  16.,  28.,  48.,  80., 132.,
        210., 316., 444., 575., 688., 762., 810., 839., 857., 867., 874.,
        877., 880., 881., 882., 882., 882., 882., 882., 882., 883., 883.,
        883., 883., 883., 883., 883., 883., 883., 883., 883., 883., 883.,
        883., 883., 883., 883., 883., 883., 883., 883., 883., 883., 883.,
        883., 883., 883., 883., 883., 883., 883., 883., 883., 883., 883.,
        883., 883., 883., 883., 883., 883., 883., 883., 883., 883., 883.,
        883., 883., 883., 883., 883., 883., 883., 883., 883., 883., 883.,
        883., 883., 883., 883., 883., 883., 883., 883., 883., 883., 883.,
        883., 883., 883., 883., 883., 883., 883., 883., 883., 883., 883.,
        883., 883., 883., 883., 883., 883., 883., 883., 883., 883., 883.,
        883., 883., 883., 883., 883., 883., 883., 883., 883., 883., 883.,
        883., 883., 883., 883., 883., 883., 883., 883., 883., 883., 883.,
        883., 883., 883., 883., 883., 883.]),
 'susceptible': array([1000.,  999.,  997.,  993.,  987.,  977.,  959.,  930.,  882.,
         808.,  700.,  560.,  406.,  272.,  180.,  127.,   99.,   84.,
          76.,   71.,   68.,   66.,   65.,   65.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.]),
 'infected_noi': array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0., 24., 25., 21., 15., 11.,  8.,  5.,  3.,  2.,  1.,  1.,
         1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.]),
 'removed_noi': array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0., 11., 23., 32., 39., 44., 48., 50., 52., 53., 53.,
        54., 54., 54., 54., 55., 55., 55., 55., 55., 55., 55., 55., 55.,
        55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55.,
        55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55.,
        55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55.,
        55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55.,
        55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55.,
        55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55.,
        55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55.,
        55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55., 55.,
        55., 55., 55., 55., 55., 55.])}
# This slider works with floating point numbers (hence being called Float Slider) and 
# allows us to set a variable with this slider. This is going to be the way we set the 
# infection rate.
ir = widgets.FloatSlider(
                value=1.187, # this is the initial value of our slider when it appears
                min=0.0,     # the minimum value we'll allow
                max=5.0,     # the maximum value we'll allow
                step=.001,   # by what increments the slider will change when we move it
                description='Infection_rate:', # the name of the slider
                continuous_update=False # Will this slider wait until it stops moving to 
                                        # call it's update function or does it call the 
                                        # update function immediately?
)

rr = widgets.FloatSlider(
                value=.46,
                min=0.1,     # this is set to greater than 0 because this is the denominator in the R0 calculation
                max=2.0,
                step=.01,
                description='Removal_Rate:',
                continuous_update=False
)


ii = widgets.IntSlider(
                value=1,
                min=1,
                max=50,
                step=1,
                description='Initially Infected:',
                continuous_update=False
)

ip = widgets.IntSlider(
                value=1000,
                min=500,
                max=10_000_000,
                step=500,
                description='Initial Population:',
                continuous_update=False
)


iday = widgets.IntSlider(
                value=15,
                min=1,
                max=150,
                step=1,
                description='Day of intervention (reducing infection rate):',
                continuous_update=False
)


ie = widgets.FloatSlider(
                value=.46,
                min=0.0,    
                max=1.0,
                step=.01,
                description='Intervention effectiveness:',
                continuous_update=False
)



first_slider_group = widgets.HBox(children=[ir, rr, ie])
second_slider_group = widgets.HBox(children=[ii, ip, iday])

# First, we use the method created above to calculate a model using the initial
# values of the sliders we just created. Given that at this point we haven't 
# displayed the sliders yet, their values will be the default values we set above.

data = calculate_values_SIR(  removal_rate = rr.value
                            , infection_rate = ir.value
                            , population = ip.value
                            , days = 150
                            , initial_infection  = ii.value
                            , intervention_day = iday.value
                        )

# Next we add all the data traces to the chart

infected_trace =  go.Bar(x = list(range(1,len(data['infected'])))
              ,y = data['infected']
              , name='Infected'
              , marker = dict(color='red')
              )

susceptible_trace = go.Bar(x = list(range(1, len(data['susceptible'])))
              , y = data['susceptible']
              , name='Susceptible'
              , marker = dict(color='rgba(0,0,255,0.5)')
              , opacity=0.5
              )

removed_trace = go.Bar(x =list(range(1, len(data['removed'])))
              ,y = data['removed']
              , name='Removed'
              , marker = dict(color='rgba(0,128,0,0.5)')
              , opacity=0.5)

######
# ADDITION STARTS HERE
######


infected_trace_noi =  go.Bar(x = list(range(1,len(data['infected_noi'])))
              ,y = data['infected_noi']
              , name='Infected No Intervention'
              , marker = dict(color='rgba(225,0,0,0.5)')
              )


removed_trace_noi = go.Bar(x =list(range(1, len(data['removed_noi'])))
              ,y = data['removed_noi']
              , name='Removed No Intervention'
              , marker = dict(color='rgba(0,225,0,0.5)')
              )

######
# ADDITION END HERE
######


# This trace is interesting as it's basically just drawing a straight line on the 
# selected intervention day.

intervention_day = go.Scatter(x = [iday.value, iday.value]
              ,y = [0, ip.value]
              , name='Intervention day'
              , marker = dict(color='darkblue')
              , line = dict(width=5)
  )

# We create our figure adding all the traces we created to the data list, and setting some layout values in the layout parameter.
g = go.FigureWidget(data=[ infected_trace
                          , infected_trace_noi
                          , removed_trace
                          , susceptible_trace
                          , intervention_day
                          ,  removed_trace_noi ],
                    layout=go.Layout(
                         title={
                                'text': f'R0 = {ir.value / rr.value} <br /> Post-Intervention R0: {ir.value/ 2* rr.value} <br />Infection_rate={ir.value} Removal_rate={rr.value}',
                                'y':.95,
                                'x':0.5,
                                'xanchor': 'center',
                                'yanchor': 'top'}
                        ,barmode='stack'
                        ,hovermode='x'
                        ,height=600
                        ,xaxis=dict(title='Number of Days')
                        ,yaxis=dict(title='Number of People')
                    ))

# This is to update the x-axis range to show only the days where we have cases
g.update_xaxes(range=[0,np.where(data['infected']==0)[0][0]])

x=1

# This method will be called any time one of the sliders is modified. It will re-run our model calculation
# with the new values and update the data for the 4 traces we added to the figure. 
def response(change):

    num_days=150
    
    # recalculate the model using the new values defined by the sliders
    pop_values = calculate_values_SIR(removal_rate = rr.value
                                      , infection_rate=ir.value
                                     , initial_infection=ii.value
                                      , population=ip.value
                                      , days=num_days
                                     , intervention_day = iday.value
                                     , intervention_effectiveness= ie.value
                                     )

    # Try to find the first day where we have no more infections, 
    # if that fails and we get an error, use the maximum number of 
    # days for which we've calculated the model. We use this later 
    # to update the x-axis range to keep our curve centered.
    try:
        end_infection = max( np.where(pop_values['infected_noi'][iday.value+1:]==0)[0][0] + iday.value +1,
                             np.where(pop_values['infected']==0)[0][0] )
    except IndexError:
        end_infection = num_days
    
    # plotly updates much faster if you send all the updates in one go, 
    # this construct allows us to open an update session on our chart
    # and when we leave it, it will send all the updates to the figure 
    # at once.
    with g.batch_update():
        # update the y-axis values from the model
        g.data[0].y = pop_values['infected']
        g.data[1].y = pop_values['infected_noi']
        g.data[2].y = pop_values['removed']
        g.data[3].y = pop_values['susceptible']
        g.data[5].y = pop_values['removed_noi']

        
        # Add the intervention day line
        g.data[4].y = [0         , ip.value]
        g.data[4].x = [iday.value, iday.value]
        
        # update the title to show the R values as well as the infection rate and the removal rate
        
        ie_rate = ir.value * max((1-ie.value),0)
        g.layout.title = {
                        'text': f'R0 = {ir.value / rr.value} <br /> Post-Intervention R0: {ie_rate / rr.value} <br /> Infection_rate={ir.value} Removal_rate={rr.value}',
                        'x':0.5,
                        'xanchor': 'center',
                        'yanchor': 'top'}
        
        # change the x-axis range to keep the infection curve in scope
        g.layout.xaxis = dict(range=[0,end_infection])
        
        
        
        
# Update each of the widgets and register our update method as the method to 
# call when they change.
ir.observe(response, names='value')
rr.observe(response, names='value')
ii.observe(response, names='value')
ip.observe(response, names='value')
iday.observe(response, names='value')
ie.observe(response, names='value')

# put the widgets, and our chart together into our layout
widgets.VBox([first_slider_group, second_slider_group, g])