Compute SST fitting

[1]:
import numpy as N
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import sys
import dask
from os import path
from dask.distributed import Client, LocalCluster, progress
import time
from tools import my_percentile
from scipy.optimize import least_squares,curve_fit
import dask_hpcconfig
from dask_jobqueue import PBSCluster
import glob

Set Dask workers in local mode (use all 28 cpus on the nodes)

[2]:
#LOCAL
overrides = {"cluster.n_workers": 28,"cluster.threads_per_worker":1}
cluster = dask_hpcconfig.cluster("datarmor-local",**overrides)
print(cluster.dashboard_link)
/user/mcaillau/proxy/8787/status
[3]:
# explicitly connect to the cluster we just created
client = Client(cluster)

Read croco grid

[4]:
#read grid
ds_grid=xr.open_dataset('/home/shom_simuref/CROCO/ODC/CONFIGS/MEDITERRANEE_GLOBALE/CROCO_FILES/test2.nc')
coord_dict={"xi_rho":"X","eta_rho":"Y","xi_u":"X_U","eta_v":"Y_V"}
ds_grid=ds_grid.assign_coords({"X":ds_grid.lon_rho[0,:], "Y":ds_grid.lat_rho[:,0]})
ds_grid=ds_grid.swap_dims(coord_dict)
mask=ds_grid.mask_rho

Read stats netcdf file from another script

[5]:
#read croco
stat_dir='/home/shom_simuref/CROCO/ODC/POSTPROC/SST/'
[6]:
ds=xr.open_dataset(f'{stat_dir}/mean_sst.nc')
sst_mean=ds.temp
[7]:
#remove last index to have only 12 months
sst_mean=sst_mean.isel(time=slice(0,-1))

Define fitting function

[8]:
n_months = sst_mean.shape[0]      # number of months/points of the time-series

### curve fit model
def model_curve(t,a,b,c) :
    sst_cycle_model = a*N.cos(2*N.pi*1/12*t-b)+c
    #sst_cycle_model = a*N.cos(2*N.pi*1/12*t-b)+c+d*t
    return sst_cycle_model

### least square model (for comparison only) ####
def model_1freq_trend(x) :
    t = N.arange(0,n_months)
    sst_cycle_model = x[0]*N.cos(2*N.pi*1/12*t-x[1])+x[2]
    return sst_cycle_model

#cost function
def cost_function_1freq(x,sst_signal):
    sst_cycle_model = model_1freq_trend(x)
    return (sst_cycle_model-sst_signal)

# fitting with least square function
def optimize_1freq(x_0,sst_signal,loss):
    res_ar   = least_squares(cost_function_1freq, x_0,args=([sst_signal]),method='trf',loss=loss, f_scale=1.0)
    return res_ar.x,res_ar.cost

CROCO FITTING

Pre process of data before fitting

[11]:
#least_square initialisation guess from one point
print('calc init parameters')

#set a new time vector
xdata= N.arange(0,n_months)

#assign new coordinates
sst2=sst_mean.assign_coords(month=("time", xdata))
sst2=sst2.swap_dims({"time": "month"})

#here we use small chunks because the calcul of curvefit in xarray seem not parallelized
#it's more efficient to have a lot of chunks of the domain
#month should not be chunked because the fitting is applied along this dimension
sst2=sst2.chunk({"month":-1,"Y":100,"X":100})

#set land values to Nan to avoid to fit theses values
sst2=sst2.where(sst2>0)


#set bounds for curvefit to avoid negtive temperature
bounds={"a":(0,N.inf),"b":(-N.inf,N.inf),"c":(0,N.inf)}

#set initial guess (not mandatory)
signal=sst2[:,400,900].compute()
x_0 = N.max(signal.data)-N.min(signal.data),0,N.mean(signal.data)
p0={"a":x_0[0],"b":x_0[1],"c":x_0[2]}
print(x_0)
calc init parameters
(12.0451565, 0, 20.337656)

define the curvefit model with xarray method “curvefit” (that calls scipy.curvefit)

[13]:
%%time
sst_fit= sst2.curvefit(
       coords="month", func=model_curve,skipna=True,bounds=bounds)#,p0=p0)


CPU times: user 12 ms, sys: 0 ns, total: 12 ms
Wall time: 14.1 ms

Do the computation

[14]:
%%time
## 4min30 sur le client dask local avec 28 workers
sst_fit=sst_fit.compute()
CPU times: user 54.2 s, sys: 8.98 s, total: 1min 3s
Wall time: 4min 30s

record data to netcdf file

[1]:
sst_fit.to_netcdf(f'{stat_dir}/sst_fit_croco.nc',mode="w")
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [1], in <cell line: 2>()
      1 #record data to netcdf file
----> 2 sst_fit.to_netcdf(f'{stat_dir}/sst_fit_croco.nc',mode="w")

NameError: name 'sst_fit' is not defined

SEVIRI FITTING

[16]:
ds_sevi=xr.open_dataset(f'{stat_dir}/sst_sevi_mean.nc')

[17]:
sst_sevi=ds_sevi.sea_surface_temperature
[18]:
sst_sevi2=sst_sevi.assign_coords(month=("time", xdata))
sst_sevi2=sst_sevi2.swap_dims({"time": "month"})
sst_sevi2-=273.15
[21]:
sst_sevi2=sst_sevi2.chunk({"month":-1,"lat":60,"lon":100})
[22]:
not_null=sst_sevi2.notnull()
[23]:
test=sst_sevi2.where(not_null.all(dim="month"))
[24]:
%%time
sst_fit_sevi= test.curvefit(
       coords="month", func=model_curve,p0=p0).compute()

CPU times: user 1.57 s, sys: 224 ms, total: 1.8 s
Wall time: 5.3 s

write results into a Netcdf file

[25]:
sst_fit_sevi.to_netcdf(f'{stat_dir}/sst_fit_sevi.nc',mode="w")

DEBUG

[ ]:
#check curve fitting
[ ]:
amp=sst_fit_sevi.curvefit_coefficients.sel(param="a")
[ ]:
N.argwhere((amp.data<0))
[ ]:
signal = sst_sevi2[:,16,740]
x_0 = N.max(signal.data)-N.min(signal.data),0,N.mean(signal.data)
ydata=signal
[ ]:
#curve fit from scipy
popt, pcov = curve_fit(model_curve, xdata, ydata, bounds=(0, [30, 365, 30]))#,p0=x_0)
#curve from xarray
coeffs=sst_fit_sevi.curvefit_coefficients[16,740].data
#least square fit
x,cost = optimize_1freq(x_0,ydata.data,loss="linear")
sig_fitted = model_1freq_trend(x)
[ ]:
coeffs,popt,x
[ ]:
# plot
fig,ax=plt.subplots(1,1,figsize=(10,8))
ax.plot(xdata, ydata, 'b-', label='data')
ax.plot(xdata,model_curve(xdata,*popt),'r-',label='scipy curve fit')
ax.plot(xdata,model_curve(xdata,*coeffs),'g+',label='xarray wrapper curve fit')
ax.plot(xdata,sig_fitted,'k*',markersize=2,label='least squares')
plt.legend()

PLOT THE RESULTS

[ ]:
import cartopy.crs as ccrs
from cartopy.feature import ShapelyFeature
import cartopy.feature as cfeature
[ ]:
proj=ccrs.LambertConformal(central_latitude=38,central_longitude=15)
lon_croco=sst_fit.X
lat_croco=sst_fit.Y
lon_sevi=sst_sevi2.lon
lat_sevi=sst_sevi2.lat

convert phase to days

[ ]:
def convert_phase(phase):
    phase=phase/2*N.pi*365
    arr=phase.data
    arr[arr>365]-=365
    phase[:]=arr[:]
    print(phase.min(),phase.max())
    return phase
