Python xarray 模块,open_dataset() 实例源码
我们从Python开源项目中,提取了以下47个代码示例,用于说明如何使用xarray.open_dataset()。
def test_triangles(self):
"""Test the creation of triangles"""
ds = psyd.open_dataset(os.path.join(bt.test_dir, 'icon_test.nc'))
decoder = psyd.CFDecoder(ds)
var = ds.t2m[0, 0]
var.attrs.pop('grid_type', None)
self.assertTrue(decoder.is_triangular(var))
self.assertTrue(decoder.is_unstructured(var))
triangles = decoder.get_triangles(var)
self.assertEqual(len(triangles.triangles), var.size)
# Test for correct falsification
ds = psyd.open_dataset(os.path.join(bt.test_dir, 'test-t2m-u-v.nc'))
decoder = psyd.CFDecoder(ds)
self.assertFalse(decoder.is_triangular(ds.t2m[0, 0]))
self.assertFalse(decoder.is_unstructured(ds.t2m[0, 0]))
def test_update_01_isel(self):
"""test the update of a single array through the isel method"""
ds = psyd.open_dataset(bt.get_file('test-t2m-u-v.nc'))
arr = ds.psy.t2m.psy[0, 0, 0]
arr.attrs['test'] = 4
self.assertnotin('test', ds.t2m.attrs)
self.assertIs(arr.psy.base, ds)
self.assertEqual(dict(arr.psy.idims), {'time': 0, 'lev': 0, 'lat': 0,
'lon': slice(None)})
# update to next time step
arr.psy.update(time=1)
self.assertEqual(arr.time, ds.time[1])
self.assertEqual(arr.values.tolist(),
ds.t2m[1, :].values.tolist())
self.assertEqual(dict(arr.psy.idims), {'time': 1,
'lon': slice(None)})
self.assertnotin('test', ds.t2m.attrs)
self.assertIn('test', arr.attrs)
self.assertEqual(arr.test, 4)
def test_update_02_sel(self):
"""test the update of a single array through the sel method"""
ds = psyd.open_dataset(bt.get_file('test-t2m-u-v.nc'))
arr = ds.psy.t2m.psy[0,
'lon': slice(None)})
# update to next time step
arr.psy.update(time='1979-02-28T18:00', method='nearest')
self.assertEqual(arr.time, 4)
def test_update_03_isel_concat(self):
"""test the update of a concatenated array through the isel method"""
ds = psyd.open_dataset(bt.get_file('test-t2m-u-v.nc'))[['t2m', 'u']]
arr = ds.psy.to_array().psy.isel(time=0, lev=0, lat=0)
arr.attrs['test'] = 4
self.assertnotin('test', ds.t2m.attrs)
arr.name = 'something'
self.assertIs(arr.psy.base,
'lon': slice(None)})
self.assertEqual(arr.coords['variable'].values.tolist(), ['t2m', 'u'])
# update to next time step
arr.psy.update(time=1)
self.assertEqual(arr.time, ds.time[1])
self.assertEqual(arr.coords['variable'].values.tolist(), 'u'])
self.assertEqual(arr.values.tolist(),
ds[['t2m', 'u']].to_array()[
:, 1, 4)
self.assertEqual(arr.name, 'something')
def test_update_05_1variable(self):
"""Test to change the variable"""
ds = psyd.open_dataset(bt.get_file('test-t2m-u-v.nc'))
arr = ds.psy.t2m.psy[0,
'lon': slice(None)})
# update to next time step
arr.psy.update(name='u', time=1)
self.assertEqual(arr.time, ds.time[1])
self.assertEqual(arr.name, 'u')
self.assertEqual(arr.values.tolist(),
ds.u[1, 4)
def test_update_06_2variables(self):
"""test the change of the variable of a concatenated array"""
ds = psyd.open_dataset(bt.get_file('test-t2m-u-v.nc'))
arr = ds[['t2m', 'u']].to_array().isel(time=0, ds.t2m.attrs)
arr.name = 'something'
arr.psy.base = ds
self.assertEqual(dict(arr.psy.idims), 'u'])
# update to next time step
arr.psy.update(time=1, name=['u', 'v'])
self.assertEqual(arr.time, ['u', 'v'])
self.assertEqual(arr.values.tolist(),
ds[['u', 'v']].to_array()[
:, 'something')
def test_netcdf_monitor_single_time_all_vars():
try:
assert not os.path.isfile('out.nc')
monitor = NetCDFMonitor('out.nc')
monitor.store(state)
assert not os.path.isfile('out.nc') # not set to write on store
monitor.write()
assert os.path.isfile('out.nc')
with xr.open_dataset('out.nc') as ds:
assert len(ds.data_vars.keys()) == 2
assert 'air_temperature' in ds.data_vars.keys()
assert ds.data_vars['air_temperature'].attrs['units'] == 'degK'
assert tuple(ds.data_vars['air_temperature'].shape) == (1, nx, ny, nz)
assert 'air_pressure' in ds.data_vars.keys()
assert ds.data_vars['air_pressure'].attrs['units'] == 'Pa'
assert tuple(ds.data_vars['air_pressure'].shape) == (1, nz)
assert len(ds['time']) == 1
assert ds['time'][0] == np.datetime64(state['time'])
finally: # make sure we remove the output file
if os.path.isfile('out.nc'):
os.remove('out.nc')
def test_netcdf_monitor_single_write_on_store():
try:
assert not os.path.isfile('out.nc')
monitor = NetCDFMonitor('out.nc', write_on_store=True)
monitor.store(state)
assert os.path.isfile('out.nc')
with xr.open_dataset('out.nc') as ds:
assert len(ds.data_vars.keys()) == 2
assert 'air_temperature' in ds.data_vars.keys()
assert ds.data_vars['air_temperature'].attrs['units'] == 'degK'
assert tuple(ds.data_vars['air_temperature'].shape) == (1, nz)
assert len(ds['time']) == 1
assert ds['time'][0] == np.datetime64(state['time'])
finally: # make sure we remove the output file
if os.path.isfile('out.nc'):
os.remove('out.nc')
def test_fractional_cover(sr_filepath, fc_filepath):
print(sr_filepath)
print(fc_filepath)
sr_dataset = open_dataset(sr_filepath)
measurements = [
{'name': 'PV', 'dtype': 'int8', 'nodata': -1, 'units': 'percent'},
{'name': 'NPV',
{'name': 'BS',
{'name': 'UE', 'units': '1'}
]
fc_dataset = fractional_cover(sr_dataset, measurements)
assert set(fc_dataset.data_vars.keys()) == {m['name'] for m in measurements}
validation_ds = open_dataset(fc_filepath)
assert validation_ds == fc_dataset
assert validation_ds.equals(fc_dataset)
def read_netcdf(data_handle, domain=None, iter_dims=['lat', 'lon'],
start=None, stop=None, calendar='standard',
var_dict=None) -> xr.Dataset:
"""Read in a NetCDF file"""
ds = xr.open_dataset(data_handle)
if var_dict is not None:
ds.rename(var_dict, inplace=True)
if start is not None and stop is not None:
ds = ds.sel(time=slice(start, stop))
dates = ds.indexes['time']
ds['day_of_year'] = xr.Variable(('time', ), dates.dayofyear)
if domain is not None:
ds = ds.sel(**{d: domain[d] for d in iter_dims})
out = ds.load()
ds.close()
return out
def convert_hdf5_to_netcdf4(input_file, output_file):
ds = xr.open_dataset(input_file)
chl = ds['chlor_a'].to_dataset()
chl.to_netcdf(output_file, format='NETCDF4_CLASSIC')
def test_version_Metadata_with_streaming(self, api, opener):
np.random.seed(123)
times = pd.date_range('2000-01-01', '2001-12-31', name='time')
annual_cycle = np.sin(2 * np.pi * (times.dayofyear / 365.25 - 0.28))
base = 10 + 15 * np.array(annual_cycle).reshape(-1, 1)
tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3)
tmax_values = base + 3 * np.random.randn(annual_cycle.size, 3)
ds = xr.Dataset({'tmin': (('time', 'location'), tmin_values),
'tmax': (('time', tmax_values)},
{'time': times, 'location': ['IA', 'IN', 'IL']})
var = api.create('streaming_test')
with var.get_local_path(
bumpversion='patch',
dependencies={'arch1': '0.1.0', 'arch2': '0.2.0'}) as f:
ds.to_netcdf(f)
ds.close()
assert var.get_history()[-1]['dependencies']['arch2'] == '0.2.0'
tmin_values = base + 10 * np.random.randn(annual_cycle.size, 3)
ds.update({'tmin': (('time', tmin_values)})
with var.get_local_path(
bumpversion='patch', 'arch2': '1.2.0'}) as f:
with xr.open_dataset(f) as ds:
mem = ds.load()
ds.close()
mem.to_netcdf(f)
assert var.get_history()[-1]['dependencies']['arch2'] == '1.2.0'
assert var.get_history()[-1][
'checksum'] != var.get_history()[-2]['checksum']
def decode_coords(ds, gridfile=None, inplace=True):
"""
Sets the coordinates and bounds in a dataset
This static method sets those coordinates and bounds that are marked
marked in the netCDF attributes as coordinates in :attr:`ds` (without
deleting them from the variable attributes because this @R_658_4045@ion is
necessary for visualizing the data correctly)
Parameters
----------
ds: xarray.Dataset
The dataset to decode
gridfile: str
The path to a separate grid file or a xarray.Dataset instance which
may store the coordinates used in `ds`
inplace: bool,optional
If True,`ds` is modified in place
Returns
-------
xarray.Dataset
`ds` with additional coordinates"""
def add_attrs(obj):
if 'coordinates' in obj.attrs:
extra_coords.update(obj.attrs['coordinates'].split())
if 'bounds' in obj.attrs:
extra_coords.add(obj.attrs['bounds'])
if gridfile is not None and not isinstance(gridfile, xr.Dataset):
gridfile = open_dataset(gridfile)
extra_coords = set(ds.coords)
for k, v in six.iteritems(ds.variables):
add_attrs(v)
add_attrs(ds)
if gridfile is not None:
ds = ds.update({k: v for k, v in six.iteritems(gridfile.variables)
if k in extra_coords}, inplace=inplace)
ds = ds.set_coords(extra_coords.intersection(ds.variables),
inplace=inplace)
return ds
def decode_coords(ds, inplace=True):
"""
Reimplemented to set the mesh variables as coordinates
Parameters
----------
%(CFDecoder.decode_coords.parameters)s
Returns
-------
%(CFDecoder.decode_coords.returns)s"""
extra_coords = set(ds.coords)
for var in six.itervalues(ds.variables):
if 'mesh' in var.attrs:
mesh = var.attrs['mesh']
if mesh not in extra_coords:
extra_coords.add(mesh)
try:
mesh_var = ds.variables[mesh]
except KeyError:
warn('Could not find mesh variable %s' % mesh)
continue
if 'node_coordinates' in mesh_var.attrs:
extra_coords.update(
mesh_var.attrs['node_coordinates'].split())
if 'face_node_connectivity' in mesh_var.attrs:
extra_coords.add(
mesh_var.attrs['face_node_connectivity'])
if gridfile is not None and not isinstance(gridfile, xr.Dataset):
gridfile = open_dataset(gridfile)
ds = ds.update({k: v for k,
inplace=inplace)
return ds
def open_dataset(filename_or_obj, decode_cf=True, decode_times=True,
decode_coords=True, engine=None, **kwargs):
"""
Open an instance of :class:`xarray.Dataset`.
This method has the same functionality as the :func:`xarray.open_dataset`
method except that is supports an additional 'gdal' engine to open
gdal Rasters (e.g. GeoTiffs) and that is supports absolute time units like
``'day as %Y%m%d.%f'`` (if `decode_cf` and `decode_times` are True).
Parameters
----------
%(xarray.open_dataset.parameters.no_engine)s
engine: {'netcdf4','scipy','pydap','h5netcdf','gdal'},optional
Engine to use when reading netCDF files. If not provided,the default
engine is chosen based on available dependencies,with a preference for
'netcdf4'.
%(CFDecoder.decode_coords.parameters.gridfile)s
Returns
-------
xarray.Dataset
The dataset that contains the variables from `filename_or_obj`"""
# use the absolute path name (is saver when saving the project)
if isstring(filename_or_obj) and os.path.exists(filename_or_obj):
filename_or_obj = os.path.abspath(filename_or_obj)
if engine == 'gdal':
from psyplot.gdal_store import GdalStore
filename_or_obj = GdalStore(filename_or_obj)
engine = None
ds = xr.open_dataset(filename_or_obj, decode_cf=decode_cf,
decode_coords=False, engine=engine,
decode_times=decode_times, **kwargs)
if decode_cf:
ds = CFDecoder.decode_ds(
ds, decode_coords=decode_coords, decode_times=decode_times,
gridfile=gridfile, inplace=True)
return ds
def _open_ds_from_store(fname, store_mod=None, store_cls=None, **kwargs):
"""Open a dataset and return it"""
if isinstance(fname, xr.Dataset):
return fname
if store_mod is not None and store_cls is not None:
fname = getattr(import_module(store_mod), store_cls)(fname)
return open_dataset(fname, **kwargs)
def _test_coord(self, func_name, name, uname=None, name2d=False,
circ_name=None):
def check_ds(name):
self.assertEqual(getattr(d, func_name)(ds.t2m).name, name)
if name2d:
self.assertEqual(getattr(d, func_name)(ds.t2m_2d).name, name)
else:
self.assertIsNone(getattr(d, func_name)(ds.t2m_2d))
if six.PY3:
# Test whether the warning is raised if the decoder finds
# multiple dimensions
with self.assertWarnsRegex(RuntimeWarning,
'multiple matches'):
coords = 'time lat lon lev x y latitude longitude'.split()
ds.t2m.attrs.pop('coordinates', None)
for dim in 'xytz':
getattr(d, dim).update(coords)
for coord in set(coords).intersection(ds.coords):
ds.coords[coord].attrs.pop('axis', None)
getattr(d, func_name)(ds.t2m)
uname = uname or name
circ_name = circ_name or name
ds = psyd.open_dataset(os.path.join(bt.test_dir, 'test-t2m-u-v.nc'))
d = psyd.CFDecoder(ds)
check_ds(name)
ds.close()
ds = psyd.open_dataset(os.path.join(bt.test_dir, 'icon_test.nc'))
d = psyd.CFDecoder(ds)
check_ds(uname)
ds.close()
ds = psyd.open_dataset(
os.path.join(bt.test_dir, 'circumpolar_test.nc'))
d = psyd.CFDecoder(ds)
check_ds(circ_name)
ds.close()
def test_standardization(self):
"""Test the :meth:`psyplot.data.CFDecoder.standardize_dims` method"""
ds = psyd.open_dataset(os.path.join(bt.test_dir, 'test-t2m-u-v.nc'))
decoder = psyd.CFDecoder(ds)
dims = {'time': 1, 'lat': 2, 'lon': 3, 'lev': 4}
replaced = decoder.standardize_dims(ds.t2m, dims)
for dim, rep in [('time', 't'), ('lat', 'y'), ('lon', 'x'),
('lev', 'z')]:
self.assertIn(rep, replaced)
self.assertEqual(replaced[rep], dims[dim],
msg="Wrong value for %s (%s-) dimension" % (
dim, rep))
def test_idims(self):
"""Test the extraction of the slicers of the dimensions"""
ds = psyd.open_dataset(bt.get_file('test-t2m-u-v.nc'))
arr = ds.t2m[1:, 1]
arr.psy.init_accessor(base=ds)
dims = arr.psy.idims
for dim in ['time', 'lev', 'lat', 'lon']:
self.assertEqual(
psyd.safe_list(ds[dim][dims[dim]]),
psyd.safe_list(arr.coords[dim]),
msg="Slice %s for dimension %s is wrong!" % (dims[dim], dim))
# test with unkNown dimensions
if xr.__version__ >= '0.9':
ds = ds.drop('time')
arr = ds.t2m[1:, 1]
arr.psy.init_accessor(base=ds)
if not six.PY2:
with self.assertWarnsRegex(UserWarning, 'time'):
dims = arr.psy.idims
l = psyd.ArrayList.from_dataset(
ds, name='t2m', time=slice(1, None), lev=85000., method='sel')
arr = l[0]
dims = arr.psy.idims
for dim in ['time', 'lon']:
if dim == 'time':
self.assertEqual(dims[dim], slice(1, 5, 1))
else:
self.assertEqual(
psyd.safe_list(ds[dim][dims[dim]]),
psyd.safe_list(arr.coords[dim]),
msg="Slice %s for dimension %s is wrong!" % (dims[dim],
dim))
def test_is_circumpolar(self):
"""Test whether the is_circumpolar method works"""
ds = psyd.open_dataset(os.path.join(bt.test_dir,
'circumpolar_test.nc'))
decoder = psyd.CFDecoder(ds)
self.assertTrue(decoder.is_circumpolar(ds.t2m))
# test for correct falsification
ds = psyd.open_dataset(os.path.join(bt.test_dir, 'icon_test.nc'))
decoder = psyd.CFDecoder(ds)
self.assertFalse(decoder.is_circumpolar(ds.t2m))
def test_get_decoder(self):
"""Test to get the right decoder"""
ds = psyd.open_dataset(bt.get_file('simple_triangular_grid_si0.nc'))
d = psyd.CFDecoder.get_decoder(ds, ds.Mesh2_fcvar)
self.assertisinstance(d, psyd.UGridDecoder)
return ds, d
def test_auto_update(self):
"""Test the :attr:`psyplot.plotter.Plotter.no_auto_update` attribute"""
ds = psyd.open_dataset(bt.get_file('test-t2m-u-v.nc'))
arr = ds.psy.t2m.psy[0, 0]
arr.psy.init_accessor(auto_update=False)
arr.psy.update(time=1)
self.assertEqual(arr.time, ds.time[0])
arr.psy.start_update()
self.assertEqual(arr.time, ds.time[1])
arr.psy.no_auto_update = False
arr.psy.update(time=2)
self.assertEqual(arr.time, ds.time[2])
def test_array_info(self):
variables, coords = self._from_dataset_test_variables
variables['v4'] = variables['v3'].copy()
ds = xr.Dataset(variables, coords)
fname = osp.relpath(bt.get_file('test-t2m-u-v.nc'), '.')
ds2 = xr.open_dataset(fname)
l = ds.psy.create_list(
name=[['v1', ['v3', 'v4']], ['v1', 'v2']], prefer_list=True)
l.extend(ds2.psy.create_list(name=['t2m'], x=0, t=1),
new_name=True)
self.assertEqual(l.array_info(engine='netCDF4'), OrderedDict([
# first list contating an array with two variables
('arr0', OrderedDict([
('arr0', {'dims': {'t': slice(None), 'x': slice(None)},
'attrs': OrderedDict(), 'store': (None,
'name': 'v1', 'fname': None}),
('arr1', {'dims': {'y': slice(None)},
'name': [['v3',
('attrs', OrderedDict())])),
# second list with two arrays containing each one variable
('arr1', {'dims': {'y': slice(None),
'name': 'v2',
# last array from real dataset
('arr2', {'dims': {'z': slice(None), 'y': slice(None),
't': 1, 'x': 0},
'attrs': ds2.t2m.attrs,
'store': ('xarray.backends.netCDF4_',
'NetCDF4DataStore'),
'name': 't2m', 'fname': fname}),
('attrs', OrderedDict())]))
return l
def test_open_dataset(self):
fname = self.test_to_netcdf()
ref_ds = self._test_ds
ds = psyd.open_dataset(fname)
self.assertEqual(
pd.to_datetime(ds.time.values).tolist(),
pd.to_datetime(ref_ds.time.values).tolist())
def _test_engine(self, engine):
from importlib import import_module
fname = self.fname
ds = psyd.open_dataset(fname, engine=engine).load()
self.assertEqual(ds.psy.filename, fname)
store_mod, store = ds.psy.data_store
# try to load the dataset
mod = import_module(store_mod)
ds2 = psyd.open_dataset(getattr(mod, store)(fname))
ds.close()
ds2.close()
ds.psy.filename = None
dumped_fname, dumped_store_mod, dumped_store = psyd.get_filename_ds(
ds, dump=True, paths=True)
self.assertTrue(dumped_fname)
self.assertTrue(osp.exists(dumped_fname),
msg='Missing %s' % fname)
self.assertEqual(dumped_store_mod, store_mod)
self.assertEqual(dumped_store, store)
ds.close()
ds.psy.filename = None
os.remove(dumped_fname)
dumped_fname, paths=dumped_fname)
self.assertTrue(dumped_fname)
self.assertTrue(osp.exists(dumped_fname), store)
ds.close()
os.remove(dumped_fname)
def test_netcdf_monitor_multiple_times_batched_all_vars():
time_list = [
datetime(2013, 7, 20, 0),
datetime(2013, 6), 12),
]
current_state = state.copy()
try:
assert not os.path.isfile('out.nc')
monitor = NetCDFMonitor('out.nc')
for time in time_list:
current_state['time'] = time
monitor.store(current_state)
assert not os.path.isfile('out.nc') # not set to write on store
monitor.write()
assert os.path.isfile('out.nc')
with xr.open_dataset('out.nc') as ds:
assert len(ds.data_vars.keys()) == 2
assert 'air_temperature' in ds.data_vars.keys()
assert ds.data_vars['air_temperature'].attrs['units'] == 'degK'
assert tuple(ds.data_vars['air_temperature'].shape) == (
len(time_list), nz)
assert 'air_pressure' in ds.data_vars.keys()
assert ds.data_vars['air_pressure'].attrs['units'] == 'Pa'
assert tuple(ds.data_vars['air_pressure'].shape) == (
len(time_list), nz)
assert len(ds['time']) == len(time_list)
assert np.all(
ds['time'].values == [np.datetime64(time) for time in time_list])
finally: # make sure we remove the output file
if os.path.isfile('out.nc'):
os.remove('out.nc')
def test_netcdf_monitor_multiple_times_sequential_all_vars():
time_list = [
datetime(2013,
]
current_state = state.copy()
try:
assert not os.path.isfile('out.nc')
monitor = NetCDFMonitor('out.nc')
for time in time_list:
current_state['time'] = time
monitor.store(current_state)
monitor.write()
assert os.path.isfile('out.nc')
with xr.open_dataset('out.nc') as ds:
assert len(ds.data_vars.keys()) == 2
assert 'air_temperature' in ds.data_vars.keys()
assert ds.data_vars['air_temperature'].attrs['units'] == 'degK'
assert tuple(ds.data_vars['air_temperature'].shape) == (
len(time_list), nz)
assert len(ds['time']) == len(time_list)
assert np.all(
ds['time'].values == [np.datetime64(time) for time in time_list])
finally: # make sure we remove the output file
if os.path.isfile('out.nc'):
os.remove('out.nc')
def test_netcdf_monitor_multiple_times_sequential_all_vars_timedelta():
time_list = [
timedelta(hours=0),
timedelta(hours=6),
timedelta(hours=12), nz)
assert len(ds['time']) == len(time_list)
assert np.all(
ds['time'].values == [np.timedelta64(time) for time in time_list])
finally: # make sure we remove the output file
if os.path.isfile('out.nc'):
os.remove('out.nc')
def test_netcdf_monitor_multiple_times_batched_single_var():
time_list = [
datetime(2013,
]
current_state = state.copy()
try:
assert not os.path.isfile('out.nc')
monitor = NetCDFMonitor('out.nc', store_names=['air_temperature'])
for time in time_list:
current_state['time'] = time
monitor.store(current_state)
assert not os.path.isfile('out.nc') # not set to write on store
monitor.write()
assert os.path.isfile('out.nc')
with xr.open_dataset('out.nc') as ds:
assert len(ds.data_vars.keys()) == 1
assert 'air_temperature' in ds.data_vars.keys()
assert ds.data_vars['air_temperature'].attrs['units'] == 'degK'
assert tuple(ds.data_vars['air_temperature'].shape) == (
len(time_list), nz)
assert len(ds['time']) == len(time_list)
assert np.all(
ds['time'].values == [np.datetime64(time) for time in time_list])
finally: # make sure we remove the output file
if os.path.isfile('out.nc'):
os.remove('out.nc')
def test_netcdf_monitor_multiple_write_on_store():
time_list = [
datetime(2013, write_on_store=True)
for time in time_list:
current_state['time'] = time
monitor.store(current_state)
assert os.path.isfile('out.nc')
with xr.open_dataset('out.nc') as ds:
assert len(ds.data_vars.keys()) == 2
assert 'air_temperature' in ds.data_vars.keys()
assert ds.data_vars['air_temperature'].attrs['units'] == 'degK'
assert tuple(ds.data_vars['air_temperature'].shape) == (
len(time_list), nz)
assert len(ds['time']) == len(time_list)
assert np.all(
ds['time'].values == [np.datetime64(time) for time in time_list])
finally: # make sure we remove the output file
if os.path.isfile('out.nc'):
os.remove('out.nc')
def load_dictionary(filename):
dataset = xr.open_dataset(filename, engine='scipy')
return dict(dataset.data_vars)
def load(self):
"""
Load the state from the restart file.
Returns
-------
state : dict
The model state stored in the restart file.
"""
dataset = xr.open_dataset(self._filename)
state = {}
for name, value in dataset.data_vars.items():
state[name] = DataArray(value[0, :]) # remove time axis
state['time'] = datetime64_to_datetime(dataset['time'][0])
return state
def open_dataset(file_path):
ds = xr.open_dataset(file_path, mask_and_scale=False, drop_variables='crs')
ds.attrs['crs'] = datacube.utils.geometry.CRS('epsg:32754')
return ds
def __init__(self, path):
self.path = path
if isinstance(path, string_types):
self.ds = xr.open_dataset(path)
def load(path: Path) -> DataSet:
"""
Loads a data set from the specified NetCDF4 file.
Parameters
----------
path: pathlib.Path
Path to the file which should be loaded.
Returns
-------
DataSet
The data set loaded from the specified file
"""
log = logging.getLogger(__name__)
log.info("loading data set from %s", path)
data = xr.open_dataset(str(path)) # type: xr.Dataset
# restore data types
data[_DataVar.FILENAME] = data[_DataVar.FILENAME].astype(np.object).fillna(None)
data[_DataVar.CHUNK_NR] = data[_DataVar.CHUNK_NR].astype(np.object).fillna(None)
data[_DataVar.CV_FOLDS] = data[_DataVar.CV_FOLDS].astype(np.object).fillna(None)
data[_DataVar.PARTITION] = data[_DataVar.PARTITION].astype(np.object).fillna(None)
data[_DataVar.LABEL_NOMINAL] = data[_DataVar.LABEL_NOMINAL].astype(np.object).fillna(None)
data[_DataVar.LABEL_NUMERIC] = data[_DataVar.LABEL_NUMERIC].astype(np.object)
data[_DataVar.FEATURES] = data[_DataVar.FEATURES].astype(np.float32)
return DataSet(data=data,
mutable=False)
def load_netcdf_array(datafile, Meta, layer_specs=None):
'''
Loads Metadata for NetCDF
Parameters:
:datafile: str: Path on disk to NetCDF file
:Meta: dict: netcdf Metadata object
:variables: dict<str:str>,list<str>: list of variables to load
Returns:
:new_es: xr.Dataset
'''
logger.debug('load_netcdf_array: {}'.format(datafile))
ds = xr.open_dataset(datafile)
if layer_specs:
data = []
if isinstance(layer_specs, dict):
data = { k: ds[getattr(v, 'name', v)] for k, v in layer_specs.items() }
layer_spec = tuple(layer_specs.values())[0]
if isinstance(layer_specs, (list, tuple)):
data = {getattr(v, v): ds[getattr(v, v)]
for v in layer_specs }
layer_spec = layer_specs[0]
data = OrderedDict(data)
else:
data = OrderedDict([(v, ds[v]) for v in Meta['variables']])
layer_spec = None
geo_transform = take_geo_transform_from_Meta(layer_spec=layer_spec,
required=True,
**Meta)
for b, sub_dataset_name in zip(Meta['layer_Meta'], data):
b['geo_transform'] = Meta['geo_transform'] = geo_transform
b['sub_dataset_name'] = sub_dataset_name
new_es = xr.Dataset(data,
coords=_normalize_coords(ds),
attrs=Meta)
return new_es
def __call__(self,filename=None,varname=None):
if self.array_type == 'numpy':
out = Dataset(filename).variables[varname][:].squeeze()
elif self.array_type == 'xarray':
ds = xr.open_dataset(filename,chunks=self.chunks,lock=False)
out = ds[varname]
elif self.array_type == 'dask_from_numpy':
d = Dataset(filename).variables[varname][:].squeeze()
out = da.from_array(np.array(d), chunks=self.chunks)
elif self.array_type == 'dask_from_netcdf':
d = Dataset(filename).variables[varname]
out = da.from_array(d, chunks=self.chunks)
return out
def return_xarray_dataset(filename,chunks=None,**kwargs):
"""Return an xarray dataset corresponding to filename.
Parameters
----------
filename : str
path to the netcdf file from which to create a xarray dataset
chunks : dict-like
dictionnary of sizes of chunk for creating xarray.Dataset.
Returns
-------
ds : xarray.Dataset
"""
return xr.open_dataset(filename,chunks=chunks,**kwargs)
def open_data(self, **kwargs):
data = self.get_path()
if self.iTRACE_flag:
ico = xr.open_mfdataset(data['ico'], **kwargs).sortby('time')
ice = xr.open_mfdataset(data['ice'], **kwargs).sortby('time')
igo = xr.open_mfdataset(data['igo'], **kwargs).sortby('time')
igom = xr.open_mfdataset(data['igom'], **kwargs).sortby('time')
return ice, ico, igo, igom
else:
if len(data) > 1:
return xr.open_mfdataset(data, **kwargs).sortby('time')
else:
return xr.open_dataset(data[0], **kwargs).sortby('time')
def __init__(self, dataset_path):
self.dataset_path = dataset_path
try:
self.dataset = xr.open_dataset(self.dataset_path)
except OSError:
print('File not found.')
exit()
def read(self, file_path):
self.file_path = file_path
self.dataset = xr.open_dataset(self.file_path)
self.var_names = self.get_var_names(self.dataset)
def get_xarray(self):
self.dataset = xr.open_dataset(self.dataset_path)
return self.dataset
def open_mfdataset(paths,
decode_coords=True,
t_format=None, **kwargs):
"""
Open multiple files as a single dataset.
This function is essentially the same as the :func:`xarray.open_mfdataset`
function but (as the :func:`open_dataset`) supports additional decoding
and the ``'gdal'`` engine.
You can further specify the `t_format` parameter to get the time
@R_658_4045@ion from the files and use the results to concatenate the files
Parameters
----------
%(xarray.open_mfdataset.parameters.no_engine)s
%(open_dataset.parameters.engine)s
%(get_tdata.parameters.t_format)s
%(CFDecoder.decode_coords.parameters.gridfile)s
Returns
-------
xarray.Dataset
The dataset that contains the variables from `filename_or_obj`"""
if t_format is not None or engine == 'gdal':
if isinstance(paths, six.string_types):
paths = sorted(glob(paths))
if not paths:
raise IOError('no files to open')
if t_format is not None:
time, paths = get_tdata(t_format, paths)
kwargs['concat_dim'] = time
if engine == 'gdal':
from psyplot.gdal_store import GdalStore
paths = list(map(GdalStore, paths))
engine = None
kwargs['lock'] = False
ds = xr.open_mfdataset(
paths,
decode_coords=False, **kwargs)
if decode_cf:
return CFDecoder.decode_ds(ds, gridfile=gridfile, inplace=True,
decode_coords=decode_coords,
decode_times=decode_times)
return ds
def get_nldas_fora_X_and_vic_y(year, month, day, hour,
vic_or_fora, band_order=None,
prefix=None, data_arrs=None,
keep_columns=None):
'''Load data from VIC for NLDAS Forcing A Grib files
Parameters:
year: year of forecast time
month: month of forecast time
day: day of forecast time
vic_or_fora: string indicating which NLDAS data source
band_order: list of DataArray names already loaded
prefix: add a prefix to the DataArray name from Grib
data_arrs: Add the DataArrays to an existing dict
keep_columns: Retain only the DataArrays in this list,if given
Returns:
tuple or (data_arrs,band_order) where data_arrs is
an OrderedDict of DataArrays and band_order is their
order when they are flattened from rasters to a single
2-D matrix
'''
data_arrs = data_arrs or OrderedDict()
band_order = band_order or []
path = get_file(year, dset=vic_or_fora)
dset = xr.open_dataset(path, engine='pynio')
for k in dset.data_vars:
if keep_columns and k not in keep_columns:
continue
arr = getattr(dset, k)
if sorted(arr.dims) != ['lat_110', 'lon_110']:
continue
#print('Model: ',f,'Param:',k,'Detail:',arr.long_name)
lon, lat = arr.lon_110, arr.lat_110
geo_transform = [lon.Lo1, lon.Di, 0.0,
lat.La1, lat.Dj]
shp = arr.shape
canvas = Canvas(geo_transform, shp[1], shp[0], arr.dims)
arr.attrs['canvas'] = canvas
if prefix:
band_name = '{}_{}'.format(prefix, k)
else:
band_name = k
data_arrs[band_name] = arr
band_order.append(band_name)
return data_arrs, band_order
def get_filelist(pattern, date_range=None, timevar='time', calendar=None):
'''given a glob pattern,return a list of files between daterange'''
files = glob.glob(pattern)
if date_range is not None:
date_range = pd.to_datetime(list(date_range)).values
sublist = []
for f in files:
try:
kwargs = dict(mask_and_scale=False, concat_characters=False,
decode_coords=False)
if calendar:
ds = xr.open_dataset(f, decode_cf=False,
decode_times=False, **kwargs)
if (('XTIME' in ds) and not
('calendar' not in ds['XTIME'].attrs)):
ds['XTIME'].attrs['calendar'] = calendar
elif 'calendar' not in ds[timevar].attrs:
ds[timevar].attrs['calendar'] = calendar
# else decode using callendar attribute in file
ds = xr.decode_cf(ds, **kwargs)
else:
ds = xr.open_dataset(f,
**kwargs)
except Exception as e:
warnings.warn('Failed to open {}: {}'.format(f, e))
try:
ds[timevar] = ds['XTIME']
except KeyError:
pass
if CHECK_TIMEVARS:
try:
check_times(ds[timevar].values, f=f)
except ValueError as e:
warnings.warn(
'time check raised an error for file %s: %s' % (f, e))
start = ds[timevar].values[0]
end = ds[timevar].values[-1]
ds.close()
if (((start >= date_range[0]) and (start <= date_range[1])) or
((end >= date_range[0]) and (end <= date_range[1])) or
(start <= date_range[0]) and (end >= date_range[1])):
sublist.append(f)
files = sublist
files.sort()
return files
def load_data(file, varname, extent=None, period=None, **kwargs):
"""
Loads netCDF files and extracts data given a spatial extend and time period
of interest.
"""
# Open either single or multi-file data set depending if list of wildcard
if "*" in file or isinstance(file, list):
ds = xr.open_mfdataset(file, decode_times=False)
else:
ds = xr.open_dataset(file, decode_times=False)
# Construct condition based on spatial extents
if extent:
n, e, s, w = extent
ds = ds.sel(lat=(ds.lat >= s) & (ds.lat <= n))
# Account for extent crossing Greenwich
if w > e:
ds = ds.sel(lon=(ds.lon >= w) | (ds.lon <= e))
else:
ds = ds.sel(lon=(ds.lon >= w) & (ds.lon <= e))
# Construct condition base on time period
if period:
t1 = date2num(datetime(*period[0]), ds.time.units, ds.time.calendar)
t2 = date2num(datetime(*period[1]), ds.time.calendar)
ds = ds.sel(time=(ds.time >= t1) & (ds.time <= t2))
# Extra keyword arguments to select from additional dimensions (e.g. plev)
if kwargs:
ds = ds.sel(**kwargs)
# Load in the data to a numpy array
dates = num2date(ds.time, ds.time.calendar)
arr = ds[varname].values
lat = ds.lat.values
lon = ds.lon.values
# Convert pr units to mm/day
if ds[varname].units == 'kg m-2 s-1':
arr *= 86400
# Convert tas units to degK
elif ds[varname].units == 'K':
arr -= 273.15
return arr, lat, lon, dates
def load_variable(var_name, path_to_file, squeeze=False,
fix_times=True, **extr_kwargs):
""" Interface for loading an extracted variable into memory,using
either iris or xarray. If `path_to_file` is instead a raw dataset,
then the entire contents of the file will be loaded!
Parameters
----------
var_name : string
The name of the variable to load
path_to_file : string
Location of file containing variable
squeeze : bool
Load only the requested field (ignore all others) and
associated dims
fix_times : bool
Correct the timestamps to the middle of the bounds
in the variable Metadata (CESM puts them at the right
boundary which sucks!)
extr_kwargs : dict
Additional keyword arguments to pass to the extractor
"""
logger.info("Loading %s from %s" % (var_name, path_to_file))
ds = xr.open_dataset(path_to_file, **extr_kwargs)
# Todo: Revise this logic as part of generalizing time post-processing.
# Fix time unit,if necessary
# interval,timestamp = ds.time.units.split(" since ")
# timestamp = timestamp.split(" ")
# yr,mm,dy = timestamp[0].split("-")
#
# if int(yr) < 1650:
# yr = 2001
# yr = str(yr)
#
# # Re-construct at Jan 01,2001 and re-set
# timestamp[0] = "-".join([yr,dy])
# new_units = " ".join([interval,"since"] + timestamp)
# ds.time.attrs['units'] = new_units
# Todo: Generalize time post-processing.
# if fix_times:
# assert hasattr(ds,'time_bnds')
# bnds = ds.time_bnds.values
# mean_times = np.mean(bnds,axis=1)
#
# ds.time.values = mean_times
# Be pedantic and check that we don't have a "missing_value" attr
for field in ds:
if hasattr(ds[field], 'missing_value'):
del ds[field].attrs['missing_value']
# Lazy decode CF
# Todo: There's potentially a bug where decode_cf eagerly loads dask arrays
# ds = xr.decode_cf(ds)
return ds
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。