neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

exportable.py (5679B)


      1 # File: _exportable.py
      2 # Created Date: Tuesday February 8th 2022
      3 # Author: Steven Atkinson ([email protected])
      4 
      5 import abc as _abc
      6 import json as _json
      7 import logging as _logging
      8 from datetime import datetime as _datetime
      9 from enum import Enum as _Enum
     10 from pathlib import Path as _Path
     11 from typing import (
     12     Any as _Any,
     13     Dict as _Dict,
     14     Optional as _Optional,
     15     Sequence as _Sequence,
     16     Tuple as _Tuple,
     17     Union as _Union,
     18 )
     19 
     20 import numpy as _np
     21 
     22 from .metadata import Date as _Date, UserMetadata as _UserMetadata
     23 
     24 logger = _logging.getLogger(__name__)
     25 
     26 # Model version is independent from package version as of package version 0.5.2 so that
     27 # the API of the package can iterate at a different pace from that of the model files.
     28 _MODEL_VERSION = "0.5.4"
     29 
     30 
     31 def _cast_enums(d: _Dict[_Any, _Any]) -> _Dict[_Any, _Any]:
     32     """
     33     Casts enum-type keys to their values
     34     """
     35     out = {}
     36     for key, val in d.items():
     37         if isinstance(val, _Enum):
     38             val = val.value
     39         if isinstance(val, dict):
     40             val = _cast_enums(val)
     41         out[key] = val
     42     return out
     43 
     44 
     45 class Exportable(_abc.ABC):
     46     """
     47     Interface for my custon export format for use in the plugin.
     48     """
     49 
     50     FILE_EXTENSION = ".nam"
     51 
     52     def export(
     53         self,
     54         outdir: _Path,
     55         include_snapshot: bool = False,
     56         basename: str = "model",
     57         user_metadata: _Optional[_UserMetadata] = None,
     58         other_metadata: _Optional[dict] = None,
     59     ):
     60         """
     61         Interface for exporting.
     62         You should create at least a `config.json` containing the two fields:
     63         * "version" (str)
     64         * "architecture" (str)
     65         * "config": (dict w/ other necessary data like tensor shapes etc)
     66 
     67         :param outdir: Assumed to exist. Can be edited inside at will.
     68         :param include_snapshots: If True, outputs `input.npy` and `output.npy`
     69             Containing an example input/output pair that the model creates. This
     70             Can be used to debug e.g. the implementation of the model in the
     71             plugin.
     72         """
     73         model_dict = self._get_export_dict()
     74         model_dict["metadata"].update(
     75             {} if user_metadata is None else _cast_enums(user_metadata.model_dump())
     76         )
     77         if other_metadata is not None:
     78             overwritten_keys = []
     79             for key in other_metadata:
     80                 if key in model_dict["metadata"]:
     81                     overwritten_keys.append(key)
     82             if overwritten_keys:
     83                 logger.warning(
     84                     "other_metadata provided keys that will overwrite existing keys!\n "
     85                     + "\n ".join(overwritten_keys)
     86                 )
     87             model_dict["metadata"].update(_cast_enums(other_metadata))
     88 
     89         training = self.training
     90         self.eval()
     91         with open(_Path(outdir, f"{basename}{self.FILE_EXTENSION}"), "w") as fp:
     92             _json.dump(model_dict, fp)
     93         if include_snapshot:
     94             x, y = self._export_input_output()
     95             x_path = _Path(outdir, "test_inputs.npy")
     96             y_path = _Path(outdir, "test_outputs.npy")
     97             logger.debug(f"Saving snapshot input to {x_path}")
     98             _np.save(x_path, x)
     99             logger.debug(f"Saving snapshot output to {y_path}")
    100             _np.save(y_path, y)
    101 
    102         # And resume training state
    103         self.train(training)
    104 
    105     @_abc.abstractmethod
    106     def export_cpp_header(self, filename: _Path):
    107         """
    108         Export a .h file to compile into the plugin with the weights written right out
    109         as text
    110         """
    111         pass
    112 
    113     def export_onnx(self, filename: _Path):
    114         """
    115         Export model in format for ONNX Runtime
    116         """
    117         raise NotImplementedError(
    118             "Exporting to ONNX is not supported for models of type "
    119             f"{self.__class__.__name__}"
    120         )
    121 
    122     def import_weights(self, weights: _Sequence[float]):
    123         """
    124         Inverse of `._export_weights()
    125         """
    126         raise NotImplementedError(
    127             f"Importing weights for models of type {self.__class__.__name__} isn't "
    128             "implemented yet."
    129         )
    130 
    131     @_abc.abstractmethod
    132     def _export_config(self):
    133         """
    134         Creates the JSON of the model's archtecture hyperparameters (number of layers,
    135         number of units, etc)
    136 
    137         :return: a JSON serializable object
    138         """
    139         pass
    140 
    141     @_abc.abstractmethod
    142     def _export_input_output(self) -> _Tuple[_np.ndarray, _np.ndarray]:
    143         """
    144         Create an input and corresponding output signal to verify its behavior.
    145 
    146         They should be the same length, but the start of the output might have transient
    147         effects. Up to you to interpret.
    148         """
    149         pass
    150 
    151     @_abc.abstractmethod
    152     def _export_weights(self) -> _np.ndarray:
    153         """
    154         Flatten the weights out to a 1D array
    155         """
    156         pass
    157 
    158     def _get_export_dict(self):
    159         return {
    160             "version": _MODEL_VERSION,
    161             "metadata": self._get_non_user_metadata(),
    162             "architecture": self.__class__.__name__,
    163             "config": self._export_config(),
    164             "weights": self._export_weights().tolist(),
    165         }
    166 
    167     def _get_non_user_metadata(self) -> _Dict[str, _Union[str, int, float]]:
    168         """
    169         Get any metadata that's non-user-provided (date, loudness, gain)
    170         """
    171         t = _datetime.now()
    172         return {
    173             "date": _Date(
    174                 year=t.year,
    175                 month=t.month,
    176                 day=t.day,
    177                 hour=t.hour,
    178                 minute=t.minute,
    179                 second=t.second,
    180             ).model_dump()
    181         }