#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Interpolation stuff
"""
import os
import re
import logging
from enum import IntEnum
import numpy as np
import pandas as pd
import xarray as xr
import xesmf
import xoa.cf
from xoa.misc import XEnumMeta
from .cf import (rename_from_standard_name,
get_lon_coord_name, get_lat_coord_name)
from .filters import erode_coast
from .io import xr_open_concat
from .times import period_vortex_to_pandas
LOGGER = logging.getLogger(__name__)
[docs]class InterpFlag(IntEnum, metaclass=XEnumMeta):
PICKED = 1
LINEAR = 2
PERSISTENT = 3
[docs]def interp_time(ds, dates, preserve_encoding=False, keep_flag=False, diurnal_vars=None, persistency=None):
"""Interpolate a dataset in time
A single interpolation per time step, with persistency.
Parameters
----------
ds : xarray.Dataset
Dataset to interpolate.
dates : single or list of dates
Interpolation dates.
Returns
-------
xarray.Dataset
"""
# Time dim
cfs = xoa.cf.get_cf_specs()
time_name = cfs.search_coord(ds, cf_name="time", get="cf_name")
time = ds.indexes[time_name]
begindate_nc = time[0]
enddate_nc = time[-1]
dates = pd.to_datetime(dates)
# For persistency
ds_first = ds.isel({time_name: slice(0, 1)})
ds_last = ds.isel({time_name: slice(-1, None)})
flag = xr.DataArray(np.zeros(len(dates), dtype="i"),
coords=[dates], dims=[time_name], name='flag')
# Loop on target times
dss = []
for date in dates:
# Persistency
if date < begindate_nc or date > enddate_nc:
# New dataset
this_ds = (ds_first if date < begindate_nc else ds_last)
# New time
this_ds = this_ds.assign_coords({time_name: [date]})
this_ds[time_name].attrs.update(ds[time_name].attrs)
if persistency:
m = re.search(r"mean_(?P<term>.*)(?P<units>.)", persistency, re.DOTALL)
if m:
ds_tr = ds.copy()
timedelta = pd.to_timedelta(int(m["term"]), m["units"])
if date > enddate_nc:
date_target = enddate_nc-timedelta
ds_tr = ds_tr.where(ds_tr[time_name]>=date_target).mean(dim=time_name, keepdims=True)
elif date < begindate_nc:
date_target = begindate_nc+timedelta
ds_tr = ds_tr.where(ds_tr[time_name]<=date_target).mean(dim=time_name, keepdims=True)
for vname in ds_tr:
this_ds[vname] = this_ds[vname].copy(data=ds_tr[vname].data)
if diurnal_vars:
if date < begindate_nc:
time_target = date+pd.Timedelta("24h")
elif date > enddate_nc:
time_target = date-pd.Timedelta("24h")
for dss_target in dss:
if dss_target[time_name].values[0]==time_target:
for vname in this_ds:
if vname in diurnal_vars:
this_ds[vname] = this_ds[vname].copy(data=dss_target[vname].data)
flag.loc[date] = InterpFlag.PERSISTENT
# Interpolation
else:
# Pick this date directly if available
if date in time:
this_ds = ds.sel({time_name: [date]})
flag.loc[date] = InterpFlag.PICKED
# Interpolate because not available
else:
# Load a two-time steps dataset
it = ds.indexes[time_name].get_loc(date, method='ffill')
tmp_ds = ds.isel({time_name: slice(it, it+2)}).chunk(
{time_name: (2,)})
# Interpolate
this_ds = tmp_ds.interp({time_name: [date]}).load()
flag.loc[date] = InterpFlag.LINEAR
# Preserve netcdf compression
if preserve_encoding:
for vname, var in tmp_ds.items():
this_ds[vname].encoding.update(var.encoding)
# New time
this_ds = this_ds.assign_coords({time_name: [date]})
this_ds[time_name].attrs.update(ds[time_name].attrs)
# Append to list
dss.append(this_ds)
del this_ds
# Concatenate
dso = xr.concat(
dss, time_name, compat='no_conflicts', coords='different',
data_vars='all')
flag.attrs['flags'] = str(InterpFlag)
if keep_flag:
dso["flag"] = flag
# Attributes and encoding
dso[time_name].attrs.update(ds[time_name].attrs)
dso[time_name].encoding.update(ds[time_name].encoding)
if preserve_encoding:
for vname, var in ds.items():
dso[vname].encoding.update(var.encoding)
dso = dso.transpose(time_name,...)
return dso
[docs]def nc_interp_time(ncfiles, dates, ncout, preproc=None, postproc=None):
"""Interpolate netcdf files to a single netcdf file at a given frequency
Parameters
----------
ncfiles: str, list
Single or list of netcdf files
dates: list
Output dates
ncout: str
Netcdf file with optional date patterns
preproc: callable
Callable that operates on the input datasets after reading it
postproc: callable
Callable that operates on the output dataset before saving it
Return
------
str
Output netcdf file
"""
# Input
dsi = xr_open_concat(ncfiles)
# Preproc
if preproc:
dsi = preproc(dsi)
# Output dates
if not isinstance(dates, list):
dates = [dates]
dates = [pd.to_datetime(date) for date in dates]
date0 = dates[0]
date1 = dates[-1]
LOGGER.debug("Interpolating to a single file: "
f"from {date0:%Y-%m-%dT%H:%M} to {date1:%Y-%m-%dT%H:%M}")
# Interpolate
dso = interp_time(dsi, dates)
# Callback postprocessing
if postproc is not None:
dso = postproc(dso)
# Write netcdf
ncfile = ncout.format(**locals())
dso.to_netcdf(ncfile)
LOGGER.info(f"Created interpolation file: {ncfile}")
return ncfile
[docs]def nc_interp_at_freq_to_daily_nc(
ncfiles, freq, ncfmt, preproc=None, postproc=None):
"""Interpolate netcdf files to daily netcdf files at a given frequency
Parameters
----------
ncfiles: str, list
Single or list of netcdf files
freq: str
Output frequency
ncfmt: str
Netcdf file path format with date patterns
preproc: callable
Callable that operates on the input datasets after reading it
postproc: callable
Callable that operates on the output dataset before saving it
Return
------
list
Effective list of output netcdf files
"""
# Input
dsi = xr_open_concat(ncfiles)
# Preproc
if preproc:
dsi = preproc(dsi)
# Output dates
cfs = xoa.cf.get_cf_specs()
time = cfs.coords.get(dsi, "time")
freq = period_vortex_to_pandas(freq)
date0 = pd.to_datetime(time[0].values).ceil(freq)
date1 = pd.to_datetime(time[-1].values).ceil(freq)
day0 = date0.floor("1D")
day1 = date1.floor("1D")
LOGGER.info("Interpolating to daily files: "
f"from {day0:%Y-%m-%dT%H:%M} to {day1:%Y-%m-%dT%H:%M} "
f"at frequency {freq}")
days = pd.date_range(day0, day1, freq='D')
# Loop on dates
ncfiles = []
for iday in range(len(days)):
# Date
date = days[iday]
LOGGER.debug(f"Date: {date}")
t = time.sel(time=slice(date, None))
# Hours of the day
if len(t) == 1:
times = [date]
else:
times = pd.date_range(
date, date + pd.to_timedelta("24 hours"),
freq=freq, closed='left')
# Interpolate
dso = interp_time(dsi, times, preserve_encoding=True)
# Callback postprocessing
if postproc is not None:
dso = postproc(dso)
date = pd.to_datetime(date).strftime("%Y%m%d")
# Save
ncfile = ncfmt.format(**locals())
ncfiles.append(ncfile)
dso.to_netcdf(ncfile)
LOGGER.info(f"Created daily interpolation file: {ncfile}")
del dso
return ncfiles
[docs]def nc_interp_at_freq_to_nc(
ncfiles, freq, ncout, begindate=None, preproc=None, postproc=None):
"""Interpolate netcdf files to a single netcdf file at a given frequency
Parameters
----------
ncfiles: str, list
Single or list of netcdf files
freq: str
Output frequency
ncout: str
Netcdf file with optional date patterns
begindate: date
Start from this date
preproc: callable
Callable that operates on the input datasets after reading it
postproc: callable
Callable that operates on the output dataset before saving it
Return
------
str
Output netcdf file
"""
# Input
dsi = xr_open_concat(ncfiles)
# Preproc
if preproc:
dsi = preproc(dsi)
# Output dates
cfs = xoa.cf.get_cf_specs()
time = cfs.coords.get(dsi, "time").values
freq = period_vortex_to_pandas(freq)
date0 = pd.to_datetime(begindate or time[0]).ceil(freq)
date1 = pd.to_datetime(time[-1]).floor(freq)
LOGGER.debug("Interpolating to a single file: "
f"from {date0:%Y-%m-%dT%H:%M} to {date1:%Y-%m-%dT%H:%M} "
f"at frequency {freq}")
dates = pd.date_range(date0, date1, freq=freq)
# Interpolate
dso = interp_time(dsi, dates)
dso = dso.astype("float32")
# Callback postprocessing
if postproc is not None:
dso = postproc(dso)
# Write netcdf
ncfile = ncout.format(**locals())
dso.to_netcdf(ncfile)
LOGGER.info(f"Created interpolation file: {ncfile}")
return ncfile
[docs]class Regridder(object):
"""Regridder based on the `xesmf` library"""
[docs] def __init__(self, dsi, dso, regridder=None, filename=None,
reuse_weights=True, **kwregrid):
kwregrid.setdefault("method", "bilinear")
self._dsi = rename_from_standard_name(
dsi, {'longitude': get_lon_coord_name(dso),
'latitude': get_lat_coord_name(dso)})
self._dso_mapping = {}
self._dso = rename_from_standard_name(
dso, {'longitude': get_lon_coord_name(dso),
'latitude': get_lat_coord_name(dso)},
mapping=self._dso_mapping)
if regridder is not None:
self.regridder = regridder
else:
if not filename or not os.path.exists(filename):
reuse_weights = False
self.regridder = xesmf.Regridder(
self._dsi, self._dso, filename=filename,
reuse_weights=reuse_weights, **kwregrid)
[docs] def regrid(self, dai, erode=None):
"""
Regrid a variable
Parameters
----------
dai : xarray.DataArray
Single variable to regrid.
Returns
-------
xarray.DataArray
Regridded variable.
"""
# Erode coast
if erode:
dai = erode_coast(dai)
# Appropriate coord names
dai = rename_from_standard_name(
dai, {'longitude': 'lon', 'latitude': 'lat'})
dao = self.regridder(dai).rename(self._dso_mapping)
return xr.merge([self._dso, dao])
__call__ = regrid
[docs] def clean(self):
"""
Remove the weights file
"""
if hasattr(self, 'regridder'):
self.regridder.clean_weight_file()