[ ]:
sst_croco=sst_fit.curvefit_coefficients
phase_croco=sst_croco.sel(param="b")
phase_croco2=convert_phase(phase_croco)
print(phase_croco2.min(),phase_croco2.max())
[ ]:
sst_fit_sevi2=sst_fit_sevi.curvefit_coefficients
phase_sevi=sst_fit_sevi.curvefit_coefficients.sel(param="b")
phase_sevi2=convert_phase(phase_sevi)
print(phase_sevi2.min(),phase_sevi2.max())

set dictionnaries

[ ]:
bounds=dict(a=(3,7),b=(0,30),c=(17,22))
names=dict(a="Amplitude",b="Phase",c="Intercept")
cmap=dict(a=plt.cm.jet,b=plt.cm.gray,c=plt.cm.plasma)
data_croco=dict(a=sst_croco.sel(param="a"),b=phase_croco2,c=sst_croco.sel(param="c"))
data_sevi=dict(a=sst_fit_sevi2.sel(param="a"),b=phase_sevi2,c=sst_fit_sevi2.sel(param="c"))
[ ]:
def plot_data(data,lon,lat,ax,name,model,**kw_plot):
    kwargs_plot=dict(transform=ccrs.PlateCarree())
    kwargs_plot.update(kw_plot)
    ax.set_extent([-7,36,30,45],crs=ccrs.PlateCarree())
    ax.coastlines()
    ax.add_feature(cfeature.LAND, zorder=1, edgecolor='k')
    ax.set_title(f'{name} SST {model}')
    cf=ax.pcolormesh(lon,lat,data,**kwargs_plot)
    cbar = fig.colorbar(cf, ax=ax, shrink=0.7)
[ ]:

[ ]:
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(14, 12),subplot_kw=dict(projection=proj))
for i,param in enumerate(("a","b","c")):
   print(param)
   data1=data_croco[param]
   data1=data1.where(mask==1)
   data2=data_sevi[param]
   vmin,vmax=bounds[param]
   kw_plot=dict(vmin=vmin,vmax=vmax,cmap=cmap[param])
   plot_data(data1,lon_croco,lat_croco,axes[i,0],names[param],"CROCO",**kw_plot)
   plot_data(data2,lon_sevi,lat_sevi,axes[i,1],names[param],"SEVIRI",**kw_plot)
[ ]:
data=sst_mean.mean('time')
data=data.where(mask==1)

[ ]:
fig,ax=plt.subplots(1,1,figsize=(10,8),subplot_kw=dict(projection=proj))
kwargs_plot=dict(transform=ccrs.PlateCarree(),cmap=plt.cm.jet,vmin=14,vmax=22)
ax.set_extent([-7,36,30,45],crs=ccrs.PlateCarree())
ax.coastlines()
ax.add_feature(cfeature.LAND, zorder=1, edgecolor='k')
ax.set_title(f'Amplitude SST CROCO')
cf=ax.pcolormesh(lon,lat,data,**kwargs_plot)
cbar = fig.colorbar(cf, ax=ax, shrink=0.7)

check the results with scipy

[ ]:
signal = sst2[:,400,600]
ydata=signal
x_0 = N.max(signal.data)-N.min(signal.data),0,N.mean(signal.data)
#curve fit from scipy
popt, pcov = curve_fit(model_curve, xdata, ydata)
#least square fit
x,cost = optimize_1freq(x_0,signal.data,loss="linear")
sig_fitted = model_1freq_trend(x)
[ ]:
%%time
#curve with curvefit from xarray

fit= sst2.isel(Y=400,X=600).curvefit(
       coords="month", func=model_curve)#, p0=x_0)
coeffs=fit.curvefit_coefficients.data.compute()
[ ]:
coeffs
[ ]:
# plot
fig,ax=plt.subplots(1,1,figsize=(10,8))
ax.plot(xdata, ydata, 'b-', label='data')
ax.plot(xdata,model_curve(xdata,*popt),'r-',label='scipy curve fit')
ax.plot(xdata,model_curve(xdata,*coeffs),'g+',label='xarray wrapper curve fit')
ax.plot(xdata,sig_fitted,'k*',markersize=2,label='least squares')
plt.legend()

[ ]: