import pickle
import numpy as np
import pyproj
from scipy.ndimage import map_coordinates
from abc import abstractmethod
from pathlib import Path
from opendrift.readers.interpolation.structured import ReaderBlock
from .variables import Variables
import logging
logger = logging.getLogger(__name__)
[docs]
class StructuredReader(Variables):
"""
A structured reader. Data is gridded on a regular grid. Used by e.g.:
:class:`opendrift.readers.reader_netCDF_CF_generic.Reader`.
Attributes:
projected: is `True` if :class:`.fakeproj.fakeproj` is used because of missing projection information. The data points are assumed to be approximately equidistant on the surface (i.e. in meters).
clipped: pixels to to remove along boundary (e.g. in case of bad data).
.. seealso::
:py:mod:`opendrift.readers`
"""
# TODO: should the variables below not be instance variables, and not class variables?
clipped = 0
x = None
y = None
interpolation = 'linearNDFast'
convolve = None # Convolution kernel or kernel size
# set these in reader to save interpolators to file
save_interpolator = None
interpolator_filename = None
# Used to enable and track status of parallel coordinate transformations.
__lonlat2xy_parallel__ = None
__disable_parallel__ = False
def __init__(self):
if self.proj is None and (self.proj4 is None
or self.proj4 == 'fakeproj'):
logger.warning(f"No proj string or projection could be derived for reader {self.name}, using 'fakeproj'.")
from scipy.interpolate import LinearNDInterpolator
import copy
from . import fakeproj
# `projected` is set to True if `fakeproj` is used
self.projected = None
self.shape = None
self.proj4 = 'None'
self.proj = fakeproj.fakeproj()
self.projected = False
self.xmin = self.ymin = 0.
self.delta_x = self.delta_y = 1.
self.xmax = self.lon.shape[1] - 1
self.ymax = self.lon.shape[0] - 1
self.numx = self.xmax
self.numy = self.ymax
self.x = np.arange(0, self.xmax+1)
self.y = np.arange(0, self.ymax+1)
# Making interpolator (lon, lat) -> x
# save to speed up next time
if self.save_interpolator and self.interpolator_filename is not None:
interpolator_filename = Path(self.interpolator_filename).with_suffix('.pickle')
else:
interpolator_filename = f'{self.name}_interpolators.pickle'
if self.save_interpolator and Path(interpolator_filename).is_file():
logger.info('Loading previously saved interpolator for lon,lat to x,y conversion.')
with open(interpolator_filename, 'rb') as file_handle:
interp_dict = pickle.load(file_handle)
spl_x = interp_dict["spl_x"]
spl_y = interp_dict["spl_y"]
else:
logger.info('Making interpolator for lon,lat to x,y conversion...')
block_x, block_y = np.mgrid[self.xmin:self.xmax + 1,
self.ymin:self.ymax + 1]
block_x, block_y = block_x.T, block_y.T
spl_x = LinearNDInterpolator(
(self.lon.ravel(), self.lat.ravel()),
block_x.ravel(),
fill_value=np.nan)
# Reusing x-interpolator (deepcopy) with data for y
spl_y = copy.deepcopy(spl_x)
spl_y.values[:, 0] = block_y.ravel()
# Call interpolator to avoid threading-problem:
# https://github.com/scipy/scipy/issues/8856
spl_x((0, 0)), spl_y((0, 0))
if self.save_interpolator:
logger.info('Saving interpolator for lon,lat to x,y conversion.')
interp_dict = {"spl_x": spl_x, "spl_y": spl_y}
with open(interpolator_filename, 'wb') as f:
pickle.dump(interp_dict, f)
self.spl_x = spl_x
self.spl_y = spl_y
else:
self.projected = True
super().__init__()
# Dictionaries to store blocks of data for reuse (buffering)
self.var_block_before = {} # Data for last timestep before present
self.var_block_after = {} # Data for first timestep after present
[docs]
@abstractmethod
def get_variables(self, variables, time=None, x=None, y=None, z=None):
"""
Obtain a _block_ of values of the requested variables at all positions
(x, y, z) closest to given time. These will be stored in
:class:`opendrift.readers.interpolation.structured.ReaderBlock` and
accessed from there.
Arguments:
variables: list of variables.
time: datetime or None, time at which data are requested.
x, y: float or ndarrays; coordinates of requested points.
z: float or ndarray; vertical position (in meters, positive up)
Returns:
Dictionary
keywords: variables (string)
values: 2D ndarray bounding x and y.
"""
[docs]
def prepare(self, extent, start_time, end_time, max_speed):
"""Prepare reader for given simulation coverage in time and space."""
logger.debug('Clearing cache for reader %s before starting new simulation' % self.name)
self.var_block_before = {}
self.var_block_after = {}
if self.time_step is None and start_time is not None:
# Set buffer large enough for whole simulation
logger.debug('Time step is None for %s, setting buffer size large nough for whole simulation' % self.name)
self.set_buffer_size(max_speed, end_time-start_time)
else:
self.set_buffer_size(max_speed, self.time_step)
super().prepare(extent, start_time, end_time, max_speed)
[docs]
def set_convolution_kernel(self, convolve):
"""Set a convolution kernel or kernel size (of array of ones) used by `get_variables` on read variables."""
self.convolve = convolve
[docs]
def __convolve_block__(self, env):
"""
Convolve arrays with a kernel, if reader.convolve is set
"""
if self.convolve is not None:
from scipy import ndimage
N = self.convolve
if isinstance(N, (int, np.integer)):
kernel = np.ones((N, N))
kernel = kernel / kernel.sum()
else:
kernel = N
logger.debug('Convolving variables with kernel: %s' % kernel)
for variable in env:
if variable in ['x', 'y', 'z', 'time']:
pass
else:
if env[variable].ndim == 2:
env[variable] = ndimage.convolve(env[variable],
kernel,
mode='nearest')
elif env[variable].ndim == 3:
env[variable] = ndimage.convolve(env[variable],
kernel[:, :, None],
mode='nearest')
return env
[docs]
def lon_range(self):
if not self.global_coverage():
raise ValueError('Only valid for readers with global coverage')
if self.xmin < 0:
return '-180to180'
else:
return '0to360'
[docs]
def _get_variables_interpolated_(self, variables, profiles, profiles_depth,
time, reader_x, reader_y, z):
# For global readers, we shift coordinates to match actual lon range
if self.global_coverage():
if self.lon_range() == '-180to180':
logger.debug('Shifting coordinates to -180-180')
reader_x = np.mod(reader_x + 180, 360) - 180
elif self.lon_range() == '0to360':
logger.debug('Shifting coordinates to 0-360')
reader_x = np.mod(reader_x, 360)
elif self.proj.crs.is_geographic and self.xmin>0:
logger.debug('Modulating longitudes to 0-360 for self.name')
reader_x = np.mod(reader_x, 360)
# Find reader time_before/time_after
time_nearest, time_before, time_after, i1, i2, i3 = \
self.nearest_time(time)
logger.debug('Reader time:\n\t\t%s (before)\n\t\t%s (after)' %
(time_before, time_after))
# For variables which are not time dependent, we do not care about time
static_variables = [
'sea_floor_depth_below_sea_level', 'land_binary_mask'
]
if time == time_before or all(v in static_variables
for v in variables):
time_after = None
if profiles is not None:
# If profiles are requested for any parameters, we
# add two fake points at the end of array to make sure that the
# requested block has the depth range required for profiles
mx = np.append(reader_x, [reader_x[-1], reader_x[-1]])
my = np.append(reader_y, [reader_y[-1], reader_y[-1]])
mz = np.append(z, [0, -profiles_depth])
else:
mx = reader_x
my = reader_y
mz = z
block_before = block_after = None
blockvariables_before = variables
blockvars_before = str(variables)
blockvariables_after = variables
blockvars_after = str(variables)
for blockvars in self.var_block_before:
if all(v in blockvars for v in variables):
block_before = self.var_block_before[blockvars]
blockvariables_before = block_before.data_dict.keys()
blockvars_before = blockvars
break
blockvariables_before = variables
blockvars_before = str(variables)
for blockvars in self.var_block_after:
if all(v in blockvars for v in variables):
block_after = self.var_block_after[blockvars]
blockvariables_after = block_after.data_dict.keys()
blockvars_after = blockvars
break
# Swap before- and after-blocks if matching times
if block_before is not None and block_after is not None:
if block_before.time != time_before:
if block_after.time == time_before:
block_before = block_after
self.var_block_before[blockvars_before] = block_before
if block_after.time != time_after:
if block_before.time == time_before:
block_after = block_before
self.var_block_after[blockvars_after] = block_after
# Fetch data, if no buffer is available
if block_before is None or \
block_before.time != time_before:
reader_data_dict = \
self.__convolve_block__(
self.get_variables(blockvariables_before, time_before,
mx, my, mz)
)
self.var_block_before[blockvars_before] = \
ReaderBlock(reader_data_dict,
interpolation_horizontal=self.interpolation,
wrap_x=self.global_coverage())
try:
len_z = len(self.var_block_before[blockvars_before].z)
except:
len_z = 1
logger.debug(
('Fetched env-block (size %ix%ix%i) ' + 'for time before (%s)')
% (len(self.var_block_before[blockvars_before].x),
len(self.var_block_before[blockvars_before].y), len_z,
time_before))
block_before = self.var_block_before[blockvars_before]
if block_after is None or block_after.time != time_after:
if time_after is None:
self.var_block_after[blockvars_after] = block_before
else:
reader_data_dict = self.__convolve_block__(
self.get_variables(blockvariables_after, time_after, mx,
my, mz))
self.var_block_after[blockvars_after] = \
ReaderBlock(
reader_data_dict,
interpolation_horizontal=self.interpolation,
wrap_x=self.global_coverage())
try:
len_z = len(self.var_block_after[blockvars_after].z)
except:
len_z = 1
logger.debug(('Fetched env-block (size %ix%ix%i) ' +
'for time after (%s)') %
(len(self.var_block_after[blockvars_after].x),
len(self.var_block_after[blockvars_after].y),
len_z, time_after))
block_after = self.var_block_after[blockvars_after]
if (block_before is not None and block_before.covers_positions(
reader_x, reader_y) is False) or (\
block_after is not None and block_after.covers_positions(
reader_x, reader_y) is False):
logger.warning('Data block from %s not large enough to '
'cover element positions within timestep. '
'Buffer size (%s) must be increased. See `Variables.set_buffer_size`.' %
(self.name, str(self.buffer)))
# TODO: could add dynamic increase of buffer size here
############################################################
# Interpolate before/after blocks onto particles in space
############################################################
self.timer_start('interpolation')
logger.debug('Interpolating before (%s) in space (%s)' %
(block_before.time, self.interpolation))
env_before, env_profiles_before = block_before.interpolate(
reader_x, reader_y, z, variables, profiles, profiles_depth)
if (time_after is not None) and (time_before != time):
logger.debug('Interpolating after (%s) in space (%s)' %
(block_after.time, self.interpolation))
env_after, env_profiles_after = block_after.interpolate(
reader_x, reader_y, z, variables, profiles, profiles_depth)
self.timer_end('interpolation')
#######################
# Time interpolation
#######################
self.timer_start('interpolation_time')
env_profiles = None
if (time_after is not None) and (time_before != time) and self.always_valid is False:
weight_after = ((time - time_before).total_seconds() /
(time_after - time_before).total_seconds())
logger.debug(('Interpolating before (%s, weight %.2f) and'
'\n\t\t after (%s, weight %.2f) in time') %
(block_before.time, 1 - weight_after,
block_after.time, weight_after))
env = {}
for var in variables:
# Weighting together, and masking invalid entries
env[var] = np.ma.masked_invalid(
(env_before[var] * (1 - weight_after) +
env_after[var] * weight_after))
# Interpolating vertical profiles in time
if profiles is not None:
env_profiles = {}
logger.debug('Interpolating profiles in time')
# Truncating layers not present both before and after
numlayers = np.minimum(len(env_profiles_before['z']),
len(env_profiles_after['z']))
env_profiles['z'] = env_profiles_before['z'][0:numlayers]
for var in env_profiles_before.keys():
if var == 'z':
continue
env_profiles_before[var] = np.atleast_2d(
env_profiles_before[var])
env_profiles_after[var] = np.atleast_2d(
env_profiles_after[var])
env_profiles[var] = (
env_profiles_before[var][0:numlayers, :] *
(1 - weight_after) +
env_profiles_after[var][0:numlayers, :] * weight_after)
else:
env_profiles = None
else:
logger.debug('No time interpolation needed - right on time.')
env = env_before
if profiles is not None:
if 'env_profiles_before' in locals():
env_profiles = env_profiles_before
else:
# Copying data from environment to vertical profiles
env_profiles = {'z': [0, -profiles_depth]}
for var in profiles:
env_profiles[var] = np.ma.array([env[var], env[var]])
self.timer_end('interpolation_time')
return env, env_profiles
[docs]
def __check_env_arrays__(self, env):
"""
For the StructuredReader the variables are checked before entered into
the ReaderBlock interpolator. This methods makes the second check a
no-op.
.. seealso::
:meth:`.variables.Variables.__check_env_arrays__`.
"""
return env
[docs]
def xy2lonlat(self, x, y):
if self.projected:
return super().xy2lonlat(x, y)
else:
np.seterr(invalid='ignore') # Disable warnings for nan-values
y = np.atleast_1d(y)
x = np.atleast_1d(x)
# NB: mask coordinates outside domain
x[x < self.xmin] = np.nan
x[x > self.xmax] = np.nan
y[y < self.ymin] = np.nan
y[y < self.ymin] = np.nan
lon = map_coordinates(self.lon, [y, x],
order=1,
cval=np.nan,
mode='nearest')
lat = map_coordinates(self.lat, [y, x],
order=1,
cval=np.nan,
mode='nearest')
return (lon, lat)
[docs]
def lonlat2xy(self, lon, lat):
if self.projected:
self.__lonlat2xy_parallel__ = False
return super().lonlat2xy(lon, lat)
else:
# For larger arrays, we split and calculate in parallel
num_elements = len(np.atleast_1d(lon))
if num_elements > 10000 and not self.__disable_parallel__:
from multiprocessing import cpu_count
from concurrent.futures import ThreadPoolExecutor
self.__lonlat2xy_parallel__ = True
nproc = cpu_count()
logger.debug('Running lonlat2xy in parallel using %d threads' %
nproc)
# Chunk arrays
split_lon = np.array_split(lon, nproc)
split_lat = np.array_split(lat, nproc)
with ThreadPoolExecutor() as x:
out_x = np.concatenate(
list(x.map(self.spl_x, zip(split_lon, split_lat))))
out_y = np.concatenate(
list(x.map(self.spl_y, zip(split_lon, split_lat))))
return (out_x, out_y)
else:
logger.debug('Calculating lonlat2xy sequentially')
self.__lonlat2xy_parallel__ = False
x = self.spl_x(lon, lat)
y = self.spl_y(lon, lat)
return (x, y)
[docs]
def pixel_size(self):
if self.projected:
return super().pixel_size()
else:
lons, lats = self.xy2lonlat([self.xmin, self.xmax],
[self.ymin, self.ymin])
geod = pyproj.Geod(ellps='WGS84') # Define an ellipsoid
dist = geod.inv(lons[0], lats[0], lons[1], lats[1],
radians=False)[2]
pixelsize = dist / self.shape[0]
return pixelsize
[docs]
def get_ocean_depth_area_volume(self, lonmin, lonmax, latmin, latmax):
"""Get depth, area and volume of ocean basin within given coordinates"""
# Extract ocean depth within given boundaries
background = 'sea_floor_depth_below_sea_level'
rx, ry = self.lonlat2xy([lonmin, lonmax, lonmax, lonmin], [latmin, latmin, latmax, latmax])
rx = np.linspace(rx.min(), rx.max(), 10)
ry = np.linspace(ry.min(), ry.max(), 10)
data = self.get_variables(background, time=None, x=rx, y=ry)
x, y = np.meshgrid(data['x'], data['y'])
lon, lat = self.xy2lonlat(x, y)
depth = data[background]
depth = np.ma.masked_where(lon<lonmin, depth)
depth = np.ma.masked_where(lon>lonmax, depth)
depth = np.ma.masked_where(lat<latmin, depth)
depth = np.ma.masked_where(lat>latmax, depth)
volume = np.nansum(depth*self.pixel_size()*self.pixel_size())
area = volume/np.nanmean(depth)
return np.nanmin(depth), np.nanmax(depth), np.nanmean(depth), area, volume
[docs]
def _coverage_unit_(self):
if self.projected:
return super()._coverage_unit_()
else:
return "pixels"
[docs]
def _bbox_(self, x, y):
"""
Find bounding box on grid containing points (x, y)
"""
ix = (x - self.xmin) / self.delta_x
ix0, ix1 = np.min(ix), np.max(ix)
iy = (y - self.ymin) / self.delta_y
iy0, iy1 = np.min(iy), np.max(iy)
ix0 = np.max((self.clipped, ix0 - self.buffer)).astype(int)
iy0 = np.max((self.clipped, iy0 - self.buffer)).astype(int)
ix1 = np.min((self.numx - self.clipped, ix1 + self.buffer)).astype(int)
iy1 = np.min((self.numy - self.clipped, iy1 + self.buffer)).astype(int)
return (ix0, ix1, iy0, iy1)
[docs]
def _make_projected_grid_(self, lon, lat, eq_eps=1.e-1):
"""
Make the projected grid in cases where `lon` and `lat` are present as
2D variables, but not `x` and `y` and assert that it is approximately
equidistant.
Args:
eq_eps: tolerance for equidistance checks.
"""
if self.x is not None or self.y is not None:
logger.error("x and y variables already exist!")
logger.debug("Finding bounds of reader")
assert len(lon.shape) == 2
assert len(lat.shape) == 2
self.X, self.Y = self.lonlat2xy(lon, lat)
self.xmin, self.xmax = np.min(self.X[:]), np.max(self.X[:])
self.ymin, self.ymax = np.min(self.Y[:]), np.max(self.Y[:])
self.delta_x = np.diff(self.X).flat[0]
self.delta_y = np.diff(self.Y, axis=0).flat[0]
self.x = self.X[0, :]
self.y = self.Y[:, 0]
self.numx = len(self.x)
self.numy = len(self.y)
self.__validate_projected_grid__(eq_eps)
[docs]
def __validate_projected_grid__(self, eq_eps=1.e-1):
"""
Validate that the projected grid is approximately equidistant.
Args:
eq_eps: tolerance for equidistance checks.
Raises:
AssertionError if not equidistant within `eq_eps`.
"""
assert np.all(np.abs(self.delta_x - np.diff(self.X)) < eq_eps
), "Grid is not equidistant in X direction"
assert np.all(np.abs(self.delta_y - np.diff(self.Y, axis=0)) < eq_eps
), "Grid is not equidistant in Y direction"
assert np.all(
np.abs(np.tile(self.x, (self.X.shape[0], 1)) - self.X) < eq_eps
), "X coordinates are not aligned along Y direction"
assert np.all(
np.abs(
np.tile(np.atleast_2d(self.y).T, (1, self.Y.shape[1])) - self.Y
) < eq_eps), "Y coordinates are not aligned along X direction"
[docs]
def _slice_variable_(self,
var,
indxTime=None,
indy=None,
indx=None,
indz=None,
indrealization=None):
"""
Slice variable depending on number of dimensions available.
Args:
All arguments can be `slice` objects or index.
Returns:
`var` sliced using the slices or indexes necessary to use depending
on number of dimensions available.
Raises:
Unsupported number of dimensions (outside 2..5) raises an exception.
"""
# NOTE: use match expressions when PEP-634 (Py 3.10) is (widely)
# available.
if var.ndim == 2:
return var[indy, indx]
elif var.ndim == 3:
return var[indxTime, indy, indx]
elif var.ndim == 4:
return var[indxTime, indz, indy, indx]
elif var.ndim == 5: # Ensemble data
return var[indxTime, indz, indrealization, indy, indx]
else:
raise Exception('Wrong dimension of variable: %s: %d' %
(var, var.ndim))