Source code for pyrregular.models.grud

"""GRUD.
Recurrent Neural Networks for Multivariate Time Series with Missing Values.
"""

from pypots.classification.grud import GRUD

from pyrregular.wrappers.pypots_wrapper import PyPOTSWrapper


[docs] class GrudWrapper(PyPOTSWrapper): def __init__(self, model, model_params, random_state=None): super().__init__(model, model_params, random_state) def _fit(self, X, y): self.model = self.model( n_steps=self.n_steps_, n_features=self.n_features_, n_classes=self.n_classes_, **self.model_params ) X_train, X_val = self._split(X, y) self.model.fit(train_set=X_train, val_set=X_val)
grud_pipeline = GrudWrapper( model=GRUD, model_params={ "rnn_hidden_size": 256, "batch_size": 32, "epochs": 1000, "patience": 50, "num_workers": 0, "device": None, }, ) """This pipeline applies GRUD."""