PyTorch-Forecasting一個新的時間序列預測庫

deephub 發佈 2023-12-25T21:42:40.574999+00:00

時間序列預測在金融、天氣預報、銷售預測和需求預測等各個領域發揮著至關重要的作用。PyTorch- forecasting是一個建立在PyTorch之上的開源Python包,專門用於簡化和增強時間序列的工作。

時間序列預測在金融、天氣預報、銷售預測和需求預測等各個領域發揮著至關重要的作用。PyTorch- forecasting是一個建立在Pytorch之上的開源Python包,專門用於簡化和增強時間序列的工作。在本文中我們介紹PyTorch-Forecasting的特性和功能,並進行示例代碼演示。

PyTorch-Forecasting的安裝非常簡單:

pip install pytorch-forecasting

但是需要注意的是,他目前現在只支持Pytorch 1.7以上,但是2.0是否支持我沒有測試。

PyTorch-Forecasting提供了幾個方面的功能:

1、提供了一個高級接口,抽象了時間序列建模的複雜性,可以使用幾行代碼來定義預測任務,使得使用不同的模型和技術進行實驗變得容易。

2、支持多個預測模型,包括自回歸模型(AR, ARIMA),狀態空間模型(SARIMAX),神經網絡(LSTM, GRU)和集成方法(Prophet, N-Beats)。這種多樣化的模型集確保了為您的時間序列數據選擇最合適方法的靈活性。

3、提供各種數據預處理工具來處理常見的時間序列任務,包括:缺失值輸入、縮放、特徵提取和滾動窗口轉換等。除了一些數據的預處理的工具外,還提供了一個名為 TimeSeriesDataSet 的Pytorch的DS,這樣可以方便的處理時間序列數據。

4、通過統一的接口方便模評估:實現了QuantileLoss,SMAPE 等時間序列的損失函數和驗證指標,支持Pytorch Lighting 這樣可以直接使用早停和交叉驗證等訓練方法

使用方法也很簡單:

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
# Load and preprocess the data
dataset = TimeSeriesDataSet.from_csv('data.csv', target='target', time_idx='time', group_ids=['id'])
dataset.prepare_training(split_into_train_val_test=[0.8, 0.1, 0.1])
# Initialize and train the model
model = TemporalFusionTransformer.from_dataset(dataset)
trainer = pl.Trainer()
trainer.fit(model, dataset.train_dataloader())
# Generate predictions
predictions = model.predict(dataset.test_dataloader())
# Evaluate the model
metric = dataset.target_normalizer.metrics['mse']
print(f'Test MSE: {metric(predictions, dataset.test_dataloader())}')

如果需要分類編碼,可以這樣用:

from pytorch_forecasting.data import GroupNormalizer
# Load and preprocess the data with categorical variables
dataset = TimeSeriesDataSet.from_pandas(data, target='target', time_idx='time', group_ids=['id'], 
categorical_encoders={'cat_variable': GroupNormalizer()})
dataset.prepare_training(...)
# Initialize and train the model
model = TemporalFusionTransformer.from_dataset(dataset)
trainer.fit(model, dataset.train_dataloader())
# Generate predictions
predictions = model.predict(dataset.test_dataloader())
# Evaluate the model
print(f'Test MSE: {metric(predictions, dataset.test_dataloader())}')

PyTorch-Forecasting是一個非常好用的工具包,就算你不使用它所有的功能,也可以將他提供的一些功能當作鞏工具來整合到自己的項目中,如果你對使用PyTorch處理時序數據感興趣,也可以看看他的代碼當作學習的參考,他的文檔還是比較全面的,並且也提供了很多的示例。

關鍵字: