Source code for pyrregular.datasets.geolife_supervised

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from xarray import DataArray

from pyrregular.data_utils import data_original_folder
from pyrregular.io_utils import read_csv
from pyrregular.reader_interface import ReaderInterface


[docs] class GeolifeSupervised(ReaderInterface):
[docs] @staticmethod def read_original_version(verbose=False): return read_geolife_supervised(verbose=verbose)
@staticmethod def _fix_intermediate_version(data: DataArray, verbose) -> DataArray: ts_id_train, _, _, _ = train_test_split( data["ts_id"], data["label"], test_size=0.3, stratify=data["label"], random_state=42, ) train_or_test = [ "train" if i in ts_id_train else "test" for i in data["ts_id"].values ] data = data.assign_coords(split_default=("ts_id", train_or_test)) # Preparing target class for 2/2 task data = data.assign_coords( class_default=("ts_id", LabelEncoder().fit_transform(data["label"])) ) mapping = { # 0 private transportation, 1 otherwise "subway": 1, "airplane": 1, "train": 1, "taxi": 1, "bus": 1, "motorcycle": 0, "run": 0, "walk": 0, "boat": 0, "car": 0, "bike": 0, } data = data.assign_coords( class_simplified=("ts_id", [mapping[x] for x in data["label"].values]), ) return data
[docs] def read_geolife_supervised(verbose=False): return read_csv( filenames=data_original_folder() / "geolife_supervised/geolife_supervised.csv", ts_id="tid", time_id="time", signal_id="variable", value_id="value", dims={ "ts_id": ["user", "label"], "signal_id": [], "time_id": [], }, reader_fun=_read_geolife_supervised, verbose=verbose, )
def _read_geolife_supervised(filenames: list): filename = filenames[0] # Read the CSV file into a Pandas DataFrame df = pd.read_csv(filename) melted_df = df.melt( id_vars=["tid", "label", "user", "time"], value_vars=["lat", "lon"] ) for row in melted_df.to_dict(orient="records"): yield row if __name__ == "__main__": # GeolifeSupervised.save_unfixed(True) # GeolifeSupervised.save_fixed(True) # # print(GeolifeSupervised.read(False)) # print(GeolifeSupervised.load_unfixed()) # print(GeolifeSupervised.load_fixed()) # GeolifeSupervised.save_fixed() df = GeolifeSupervised.load_final_version()