data.py (28905B)
1 # File: data.py 2 # Created Date: Saturday February 5th 2022 3 # Author: Steven Atkinson ([email protected]) 4 5 """ 6 Functions and classes for working with audio data with NAM 7 """ 8 9 import abc as _abc 10 import logging as _logging 11 from collections import namedtuple as _namedtuple 12 from copy import deepcopy as _deepcopy 13 from dataclasses import dataclass as _dataclass 14 from enum import Enum as _Enum 15 from pathlib import Path as _Path 16 from typing import ( 17 Any as _Any, 18 Callable as _Callable, 19 Dict as _Dict, 20 Optional as _Optional, 21 Sequence as _Sequence, 22 Tuple as _Tuple, 23 Union as _Union, 24 ) 25 26 import numpy as _np 27 import torch as _torch 28 import wavio as _wavio 29 from scipy.interpolate import interp1d as _interp1d 30 from torch.utils.data import Dataset as _Dataset 31 from tqdm import tqdm as _tqdm 32 33 from ._core import InitializableFromConfig as _InitializableFromConfig 34 35 logger = _logging.getLogger(__name__) 36 37 _REQUIRED_CHANNELS = 1 # Mono 38 39 40 class Split(_Enum): 41 TRAIN = "train" 42 VALIDATION = "validation" 43 44 45 @_dataclass 46 class WavInfo: 47 sampwidth: int 48 rate: int 49 50 51 class DataError(Exception): 52 """ 53 Parent class for all special exceptions raised by NAM data sets 54 """ 55 56 pass 57 58 59 class AudioShapeMismatchError(ValueError, DataError): 60 """ 61 Exception where the shape (number of samples, number of channels) of two audio files 62 don't match but were supposed to. 63 """ 64 65 def __init__(self, shape_expected, shape_actual, *args, **kwargs): 66 super().__init__(*args, **kwargs) 67 self._shape_expected = shape_expected 68 self._shape_actual = shape_actual 69 70 @property 71 def shape_expected(self): 72 return self._shape_expected 73 74 @property 75 def shape_actual(self): 76 return self._shape_actual 77 78 79 def wav_to_np( 80 filename: _Union[str, _Path], 81 rate: _Optional[int] = None, 82 require_match: _Optional[_Union[str, _Path]] = None, 83 required_shape: _Optional[_Tuple[int, ...]] = None, 84 required_wavinfo: _Optional[WavInfo] = None, 85 preroll: _Optional[int] = None, 86 info: bool = False, 87 ) -> _Union[_np.ndarray, _Tuple[_np.ndarray, WavInfo]]: 88 """ 89 :param filename: Where to load from 90 :param rate: Expected sample rate. `None` allows for anything. 91 :param require_match: If not `None`, assert that the data you get matches the shape 92 and other characteristics of another audio file at the provided location 93 :param required_shape: If not `None`, assert that the audio loaded is of shape 94 `(num_samples, num_channels)`. 95 :param required_wavinfo: If not `None`, assert that the WAV info of the laoded audio 96 matches that provided. 97 :param preroll: Drop this many samples off the front 98 :param info: If `True`, also return the WAV info of this file. 99 """ 100 x_wav = _wavio.read(str(filename)) 101 assert x_wav.data.shape[1] == _REQUIRED_CHANNELS, "Mono" 102 if rate is not None and x_wav.rate != rate: 103 raise RuntimeError( 104 f"Explicitly expected sample rate of {rate}, but found {x_wav.rate} in " 105 f"file {filename}!" 106 ) 107 108 if require_match is not None: 109 assert required_shape is None 110 assert required_wavinfo is None 111 y_wav = _wavio.read(str(require_match)) 112 required_shape = y_wav.data.shape 113 required_wavinfo = WavInfo(y_wav.sampwidth, y_wav.rate) 114 if required_wavinfo is not None: 115 if x_wav.rate != required_wavinfo.rate: 116 raise ValueError( 117 f"Mismatched rates {x_wav.rate} versus {required_wavinfo.rate}" 118 ) 119 arr_premono = x_wav.data[preroll:] / (2.0 ** (8 * x_wav.sampwidth - 1)) 120 if required_shape is not None: 121 if arr_premono.shape != required_shape: 122 raise AudioShapeMismatchError( 123 required_shape, # Expected 124 arr_premono.shape, # Actual 125 f"Mismatched shapes. Expected {required_shape}, but this is " 126 f"{arr_premono.shape}!", 127 ) 128 # sampwidth fine--we're just casting to 32-bit float anyways 129 arr = arr_premono[:, 0] 130 return arr if not info else (arr, WavInfo(x_wav.sampwidth, x_wav.rate)) 131 132 133 def wav_to_tensor( 134 *args, info: bool = False, **kwargs 135 ) -> _Union[_torch.Tensor, _Tuple[_torch.Tensor, WavInfo]]: 136 out = wav_to_np(*args, info=info, **kwargs) 137 if info: 138 arr, info = out 139 return _torch.Tensor(arr), info 140 else: 141 arr = out 142 return _torch.Tensor(arr) 143 144 145 def tensor_to_wav(x: _torch.Tensor, *args, **kwargs): 146 np_to_wav(x.detach().cpu().numpy(), *args, **kwargs) 147 148 149 def np_to_wav( 150 x: _np.ndarray, 151 filename: _Union[str, _Path], 152 rate: int = 48_000, 153 sampwidth: int = 3, 154 scale=None, 155 **kwargs, 156 ): 157 if _wavio.__version__ <= "0.0.4" and scale is None: 158 scale = "none" 159 _wavio.write( 160 str(filename), 161 (_np.clip(x, -1.0, 1.0) * (2 ** (8 * sampwidth - 1))).astype(_np.int32), 162 rate, 163 scale=scale, 164 sampwidth=sampwidth, 165 **kwargs, 166 ) 167 168 169 class AbstractDataset(_Dataset, _abc.ABC): 170 @_abc.abstractmethod 171 def __getitem__(self, idx: int): 172 """ 173 Get input and output audio segment for training / evaluation. 174 :return: 175 """ 176 pass 177 178 179 class _DelayInterpolationMethod(_Enum): 180 """ 181 :param LINEAR: Linear interpolation 182 :param CUBIC: Cubic spline interpolation 183 """ 184 185 # Note: these match scipy.interpolate.interp1d kwarg "kind" 186 LINEAR = "linear" 187 CUBIC = "cubic" 188 189 190 def _interpolate_delay( 191 x: _torch.Tensor, delay: float, method: _DelayInterpolationMethod 192 ) -> _np.ndarray: 193 """ 194 NOTE: This breaks the gradient tape! 195 """ 196 if delay == 0.0: 197 return x 198 t_in = _np.arange(len(x)) 199 n_out = len(x) - int(_np.ceil(_np.abs(delay))) 200 if delay > 0: 201 t_out = _np.arange(n_out) + delay 202 elif delay < 0: 203 t_out = _np.arange(len(x) - n_out, len(x)) - _np.abs(delay) 204 205 return _torch.Tensor( 206 _interp1d(t_in, x.detach().cpu().numpy(), kind=method.value)(t_out) 207 ) 208 209 210 class XYError(ValueError, DataError): 211 """ 212 Exceptions related to invalid x and y provided for data sets 213 """ 214 215 pass 216 217 218 class StartStopError(ValueError, DataError): 219 """ 220 Exceptions related to invalid start and stop arguments 221 """ 222 223 pass 224 225 226 class StartError(StartStopError): 227 pass 228 229 230 class StopError(StartStopError): 231 pass 232 233 234 # In seconds. Can't be 0.5 or else v1.wav is invalid! Oops! 235 _DEFAULT_REQUIRE_INPUT_PRE_SILENCE = 0.4 236 237 238 def _sample_to_time(s, rate): 239 seconds = s // rate 240 remainder = s % rate 241 hours, minutes = 0, 0 242 seconds_per_hour = 3600 243 while seconds >= seconds_per_hour: 244 hours += 1 245 seconds -= seconds_per_hour 246 seconds_per_minute = 60 247 while seconds >= seconds_per_minute: 248 minutes += 1 249 seconds -= seconds_per_minute 250 return f"{hours}:{minutes:02d}:{seconds:02d} and {remainder} samples" 251 252 253 class Dataset(AbstractDataset, _InitializableFromConfig): 254 """ 255 Take a pair of matched audio files and serve input + output pairs. 256 """ 257 258 def __init__( 259 self, 260 x: _torch.Tensor, 261 y: _torch.Tensor, 262 nx: int, 263 ny: _Optional[int], 264 start: _Optional[int] = None, 265 stop: _Optional[int] = None, 266 start_samples: _Optional[int] = None, 267 stop_samples: _Optional[int] = None, 268 start_seconds: _Optional[_Union[int, float]] = None, 269 stop_seconds: _Optional[_Union[int, float]] = None, 270 delay: _Optional[_Union[int, float]] = None, 271 delay_interpolation_method: _Union[ 272 str, _DelayInterpolationMethod 273 ] = _DelayInterpolationMethod.CUBIC, 274 y_scale: float = 1.0, 275 x_path: _Optional[_Union[str, _Path]] = None, 276 y_path: _Optional[_Union[str, _Path]] = None, 277 input_gain: float = 0.0, 278 sample_rate: _Optional[float] = None, 279 require_input_pre_silence: _Optional[ 280 float 281 ] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE, 282 ): 283 """ 284 :param x: The input signal. A 1D array. 285 :param y: The associated output from the model. A 1D array. 286 :param nx: The number of samples required as input for the model. For example, 287 for a ConvNet, this would be the receptive field. 288 :param ny: How many samples to provide as the output array for a single "datum". 289 It's usually more computationally-efficient to provide a larger `ny` than 1 290 so that the forward pass can process more audio all at once. However, this 291 shouldn't be too large or else you won't be able to provide a large batch 292 size (where each input-output pair could be something substantially 293 different and improve batch diversity). 294 :param start: [DEPRECATED; use start_samples instead.] In samples; clip x and y 295 at this point. Negative values are taken from the end of the audio. 296 :param stop: [DEPRECATED; use stop_samples instead.] In samples; clip x and y at 297 this point. Negative values are taken from the end of the audio. 298 :param start_samples: Clip x and y at this point. Negative values are taken from 299 the end of the audio. 300 :param stop: Clip x and y at this point. Negative values are taken from the end 301 of the audio. 302 :param start_seconds: Clip x and y at this point. Negative values are taken from 303 the end of the audio. Requires providing `sample_rate`. 304 :param stop_seconds: Clip x and y at this point. Negative values are taken from 305 the end of the audio. Requires providing `sample_rate`. 306 :param delay: In samples. Positive means we get rid of the start of x, end of y 307 (i.e. we are correcting for an alignment error in which y is delayed behind 308 x). If a non-integer delay is provided, then y is interpolated, with 309 the extra sample removed. 310 :param y_scale: Multiplies the output signal by a factor (e.g. if the data are 311 too quiet). 312 :param input_gain: In dB. If the input signal wasn't fed to the amp at unity 313 gain, you can indicate the gain here. The data set will multipy the raw 314 audio file by the specified gain so that the true input signal amplitude 315 experienced by the signal chain will be provided as input to the model. If 316 you are using a reamping setup, you can estimate this by reamping a 317 completely dry signal (i.e. connecting the interface output directly back 318 into the input with which the guitar was originally recorded.) 319 :param sample_rate: Sample rate for the data 320 :param require_input_pre_silence: If provided, require that this much time (in 321 seconds) preceding the start of the data set (`start`) have a silent input. 322 If it's not, then raise an exception because the output due to it will leak 323 into the data set that we're trying to use. If `None`, don't assert. 324 """ 325 self._validate_x_y(x, y) 326 self._sample_rate = sample_rate 327 start, stop = self._validate_start_stop( 328 x, 329 y, 330 start, 331 stop, 332 start_samples, 333 stop_samples, 334 start_seconds, 335 stop_seconds, 336 self.sample_rate, 337 ) 338 if not isinstance(delay_interpolation_method, _DelayInterpolationMethod): 339 delay_interpolation_method = _DelayInterpolationMethod( 340 delay_interpolation_method 341 ) 342 if require_input_pre_silence is not None: 343 self._validate_preceding_silence( 344 x, start, require_input_pre_silence, self.sample_rate 345 ) 346 x, y = [z[start:stop] for z in (x, y)] 347 if delay is not None and delay != 0: 348 x, y = self._apply_delay(x, y, delay, delay_interpolation_method) 349 x_scale = 10.0 ** (input_gain / 20.0) 350 x = x * x_scale 351 y = y * y_scale 352 self._x_path = x_path 353 self._y_path = y_path 354 self._validate_inputs_after_processing(x, y, nx, ny) 355 self._x = x 356 self._y = y 357 self._nx = nx 358 self._ny = ny if ny is not None else len(x) - nx + 1 359 360 def __getitem__(self, idx: int) -> _Tuple[_torch.Tensor, _torch.Tensor]: 361 """ 362 :return: 363 Input (NX+NY-1,) 364 Output (NY,) 365 """ 366 if idx >= len(self): 367 raise IndexError(f"Attempted to access datum {idx}, but len is {len(self)}") 368 i = idx * self._ny 369 j = i + self.y_offset 370 return self.x[i : i + self._nx + self._ny - 1], self.y[j : j + self._ny] 371 372 def __len__(self) -> int: 373 n = len(self.x) 374 # If ny were 1 375 single_pairs = n - self._nx + 1 376 return single_pairs // self._ny 377 378 @property 379 def ny(self) -> int: 380 return self._ny 381 382 @property 383 def sample_rate(self) -> _Optional[float]: 384 return self._sample_rate 385 386 @property 387 def x(self) -> _torch.Tensor: 388 """ 389 The input audio data 390 391 :return: (N,) 392 """ 393 return self._x 394 395 @property 396 def y(self) -> _torch.Tensor: 397 """ 398 The output audio data 399 400 :return: (N,) 401 """ 402 return self._y 403 404 @property 405 def y_offset(self) -> int: 406 return self._nx - 1 407 408 @classmethod 409 def parse_config(cls, config): 410 """ 411 :param config: 412 Must contain: 413 x_path (path-like) 414 y_path (path-like) 415 May contain: 416 sample_rate (int) 417 y_preroll (int) 418 allow_unequal_lengths (bool) 419 Must NOT contain: 420 x (torch.Tensor) - loaded from x_path 421 y (torch.Tensor) - loaded from y_path 422 Everything else is passed on to __init__ 423 """ 424 config = _deepcopy(config) 425 sample_rate = config.pop("sample_rate", None) 426 x, x_wavinfo = wav_to_tensor(config.pop("x_path"), info=True, rate=sample_rate) 427 sample_rate = x_wavinfo.rate 428 if config.pop("allow_unequal_lengths", False): 429 y = wav_to_tensor( 430 config.pop("y_path"), 431 rate=sample_rate, 432 preroll=config.pop("y_preroll", None), 433 required_wavinfo=x_wavinfo, 434 ) 435 # Truncate to the shorter of the two 436 if len(x) == 0: 437 raise DataError("Input is zero-length!") 438 if len(y) == 0: 439 raise DataError("Output is zero-length!") 440 n = min(len(x), len(y)) 441 if n < len(x): 442 print(f"Truncating input to {_sample_to_time(n, sample_rate)}") 443 if n < len(y): 444 print(f"Truncating output to {_sample_to_time(n, sample_rate)}") 445 x, y = [z[:n] for z in (x, y)] 446 else: 447 try: 448 y = wav_to_tensor( 449 config.pop("y_path"), 450 rate=sample_rate, 451 preroll=config.pop("y_preroll", None), 452 required_shape=(len(x), 1), 453 required_wavinfo=x_wavinfo, 454 ) 455 except AudioShapeMismatchError as e: 456 # Really verbose message since users see this. 457 x_samples, x_channels = e.shape_expected 458 y_samples, y_channels = e.shape_actual 459 msg = "Your audio files aren't the same shape as each other!" 460 if x_channels != y_channels: 461 channels_to_stereo_mono = {1: "mono", 2: "stereo"} 462 msg += f"\n * The input is {channels_to_stereo_mono[x_channels]}, but the output is {channels_to_stereo_mono[y_channels]}!" 463 if x_samples != y_samples: 464 msg += f"\n * The input is {_sample_to_time(x_samples, sample_rate)} long" 465 msg += f"\n * The output is {_sample_to_time(y_samples, sample_rate)} long" 466 msg += f"\n\nOriginal exception:\n{e}" 467 raise DataError(msg) 468 return {"x": x, "y": y, "sample_rate": sample_rate, **config} 469 470 @classmethod 471 def _apply_delay( 472 cls, 473 x: _torch.Tensor, 474 y: _torch.Tensor, 475 delay: _Union[int, float], 476 method: _DelayInterpolationMethod, 477 ) -> _Tuple[_torch.Tensor, _torch.Tensor]: 478 # Check for floats that could be treated like ints (simpler algorithm) 479 if isinstance(delay, float) and int(delay) == delay: 480 delay = int(delay) 481 if isinstance(delay, int): 482 return cls._apply_delay_int(x, y, delay) 483 elif isinstance(delay, float): 484 return cls._apply_delay_float(x, y, delay, method) 485 else: 486 raise TypeError(type(delay)) 487 488 @classmethod 489 def _apply_delay_int( 490 cls, x: _torch.Tensor, y: _torch.Tensor, delay: int 491 ) -> _Tuple[_torch.Tensor, _torch.Tensor]: 492 if delay > 0: 493 x = x[:-delay] 494 y = y[delay:] 495 elif delay < 0: 496 x = x[-delay:] 497 y = y[:delay] 498 return x, y 499 500 @classmethod 501 def _apply_delay_float( 502 cls, 503 x: _torch.Tensor, 504 y: _torch.Tensor, 505 delay: float, 506 method: _DelayInterpolationMethod, 507 ) -> _Tuple[_torch.Tensor, _torch.Tensor]: 508 n_out = len(y) - int(_np.ceil(_np.abs(delay))) 509 if delay > 0: 510 x = x[:n_out] 511 elif delay < 0: 512 x = x[-n_out:] 513 y = _interpolate_delay(y, delay, method) 514 return x, y 515 516 @classmethod 517 def _validate_start_stop( 518 cls, 519 x: _torch.Tensor, 520 y: _torch.Tensor, 521 start: _Optional[int] = None, 522 stop: _Optional[int] = None, 523 start_samples: _Optional[int] = None, 524 stop_samples: _Optional[int] = None, 525 start_seconds: _Optional[_Union[int, float]] = None, 526 stop_seconds: _Optional[_Union[int, float]] = None, 527 sample_rate: _Optional[int] = None, 528 ) -> _Tuple[_Optional[int], _Optional[int]]: 529 """ 530 Parse the requested start and stop trim points. 531 532 These may be valid indices in Python, but probably point to invalid usage, so 533 we will raise an exception if something fishy is going on (e.g. starting after 534 the end of the file, etc) 535 536 :return: parsed start/stop (if valid). 537 """ 538 539 def parse_start_stop(s, samples, seconds, rate): 540 # Assumes validated inputs 541 if s is not None: 542 return s 543 if samples is not None: 544 return samples 545 if seconds is not None: 546 return int(seconds * rate) 547 # else 548 return None 549 550 # Resolve different ways of asking for start/stop... 551 if start is not None: 552 logger.warning("Using `start` is deprecated; use `start_samples` instead.") 553 if start is not None: 554 logger.warning("Using `stop` is deprecated; use `start_samples` instead.") 555 if ( 556 int(start is not None) 557 + int(start_samples is not None) 558 + int(start_seconds is not None) 559 >= 2 560 ): 561 raise ValueError( 562 "More than one start provided. Use only one of `start`, `start_samples`, or `start_seconds`!" 563 ) 564 if ( 565 int(stop is not None) 566 + int(stop_samples is not None) 567 + int(stop_seconds is not None) 568 >= 2 569 ): 570 raise ValueError( 571 "More than one stop provided. Use only one of `stop`, `stop_samples`, or `stop_seconds`!" 572 ) 573 if start_seconds is not None and sample_rate is None: 574 raise ValueError( 575 "Provided `start_seconds` without sample rate; cannot resolve into samples!" 576 ) 577 if stop_seconds is not None and sample_rate is None: 578 raise ValueError( 579 "Provided `stop_seconds` without sample rate; cannot resolve into samples!" 580 ) 581 582 # By this point, we should have a valid, unambiguous way of asking. 583 start = parse_start_stop(start, start_samples, start_seconds, sample_rate) 584 stop = parse_start_stop(stop, stop_samples, stop_seconds, sample_rate) 585 # And only use start/stop from this point. 586 587 # We could do this whole thing with `if len(x[start: stop]==0`, but being more 588 # explicit makes the error messages better for users. 589 if start is None and stop is None: 590 return start, stop 591 if len(x) != len(y): 592 raise ValueError( 593 f"Input and output are different length. Input has {len(x)} samples, " 594 f"and output has {len(y)}" 595 ) 596 n = len(x) 597 if start is not None: 598 # Start after the files' end? 599 if start >= n: 600 raise StartError( 601 f"Arrays are only {n} samples long, but start was provided as {start}, " 602 "which is beyond the end of the array!" 603 ) 604 # Start before the files' beginning? 605 if start < -n: 606 raise StartError( 607 f"Arrays are only {n} samples long, but start was provided as {start}, " 608 "which is before the beginning of the array!" 609 ) 610 if stop is not None: 611 # Stop after the files' end? 612 if stop > n: 613 raise StopError( 614 f"Arrays are only {n} samples long, but stop was provided as {stop}, " 615 "which is beyond the end of the array!" 616 ) 617 # Start before the files' beginning? 618 if stop <= -n: 619 raise StopError( 620 f"Arrays are only {n} samples long, but stop was provided as {stop}, " 621 "which is before the beginning of the array!" 622 ) 623 # Just in case... 624 if len(x[start:stop]) == 0: 625 raise StartStopError( 626 f"Array length {n} with start={start} and stop={stop} would get " 627 "rid of all of the data!" 628 ) 629 return start, stop 630 631 @classmethod 632 def _validate_x_y(self, x, y): 633 if len(x) != len(y): 634 raise XYError( 635 f"Input and output aren't the same lengths! ({len(x)} vs {len(y)})" 636 ) 637 # TODO channels 638 n = len(x) 639 if n == 0: 640 raise XYError("Input and output are empty!") 641 642 def _validate_inputs_after_processing(self, x, y, nx, ny): 643 assert x.ndim == 1 644 assert y.ndim == 1 645 assert len(x) == len(y) 646 if nx > len(x): 647 raise RuntimeError( # TODO XYError? 648 f"Input of length {len(x)}, but receptive field is {nx}." 649 ) 650 if ny is not None: 651 assert ny <= len(y) - nx + 1 652 if _torch.abs(y).max() >= 1.0: 653 msg = "Output clipped." 654 if self._y_path is not None: 655 msg += f"Source is {self._y_path}" 656 raise ValueError(msg) 657 658 @classmethod 659 def _validate_preceding_silence( 660 cls, 661 x: _torch.Tensor, 662 start: _Optional[int], 663 silent_seconds: float, 664 sample_rate: _Optional[float], 665 ): 666 """ 667 Make sure that the input is silent before the starting index. 668 If it's not, then the output from that non-silent input will leak into the data 669 set and couldn't be predicted! 670 671 This assumes that silence is indeed required. If it's not, then don't call this! 672 673 See: Issue #252 674 675 :param x: Input 676 :param start: Where the data starts 677 :param silent_samples: How many are expected to be silent 678 """ 679 if sample_rate is None: 680 raise ValueError( 681 f"Pre-silence was required for {silent_seconds} seconds, but no sample " 682 "rate was provided!" 683 ) 684 silent_samples = int(silent_seconds * sample_rate) 685 if start is None: 686 return 687 raw_check_start = start - silent_samples 688 check_start = max(raw_check_start, 0) if start >= 0 else min(raw_check_start, 0) 689 check_end = start 690 if not _torch.all(x[check_start:check_end] == 0.0): 691 raise XYError( 692 f"Input provided isn't silent for at least {silent_samples} samples " 693 "before the starting index. Responses to this non-silent input may " 694 "leak into the dataset!" 695 ) 696 697 698 class ConcatDataset(AbstractDataset, _InitializableFromConfig): 699 def __init__(self, datasets: _Sequence[Dataset], flatten=True): 700 if flatten: 701 datasets = self._flatten_datasets(datasets) 702 self._validate_datasets(datasets) 703 self._datasets = datasets 704 self._lookup = self._make_lookup() 705 706 def __getitem__(self, idx: int) -> _Tuple[_torch.Tensor, _torch.Tensor]: 707 i, j = self._lookup[idx] 708 return self.datasets[i][j] 709 710 def __len__(self) -> int: 711 """ 712 How many data sets are in this data set 713 """ 714 return sum(len(d) for d in self._datasets) 715 716 @property 717 def datasets(self): 718 return self._datasets 719 720 @classmethod 721 def parse_config(cls, config): 722 init = _dataset_init_registry[config.get("type", "dataset")] 723 return { 724 "datasets": tuple( 725 init(c) for c in _tqdm(config["dataset_configs"], desc="Loading data") 726 ) 727 } 728 729 def _flatten_datasets(self, datasets): 730 """ 731 If any dataset is a ConcatDataset, pull it out 732 """ 733 flattened = [] 734 for d in datasets: 735 if isinstance(d, ConcatDataset): 736 flattened.extend(d.datasets) 737 else: 738 flattened.append(d) 739 return flattened 740 741 def _make_lookup(self): 742 """ 743 For faster __getitem__ 744 """ 745 lookup = {} 746 offset = 0 747 j = 0 # Dataset index 748 for i in range(len(self)): 749 if offset == len(self.datasets[j]): 750 offset -= len(self.datasets[j]) 751 j += 1 752 lookup[i] = (j, offset) 753 offset += 1 754 # Assert that we got to the last data set 755 if j != len(self.datasets) - 1: 756 raise RuntimeError( 757 f"During lookup population, didn't get to the last dataset (index " 758 f"{len(self.datasets)-1}). Instead index ended at {j}." 759 ) 760 if offset != len(self.datasets[-1]): 761 raise RuntimeError( 762 "During lookup population, didn't end at the index of the last datum " 763 f"in the last dataset. Expected index {len(self.datasets[-1])}, got " 764 f"{offset} instead." 765 ) 766 return lookup 767 768 @classmethod 769 def _validate_datasets(cls, datasets: _Sequence[Dataset]): 770 Reference = _namedtuple("Reference", ("index", "val")) 771 ref_keys, ref_ny = None, None 772 for i, d in enumerate(datasets): 773 ref_ny = Reference(i, d.ny) if ref_ny is None else ref_ny 774 if d.ny != ref_ny.val: 775 raise ValueError( 776 f"Mismatch between ny of datasets {ref_ny.index} ({ref_ny.val}) and {i} ({d.ny})" 777 ) 778 779 780 _dataset_init_registry = {"dataset": Dataset.init_from_config} 781 782 783 def register_dataset_initializer( 784 name: str, constructor: _Callable[[_Any], AbstractDataset], overwrite=False 785 ): 786 """ 787 If you have other data set types, you can register their initializer by name using 788 this. 789 790 For example, the basic NAM is registered by default under the name "default", but if 791 it weren't, you could register it like this: 792 793 >>> from nam import data 794 >>> data.register_dataset_initializer("parametric", MyParametricDataset.init_from_config) 795 796 :param name: The name that'll be used in the config to ask for the data set type 797 :param constructor: The constructor that'll be fed the config. 798 """ 799 if name in _dataset_init_registry and not overwrite: 800 raise KeyError( 801 f"A constructor for dataset name '{name}' is already registered!" 802 ) 803 _dataset_init_registry[name] = constructor 804 805 806 def init_dataset(config, split: Split) -> AbstractDataset: 807 name = config.get("type", "dataset") 808 base_config = config[split.value] 809 common = config.get("common", {}) 810 if isinstance(base_config, dict): 811 init = _dataset_init_registry[name] 812 return init({**common, **base_config}) 813 elif isinstance(base_config, list): 814 return ConcatDataset.init_from_config( 815 { 816 "type": name, 817 "dataset_configs": [{**common, **c} for c in base_config], 818 } 819 )