"""
Atmospheric Forcing preprocessings
"""
import logging
import numpy as np
from ...log import LoggingLevelContext
with LoggingLevelContext(logging.WARNING):
import xarray as xr
from xoa.filter import erode_mask
from ...times import convert_to_julian_day
from ...interp import interp_time, Regridder
from ...phys import (
windstress,
radiativeflux,
celsius2kelvin,
watervapormixingratio,
)
from ...filters import erode_mask_vec
from .__init__ import HYCOM3D_PARAM
from .io import get_ab_file_names
LOGGER = logging.getLogger(__name__)
[docs]class AtmFrc(object):
""" Atmospheric Forcing preprocessings """
[docs] def __init__(
self, insta_files, cumul_files,
):
"""
Parameters
----------
insta_files: list
List of instantaneous grib files::
insta_files = [grib_time0, grib_time1, grib_time2, ...]
cumul_files: list(list)
List of list of cumulated grib files.
The outer level is for dates and the inner one for terms::
cumul_files = [[grib_date0_term0, grib_date0_term1],
[grib_date1_term0, grib_date1_term1],
[grib_date2_term0, grib_date2_term1]]
Note that the cumul is reset at each new date.
"""
self.insta = insta_files
self.cumul = cumul_files
LOGGER.debug("insta files: {self.insta}")
LOGGER.debug("cumul files: {self.cumul}")
[docs] def grib2dataset(self):
LOGGER.info("Convert grib format to dataset")
dsc = []
LOGGER.info("Cumulated data processing")
for grib_at_terms in self.cumul:
LOGGER.debug(grib_at_terms)
ds = []
for data in grib_at_terms:
ds.append(
xr.open_dataset(
data, engine="cfgrib", backend_kwargs={"indexpath": ""}
)
)
ds = xr.combine_nested(
ds, concat_dim="valid_time", combine_attrs="override"
)
ds = ds.drop(["time", "surface"])
ds = ds.rename(
{"longitude": "lon", "latitude": "lat", "valid_time": "time"}
)
ds = ds.sortby(ds["time"])
ds = ds.astype("float64")
ds = self.decumul(ds)
dsc.append(ds)
ds_cumul = xr.combine_nested(
dsc, concat_dim="time", combine_attrs="override"
)
ds_cumul = ds_cumul.sortby(ds_cumul["time"])
LOGGER.info("Instantaneous data processing")
dsi = []
for data in self.insta:
ds = []
for level in [0, 2, 10]:
ds.append(
xr.open_dataset(
data, engine="cfgrib",
backend_kwargs={
"indexpath": "",
"filter_by_keys": {"edition": 2, "level": level},
},
drop_variables=[
"heightAboveGround",
"meanSea",
"surface",
],
)
)
dsi.append(xr.merge(ds, combine_attrs="override"))
ds_insta = xr.combine_nested(
dsi, concat_dim="valid_time", combine_attrs="override"
)
ds_insta = ds_insta.drop(["step", "time"])
ds_insta = ds_insta.rename(
{"longitude": "lon", "latitude": "lat", "valid_time": "time"}
)
ds_insta = ds_insta.sortby(ds_insta["time"])
return ds_cumul, ds_insta
[docs] @staticmethod
def decumul(ds_cumul):
LOGGER.debug("start decumul")
if ds_cumul.dims["time"] == 1:
LOGGER.debug("one time step")
diff_time = ds_cumul["step"]
ds_decumul = ds_cumul/(diff_time/np.timedelta64(1, "s"))
else:
LOGGER.debug("more than one time step")
ds_decumul = xr.concat([ds_cumul.isel(time=0), ds_cumul.diff("time")], "time")
diff_time = ds_decumul.step.values/np.arange(1, ds_decumul.dims["time"]+1, 1)
ds_decumul = ds_decumul/(diff_time/np.timedelta64(1, "s"))
ds_decumul["time"] = ds_decumul.time - 0.5 * diff_time
if "step" in ds_decumul:
ds_decumul = ds_decumul.drop(["step"])
return ds_decumul
[docs] def interp_time(self, time, nc_out):
ds_cumul, ds_insta = self.grib2dataset()
diurnal_vars = []
for pv in HYCOM3D_PARAM.values():
if "diurnal" in pv["attrs"]:
if pv["attrs"]["diurnal"] is "True":
diurnal_vars.extend(pv["name"])
LOGGER.debug(f"Diurnal vars : {diurnal_vars}")
LOGGER.info("time interpolation of instantaneous data")
ds_insta = interp_time(ds_insta, time, diurnal_vars=diurnal_vars)
LOGGER.info("time interpolation of cumulated data")
ds_cumul = interp_time(ds_cumul, time, diurnal_vars=diurnal_vars)
ds_atmfrc = xr.merge(
[ds_cumul, ds_insta], combine_attrs="override", compat='no_conflicts',
)
LOGGER.info(f"Store results in {nc_out}")
ds_atmfrc.to_netcdf(nc_out)
[docs] @staticmethod
def parameters(nc_in, nc_out):
ds = xr.open_dataset(nc_in)
ds = celsius2kelvin(ds)
if not set(["ewss", "nsss"]).issubset([n for n in ds.variables]):
ds = windstress(ds, method="Speich")
ds = radiativeflux(ds)
ds.drop("str")
ds = watervapormixingratio(ds)
ds["tp"] *= 0.001 # [kg.m-2.s-1]-->[m.s-1]
ds["tp"] = ds["tp"].where(ds["tp"]>=0, 0)
LOGGER.info(f"Store results in {nc_out}")
ds.to_netcdf(nc_out)
[docs] @staticmethod
def regridmask(nc_mask, nc_in, nc_out, weightsfile):
"""Add an atmos mask to the dataset"""
LOGGER.info("Add atmospheric mask to the dataset")
ds = xr.open_dataset(nc_in)
ds = ds.reindex({'lat': ds.lat[::-1]})
ds_mask = xr.open_dataset(nc_mask)
ds = ds.sel(lon=slice(ds_mask.lon.min(), ds_mask.lon.max(),1),
lat=slice(ds_mask.lat.min(), ds_mask.lat.max(),1))
regridder = Regridder(
ds_mask, ds, regridder=None, filename=weightsfile)
ds = regridder.regrid(ds_mask)
ds["mask"] = (ds["mask"] <= 0.5).astype('i')
ds["mask"].attrs.update(
long_name="Land-sea mask with 1 on the ocean and 0 on land")
LOGGER.debug(ds)
LOGGER.info(f"Store results in {nc_out}")
ds.to_netcdf(nc_out)
[docs] @staticmethod
def regridvar(nc_in, nc_out, hycom_grid, weightsfile):
"""Regrid to ocean grid with appropriate handling of coasts"""
ds = xr.open_dataset(nc_in)
kernel = {"lon": 3, "lat": 3}
ds = erode_mask_vec(ds, kernel, param="wind", until=5)
scalar_variables = [
"t2m",
"r2",
"prmsl",
"ssr",
"tp",
"si10",
"vapmix",
"radflx",
]
ds = ds[scalar_variables+["ewss", "nsss", "mask"]]
for var in scalar_variables:
LOGGER.info(f"Erode mask on {var}")
ds[var] = ds[var].where(ds.mask == 1)
ds[var] = erode_mask(ds[var], kernel=kernel, until=5)
ds = ds.transpose('time', 'lat', 'lon')
LOGGER.info("Build regridder")
regridder = Regridder(
ds, hycom_grid, regridder=None, filename=weightsfile
)
LOGGER.info("Redrig the dataset")
ds = regridder.regrid(ds)
LOGGER.info(f"Store dataset in NetCDF file {nc_out}")
ds.to_netcdf(nc_out)
[docs] @staticmethod
def rename_vars(ds):
LOGGER.info("Rename atmfrc variables as expected by Hycom")
for hyparam in list(HYCOM3D_PARAM.keys()):
if len(HYCOM3D_PARAM[hyparam]["name"]) > 1:
HYCOM3D_PARAM[hyparam]["name"].remove(hyparam)
if HYCOM3D_PARAM[hyparam]["name"][0] in ds:
ds = ds.rename({HYCOM3D_PARAM[hyparam]["name"][0]: hyparam})
ds[hyparam].attrs = HYCOM3D_PARAM[hyparam]["attrs"]
for dsi in list(ds.keys()):
if dsi not in list(HYCOM3D_PARAM.keys()):
ds = ds.drop(dsi)
return ds
[docs] @staticmethod
def write_abfiles(ds, ab_out="forcing.{var_name}", **kwargs):
LOGGER.info("Write Hycom [a/b] files")
ds = ds.fillna(0)
freq = kwargs.get("freq", None)
fa_pattern, fb_pattern = get_ab_file_names(ab_out)
header_pattern = (
"{var_header} a {freq} heures \n\n"
"ORIGINE:unknow \n"
" II JJ XLONE XLONW XLATS XLATN \n"
"{ds.lon.shape[1]} {ds.lon.shape[0]} "
"{lonmax} {lonmin} "
"{latmin} {latmax} \n"
)
line_pattern = (
"{var_name}: date,range = {date} {varmin} {varmax}\n"
)
rec = int(np.ceil(float(ds.lon.size * 4) / 2 ** 14) * 2 ** 14)
valformat = "{:.3f}"
lonmin = valformat.format(ds.lon.values.min())
lonmax = valformat.format(ds.lon.values.max())
latmin = valformat.format(ds.lat.values.min())
latmax = valformat.format(ds.lat.values.max())
for var_name in list(ds.keys()):
LOGGER.info(f"{var_name}")
var_header = ds[var_name].attrs["header"]
header = header_pattern.format(**locals())
valformat = "{:.8E}"
with open(fa_pattern.format(**locals()), "wb") as fa:
with open(fb_pattern.format(**locals()), "w") as fb:
LOGGER.debug(header)
fb.writelines(header)
for itime in range(len(ds.time)):
ds[var_name][itime, :, :].values.astype(">f4").tofile(
fa
)
fa.seek(rec * (itime + 1))
date = valformat.format(
convert_to_julian_day(ds.time.values[itime])
)
varmin = valformat.format(
ds[var_name][itime, :, :].values.min()
)
varmax = valformat.format(
ds[var_name][itime, :, :].values.max()
)
fb.write(line_pattern.format(**locals()))
[docs] @staticmethod
def write_ncfiles(ds, nc_out="forcing.{var_name}.nc", **kwargs):
for var_name in list(ds.keys()):
LOGGER.info("Create {}".format(
nc_out.format(**locals())))
ds[var_name].to_netcdf(nc_out.format(**locals()))
[docs] @staticmethod
def outgen(nc_in, **kwargs):
freq = kwargs.get("freq", None)
ds = xr.open_dataset(nc_in)
ds = AtmFrc.rename_vars(ds)
AtmFrc.write_abfiles(ds, freq=freq)
AtmFrc.write_ncfiles(ds)