"""GRUD.
Recurrent Neural Networks for Multivariate Time Series with Missing Values.
"""
from pypots.classification.grud import GRUD
from pyrregular.wrappers.pypots_wrapper import PyPOTSWrapper
[docs]
class GrudWrapper(PyPOTSWrapper):
def __init__(self, model, model_params, random_state=None):
super().__init__(model, model_params, random_state)
def _fit(self, X, y):
self.model = self.model(
n_steps=self.n_steps_,
n_features=self.n_features_,
n_classes=self.n_classes_,
**self.model_params
)
X_train, X_val = self._split(X, y)
self.model.fit(train_set=X_train, val_set=X_val)
grud_pipeline = GrudWrapper(
model=GRUD,
model_params={
"rnn_hidden_size": 256,
"batch_size": 32,
"epochs": 1000,
"patience": 50,
"num_workers": 0,
"device": None,
},
)
"""This pipeline applies GRUD."""