exportable.py (5679B)
1 # File: _exportable.py 2 # Created Date: Tuesday February 8th 2022 3 # Author: Steven Atkinson ([email protected]) 4 5 import abc as _abc 6 import json as _json 7 import logging as _logging 8 from datetime import datetime as _datetime 9 from enum import Enum as _Enum 10 from pathlib import Path as _Path 11 from typing import ( 12 Any as _Any, 13 Dict as _Dict, 14 Optional as _Optional, 15 Sequence as _Sequence, 16 Tuple as _Tuple, 17 Union as _Union, 18 ) 19 20 import numpy as _np 21 22 from .metadata import Date as _Date, UserMetadata as _UserMetadata 23 24 logger = _logging.getLogger(__name__) 25 26 # Model version is independent from package version as of package version 0.5.2 so that 27 # the API of the package can iterate at a different pace from that of the model files. 28 _MODEL_VERSION = "0.5.4" 29 30 31 def _cast_enums(d: _Dict[_Any, _Any]) -> _Dict[_Any, _Any]: 32 """ 33 Casts enum-type keys to their values 34 """ 35 out = {} 36 for key, val in d.items(): 37 if isinstance(val, _Enum): 38 val = val.value 39 if isinstance(val, dict): 40 val = _cast_enums(val) 41 out[key] = val 42 return out 43 44 45 class Exportable(_abc.ABC): 46 """ 47 Interface for my custon export format for use in the plugin. 48 """ 49 50 FILE_EXTENSION = ".nam" 51 52 def export( 53 self, 54 outdir: _Path, 55 include_snapshot: bool = False, 56 basename: str = "model", 57 user_metadata: _Optional[_UserMetadata] = None, 58 other_metadata: _Optional[dict] = None, 59 ): 60 """ 61 Interface for exporting. 62 You should create at least a `config.json` containing the two fields: 63 * "version" (str) 64 * "architecture" (str) 65 * "config": (dict w/ other necessary data like tensor shapes etc) 66 67 :param outdir: Assumed to exist. Can be edited inside at will. 68 :param include_snapshots: If True, outputs `input.npy` and `output.npy` 69 Containing an example input/output pair that the model creates. This 70 Can be used to debug e.g. the implementation of the model in the 71 plugin. 72 """ 73 model_dict = self._get_export_dict() 74 model_dict["metadata"].update( 75 {} if user_metadata is None else _cast_enums(user_metadata.model_dump()) 76 ) 77 if other_metadata is not None: 78 overwritten_keys = [] 79 for key in other_metadata: 80 if key in model_dict["metadata"]: 81 overwritten_keys.append(key) 82 if overwritten_keys: 83 logger.warning( 84 "other_metadata provided keys that will overwrite existing keys!\n " 85 + "\n ".join(overwritten_keys) 86 ) 87 model_dict["metadata"].update(_cast_enums(other_metadata)) 88 89 training = self.training 90 self.eval() 91 with open(_Path(outdir, f"{basename}{self.FILE_EXTENSION}"), "w") as fp: 92 _json.dump(model_dict, fp) 93 if include_snapshot: 94 x, y = self._export_input_output() 95 x_path = _Path(outdir, "test_inputs.npy") 96 y_path = _Path(outdir, "test_outputs.npy") 97 logger.debug(f"Saving snapshot input to {x_path}") 98 _np.save(x_path, x) 99 logger.debug(f"Saving snapshot output to {y_path}") 100 _np.save(y_path, y) 101 102 # And resume training state 103 self.train(training) 104 105 @_abc.abstractmethod 106 def export_cpp_header(self, filename: _Path): 107 """ 108 Export a .h file to compile into the plugin with the weights written right out 109 as text 110 """ 111 pass 112 113 def export_onnx(self, filename: _Path): 114 """ 115 Export model in format for ONNX Runtime 116 """ 117 raise NotImplementedError( 118 "Exporting to ONNX is not supported for models of type " 119 f"{self.__class__.__name__}" 120 ) 121 122 def import_weights(self, weights: _Sequence[float]): 123 """ 124 Inverse of `._export_weights() 125 """ 126 raise NotImplementedError( 127 f"Importing weights for models of type {self.__class__.__name__} isn't " 128 "implemented yet." 129 ) 130 131 @_abc.abstractmethod 132 def _export_config(self): 133 """ 134 Creates the JSON of the model's archtecture hyperparameters (number of layers, 135 number of units, etc) 136 137 :return: a JSON serializable object 138 """ 139 pass 140 141 @_abc.abstractmethod 142 def _export_input_output(self) -> _Tuple[_np.ndarray, _np.ndarray]: 143 """ 144 Create an input and corresponding output signal to verify its behavior. 145 146 They should be the same length, but the start of the output might have transient 147 effects. Up to you to interpret. 148 """ 149 pass 150 151 @_abc.abstractmethod 152 def _export_weights(self) -> _np.ndarray: 153 """ 154 Flatten the weights out to a 1D array 155 """ 156 pass 157 158 def _get_export_dict(self): 159 return { 160 "version": _MODEL_VERSION, 161 "metadata": self._get_non_user_metadata(), 162 "architecture": self.__class__.__name__, 163 "config": self._export_config(), 164 "weights": self._export_weights().tolist(), 165 } 166 167 def _get_non_user_metadata(self) -> _Dict[str, _Union[str, int, float]]: 168 """ 169 Get any metadata that's non-user-provided (date, loudness, gain) 170 """ 171 t = _datetime.now() 172 return { 173 "date": _Date( 174 year=t.year, 175 month=t.month, 176 day=t.day, 177 hour=t.hour, 178 minute=t.minute, 179 second=t.second, 180 ).model_dump() 181 }