diff --git a/LICENSE b/LICENSE index faaa6cff15b1..07a2ad8da4d4 100644 --- a/LICENSE +++ b/LICENSE @@ -349,4 +349,14 @@ The chronos-forecasting is open source software licensed under the Apache Licens Project page: https://github.com/amazon-science/chronos-forecasting License: https://github.com/amazon-science/chronos-forecasting/blob/main/LICENSE --------------------------------------------------------------------------------- \ No newline at end of file +-------------------------------------------------------------------------------- + +The following files include code modified from uni2ts project. + +./iotdb-core/ainode/iotdb/ainode/core/model/moirai2/* + +The uni2ts is open source software licensed under the Apache License 2.0 +Project page: https://github.com/SalesforceAIResearch/uni2ts +License: https://github.com/SalesforceAIResearch/uni2ts/blob/main/LICENSE.txt + +-------------------------------------------------------------------------------- diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index d620efacc263..5a4dce53666d 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -56,7 +56,9 @@ public class AINodeTestUtils { new AbstractMap.SimpleEntry<>( "timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")), new AbstractMap.SimpleEntry<>( - "chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active"))) + "chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active"))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); public static final Map BUILTIN_MODEL_MAP; diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index 8a83c9814379..68f64a79afc4 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -53,7 +53,7 @@ # TODO: Should be optimized AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { "sundial": 1036 * 1024**2, # 1036 MiB - "timer": 856 * 1024**2, # 856 MiB + "timer_xl": 856 * 1024**2, # 856 MiB } # the memory usage of each model in bytes AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.2 # the device space allocated for inference diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index 65aa77143939..3e62b5e96545 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -59,8 +59,17 @@ def _estimate_shared_pool_size_by_total_mem( # Seize memory usage for each model mem_usages: Dict[str, float] = {} for model_info in all_models: + if model_info.model_id not in MODEL_MEM_USAGE_MAP: + logger.error( + f"[Inference] Model '{model_info.model_id}' not found in MODEL_MEM_USAGE_MAP. " + f"Available types: {list(MODEL_MEM_USAGE_MAP.keys())}" + ) + raise KeyError( + f"Model '{model_info.model_id}' not found in MODEL_MEM_USAGE_MAP. " + f"Please add memory usage configuration for '{model_info.model_id}' in constant.py" + ) mem_usages[model_info.model_id] = ( - MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO + MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO ) # Evaluate system resources and get TOTAL memory diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/chronos_bolt.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/chronos_bolt.py index 8b221f5f149d..68880e63edb5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/chronos_bolt.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/chronos_bolt.py @@ -596,7 +596,7 @@ def predict( context_tensor = torch.cat([context_tensor, prediction], dim=-1)[ ..., -self.model_context_length : ] - (batch_size, n_quantiles, context_length) = context_tensor.shape + batch_size, n_quantiles, context_length = context_tensor.shape with torch.no_grad(): # Reshape (batch, n_quantiles, context_length) -> (batch * n_quantiles, context_length) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index d0da371bfd5c..f253fb1e56f6 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -145,4 +145,17 @@ def __repr__(self): "AutoModelForCausalLM": "model.Chronos2Model", }, ), + "moirai2": ModelInfo( + model_id="moirai2", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="moirai", + pipeline_cls="pipeline_moirai2.Moirai2Pipeline", + repo_id="Salesforce/moirai-2.0-R-small", + auto_map={ + "AutoConfig": "configuration_moirai2.Moirai2Config", + "AutoModelForCausalLM": "modeling_moirai2.Moirai2ForPrediction", + }, + transformers_registered=True, + ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index 289786c8aa3b..1da07cb9fef9 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -61,7 +61,6 @@ def load_model(model_info: ModelInfo, **model_kwargs) -> Any: def load_model_from_transformers(model_info: ModelInfo, **model_kwargs): device_map = model_kwargs.get("device_map", "cpu") - trust_remote_code = model_kwargs.get("trust_remote_code", True) train_from_scratch = model_kwargs.get("train_from_scratch", False) model_path = os.path.join( @@ -107,11 +106,9 @@ def load_model_from_transformers(model_info: ModelInfo, **model_kwargs): model_cls = AutoModelForCausalLM if train_from_scratch: - model = model_cls.from_config(config_cls, trust_remote_code=trust_remote_code) + model = model_cls.from_config(config_cls) else: - model = model_cls.from_pretrained( - model_path, trust_remote_code=trust_remote_code - ) + model = model_cls.from_pretrained(model_path) return BACKEND.move_model(model, device_map) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/__init__.py new file mode 100644 index 000000000000..f1ad0c458d3f --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file is part of the Apache IoTDB project. +# +# This file includes code modified from the uni2ts project (https://github.com/salesforce/uni2ts). +# The original code is licensed under the Apache License 2.0. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/common/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/common/__init__.py new file mode 100644 index 000000000000..2a1e720805f2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/common/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/common/torch_util.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/common/torch_util.py new file mode 100644 index 000000000000..469e8044b991 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/common/torch_util.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Optional + +import numpy as np +import torch +from jaxtyping import Bool, Float, Int + +numpy_to_torch_dtype_dict = { + bool: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} + + +def packed_attention_mask( + sample_id: Int[torch.Tensor, "*batch seq_len"], +) -> Bool[torch.Tensor, "*batch seq_len seq_len"]: + sample_id = sample_id.unsqueeze(-1) + attention_mask = sample_id.eq(sample_id.mT) + return attention_mask + + +def packed_causal_attention_mask( + sample_id: Int[torch.Tensor, "*batch seq_len"], + time_id: Int[torch.Tensor, "*batch seq_len"], +) -> Bool[torch.Tensor, "*batch seq_len seq_len"]: + attention_mask = packed_attention_mask(sample_id) + expanded_id1 = time_id.unsqueeze(-2) + expanded_id2 = time_id.unsqueeze(-1) + compare_res = expanded_id1 <= expanded_id2 + attention_mask = attention_mask * compare_res + return attention_mask + + +def mask_fill( + tensor: Float[torch.Tensor, "*batch dim"], + mask: Bool[torch.Tensor, "*batch"], + value: Float[torch.Tensor, "dim"], +) -> Float[torch.Tensor, "*batch dim"]: + mask = mask.unsqueeze(-1) + return tensor * ~mask + value * mask + + +def safe_div( + numer: torch.Tensor, + denom: torch.Tensor, +) -> torch.Tensor: + return numer / torch.where( + denom == 0, + 1.0, + denom, + ) + + +def size_to_mask( + max_size: int, + sizes: Int[torch.Tensor, "*batch"], +) -> Bool[torch.Tensor, "*batch max_size"]: + mask = torch.arange(max_size, device=sizes.device) + return torch.lt(mask, sizes.unsqueeze(-1)) + + +def fixed_size( + value: Float[torch.Tensor, "*batch max_size"], +) -> Int[torch.Tensor, "*batch"]: + sizes = torch.ones_like(value[..., 0], dtype=torch.long) * value.shape[-1] + return sizes + + +def sized_mean( + value: Float[torch.Tensor, "*batch max_size"], + sizes: Optional[Int[torch.Tensor, "*batch"]], + dim: Optional[int | tuple[int, ...]] = None, + keepdim: bool = False, + size_keepdim: bool = False, + correction: int = 0, +) -> Float[torch.Tensor, "..."]: + value = value * size_to_mask(value.shape[-1], sizes) + div_val = safe_div( + value.sum(dim=-1).sum(dim, keepdim=keepdim), + torch.clamp(sizes.sum(dim, keepdim=keepdim) - correction, min=0), + ) + if size_keepdim: + div_val = div_val.unsqueeze(-1) + return div_val + + +def masked_mean( + value: Float[torch.Tensor, "..."], + mask: Bool[torch.Tensor, "..."], + dim: Optional[int | tuple[int, ...]] = None, + keepdim: bool = False, + correction: int = 0, +) -> Float[torch.Tensor, "..."]: + return safe_div( + (value * mask).sum(dim=dim, keepdim=keepdim), + torch.clamp(mask.float().sum(dim, keepdim=keepdim) - correction, min=0), + ) + + +def unsqueeze_trailing_dims(x: torch.Tensor, shape: torch.Size) -> torch.Tensor: + if x.ndim > len(shape) or x.shape != shape[: x.ndim]: + raise ValueError + dim = (...,) + (None,) * (len(shape) - x.ndim) + return x[dim] diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/configuration_moirai2.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/configuration_moirai2.py new file mode 100644 index 000000000000..81eeea189e8e --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/configuration_moirai2.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import List, Tuple + +from transformers import PretrainedConfig + + +class Moirai2Config(PretrainedConfig): + model_type = "moirai2" + + def __init__( + self, + d_model: int = 384, + d_ff: int = 1024, + num_layers: int = 6, + patch_size: int = 16, + max_seq_len: int = 512, + attn_dropout_p: float = 0.0, + dropout_p: float = 0.0, + scaling: bool = True, + num_predict_token: int = 4, + quantile_levels: Tuple[float, ...] = ( + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + ), + **kwargs, + ): + self.d_model = d_model + self.d_ff = d_ff + self.num_layers = num_layers + self.patch_size = patch_size + self.max_seq_len = max_seq_len + self.attn_dropout_p = attn_dropout_p + self.dropout_p = dropout_p + self.scaling = scaling + self.num_predict_token = num_predict_token + self.quantile_levels = quantile_levels + super().__init__(**kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/modeling_moirai2.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/modeling_moirai2.py new file mode 100644 index 000000000000..c9b7afc98a32 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/modeling_moirai2.py @@ -0,0 +1,1352 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import math +import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial +from typing import Generator, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from huggingface_hub import PyTorchModelHubMixin +from jaxtyping import Bool, Float, Int +from torch import nn +from transformers import PreTrainedModel +from transformers.modeling_outputs import ModelOutput + +from iotdb.ainode.core.model.moirai2.common.torch_util import ( + packed_causal_attention_mask, +) +from iotdb.ainode.core.model.moirai2.configuration_moirai2 import Moirai2Config +from iotdb.ainode.core.model.moirai2.module.norm import RMSNorm +from iotdb.ainode.core.model.moirai2.module.packed_scaler import ( + PackedNOPScaler, + PackedStdScaler, +) +from iotdb.ainode.core.model.moirai2.module.position import ( + BinaryAttentionBias, + QueryKeyProjection, + RotaryProjection, +) +from iotdb.ainode.core.model.moirai2.module.transformer import TransformerEncoder +from iotdb.ainode.core.model.moirai2.module.ts_embed import ResidualBlock +from iotdb.ainode.core.model.moirai2.transform.imputation import CausalMeanImputation + + +@dataclass +class Moirai2Output(ModelOutput): + """ + Output class for Moirai2 model. + + Args: + predictions: Model predictions of shape (batch, seq_len, num_quantiles * patch_size) + scaled_target: Scaled target values (only during training) + loc: Scaling location parameters + scale: Scaling scale parameters + """ + + predictions: torch.FloatTensor = None + scaled_target: Optional[torch.FloatTensor] = None + loc: Optional[torch.FloatTensor] = None + scale: Optional[torch.FloatTensor] = None + + +class Moirai2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. + """ + + config_class = Moirai2Config + base_model_prefix = "moirai2" + supports_gradient_checkpointing = False + _no_split_modules = ["TransformerEncoderLayer"] + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02) + + +class Moirai2Model(Moirai2PreTrainedModel, PyTorchModelHubMixin): + """ + Core Moirai2 transformer model. + + This model implements a transformer-based architecture for time series forecasting + with support for patching, scaling, and multi-variate predictions. + + Inherits from PyTorchModelHubMixin to enable loading from HuggingFace Hub + using the from_pretrained() method, which is compatible with uni2ts pretrained models. + """ + + def __init__(self, config: Moirai2Config = None, **kwargs): + # Handle both config and kwargs for PyTorchModelHubMixin compatibility + if config is None: + config = Moirai2Config(**kwargs) + super().__init__(config) + self.config = config + + # Model parameters + self.d_model = config.d_model + self.num_layers = config.num_layers + self.patch_size = config.patch_size + self.num_predict_token = config.num_predict_token + self.max_seq_len = config.max_seq_len + self.scaling = config.scaling + self.quantile_levels = config.quantile_levels + self.num_quantiles = len(config.quantile_levels) + + # Scaler + self.scaler = PackedStdScaler() if self.scaling else PackedNOPScaler() + + # Input projection + self.in_proj = ResidualBlock( + input_dims=self.patch_size * 2, + hidden_dims=self.d_model, + output_dims=self.d_model, + ) + + # Transformer encoder + self.encoder = TransformerEncoder( + self.d_model, + self.num_layers, + num_heads=None, + pre_norm=True, + attn_dropout_p=config.attn_dropout_p, + dropout_p=config.dropout_p, + norm_layer=RMSNorm, + activation=F.silu, + use_glu=True, + use_qk_norm=True, + var_attn_bias_layer=partial(BinaryAttentionBias), + time_qk_proj_layer=partial( + QueryKeyProjection, + proj_layer=RotaryProjection, + kwargs=dict(max_len=self.max_seq_len), + partial_factor=(0.0, 0.5), + ), + shared_var_attn_bias=False, + shared_time_qk_proj=True, + d_ff=config.d_ff, + ) + + # Output projection + self.out_proj = ResidualBlock( + input_dims=self.d_model, + hidden_dims=self.d_model, + output_dims=self.num_predict_token * self.num_quantiles * self.patch_size, + ) + + self.post_init() + + def forward( + self, + target: Float[torch.Tensor, "*batch seq_len patch"], + observed_mask: Bool[torch.Tensor, "*batch seq_len patch"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + time_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + prediction_mask: Bool[torch.Tensor, "*batch seq_len"], + training_mode: bool = True, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Moirai2Output]: + """ + Forward pass of Moirai2Model. + + Args: + target: Input time series data + observed_mask: Binary mask for observed values (1=observed, 0=missing) + sample_id: Sample indices for packed sequences + time_id: Time step indices + variate_id: Variable indices + prediction_mask: Binary mask for prediction horizon (1=predict, 0=context) + training_mode: Whether in training mode + return_dict: Whether to return ModelOutput + + Returns: + Moirai2Output or tuple with predictions + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Apply scaling + loc, scale = self.scaler( + target, + observed_mask * ~prediction_mask.unsqueeze(-1), + sample_id, + variate_id, + ) + scaled_target = (target - loc) / scale + + # Concatenate scaled values with mask + input_tokens = torch.cat( + [scaled_target, observed_mask.to(torch.float32)], dim=-1 + ) + + # Project to model dimension + reprs = self.in_proj(input_tokens) + + # Apply transformer + reprs = self.encoder( + reprs, + packed_causal_attention_mask(sample_id, time_id), + time_id=time_id, + var_id=variate_id, + ) + + # Project to output + preds = self.out_proj(reprs) + + if training_mode: + scaled_preds = preds + if not return_dict: + return (preds, scaled_target) + return Moirai2Output( + predictions=scaled_preds, + scaled_target=scaled_target, + loc=loc, + scale=scale, + ) + else: + # Rescale predictions + preds = preds * scale + loc + if not return_dict: + return (preds,) + return Moirai2Output( + predictions=preds, + loc=loc, + scale=scale, + ) + + +class Moirai2ForPrediction(Moirai2PreTrainedModel): + """ + Moirai2 model for time series prediction. + + This class wraps Moirai2Model and provides high-level prediction interfaces, + including support for quantile forecasting and various input formats. + """ + + def __init__( + self, + config: Moirai2Config, + prediction_length: int = 96, + target_dim: int = 1, + feat_dynamic_real_dim: int = 0, + past_feat_dynamic_real_dim: int = 0, + context_length: int = 512, + ): + super().__init__(config) + self.config = config + + # Hyperparameters + self.prediction_length = prediction_length + self.target_dim = target_dim + self.feat_dynamic_real_dim = feat_dynamic_real_dim + self.past_feat_dynamic_real_dim = past_feat_dynamic_real_dim + self.context_length = context_length + + # Core model + self.model = Moirai2Model(config) + + self.post_init() + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + prediction_length: int = 96, + target_dim: int = 1, + feat_dynamic_real_dim: int = 0, + past_feat_dynamic_real_dim: int = 0, + context_length: int = 512, + **kwargs, + ): + """ + Load a pretrained Moirai2ForPrediction model. + + This method handles the weight loading from uni2ts pretrained models, + which have a different weight key structure (without 'model.' prefix). + + Args: + pretrained_model_name_or_path: Path or HuggingFace repo ID + prediction_length: Length of forecast horizon + target_dim: Number of target variables + feat_dynamic_real_dim: Dimension of future dynamic features + past_feat_dynamic_real_dim: Dimension of past dynamic features + context_length: Length of context window + **kwargs: Additional arguments for loading + + Returns: + Moirai2ForPrediction instance with loaded weights + """ + import os + from pathlib import Path + + # Load config + config = Moirai2Config.from_pretrained(pretrained_model_name_or_path, **kwargs) + + # Create instance normally (this will create a randomly initialized core model) + instance = cls( + config=config, + prediction_length=prediction_length, + target_dim=target_dim, + feat_dynamic_real_dim=feat_dynamic_real_dim, + past_feat_dynamic_real_dim=past_feat_dynamic_real_dim, + context_length=context_length, + ) + + # Now load weights into the core model + # Try to find the weight file + model_dir = Path(pretrained_model_name_or_path) + weight_file = None + + # Check for various weight file formats + for filename in ["model.safetensors", "pytorch_model.bin", "model.bin"]: + candidate = model_dir / filename + if candidate.exists(): + weight_file = candidate + break + + if weight_file is None: + warnings.warn( + f"No weight file found in {pretrained_model_name_or_path}. " + f"Model will use randomly initialized weights." + ) + return instance + + # Load the state dict + try: + if str(weight_file).endswith(".safetensors"): + # Use safetensors if available + try: + from safetensors.torch import load_file + + state_dict = load_file(str(weight_file)) + except ImportError: + # Fallback to torch.load (won't work for safetensors) + warnings.warn( + "safetensors not available, trying torch.load. " + "Please install safetensors for better compatibility." + ) + state_dict = torch.load(str(weight_file), map_location="cpu") + else: + state_dict = torch.load(str(weight_file), map_location="cpu") + + # Load weights into the core model + # The uni2ts weights don't have 'model.' prefix, load directly + missing_keys, unexpected_keys = instance.model.load_state_dict( + state_dict, strict=False + ) + + if missing_keys: + warnings.warn( + f"Missing keys when loading pretrained weights: {missing_keys[:5]}..." + ) + if unexpected_keys: + warnings.warn( + f"Unexpected keys when loading pretrained weights: {unexpected_keys[:5]}..." + ) + + if not missing_keys: + # Successfully loaded all weights + pass + else: + warnings.warn( + f"Some weights were not loaded. Model may not work correctly." + ) + + except Exception as e: + warnings.warn( + f"Failed to load weights from {weight_file}: {e}\n" + f"Model will use randomly initialized weights." + ) + + return instance + + def set_decoder(self, decoder): + """Set the decoder model (for compatibility with TSGenerationMixin).""" + self.model = decoder + + def get_decoder(self): + """Get the decoder model (for compatibility with TSGenerationMixin).""" + return self.model + + @property + def past_length(self) -> int: + """Get the context length.""" + return self.context_length + + @property + def max_patch_size(self) -> int: + """Get the maximum patch size.""" + return self.model.patch_size + + def context_token_length(self, patch_size: int) -> int: + """Calculate the number of tokens in the context.""" + return math.ceil(self.context_length / patch_size) + + def prediction_token_length(self, patch_size: int) -> int: + """Calculate the number of tokens in the prediction.""" + return math.ceil(self.prediction_length / patch_size) + + @contextmanager + def hparams_context( + self, + prediction_length: Optional[int] = None, + target_dim: Optional[int] = None, + feat_dynamic_real_dim: Optional[int] = None, + past_feat_dynamic_real_dim: Optional[int] = None, + context_length: Optional[int] = None, + ) -> Generator["Moirai2ForPrediction", None, None]: + """ + Context manager for temporarily changing hyperparameters. + """ + kwargs = { + "prediction_length": prediction_length, + "target_dim": target_dim, + "feat_dynamic_real_dim": feat_dynamic_real_dim, + "past_feat_dynamic_real_dim": past_feat_dynamic_real_dim, + "context_length": context_length, + } + + # Save old values + old_values = { + "prediction_length": self.prediction_length, + "target_dim": self.target_dim, + "feat_dynamic_real_dim": self.feat_dynamic_real_dim, + "past_feat_dynamic_real_dim": self.past_feat_dynamic_real_dim, + "context_length": self.context_length, + } + + # Set new values + for key, value in kwargs.items(): + if value is not None: + setattr(self, key, value) + + try: + yield self + finally: + # Restore old values + for key, value in old_values.items(): + setattr(self, key, value) + + def forward( + self, + past_target: Float[torch.Tensor, "batch past_time tgt"], + past_observed_target: Bool[torch.Tensor, "batch past_time tgt"], + past_is_pad: Bool[torch.Tensor, "batch past_time"], + feat_dynamic_real: Optional[Float[torch.Tensor, "batch time feat"]] = None, + observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch time feat"] + ] = None, + past_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + past_observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + return_dict: Optional[bool] = None, + ) -> Float[torch.Tensor, "batch num_quantiles future_time *tgt"]: + """ + Forward pass for prediction. + + Args: + past_target: Historical target values + past_observed_target: Mask for observed target values + past_is_pad: Mask for padding in the past + feat_dynamic_real: Future dynamic features (optional) + observed_feat_dynamic_real: Mask for dynamic features (optional) + past_feat_dynamic_real: Past dynamic features (optional) + past_observed_feat_dynamic_real: Mask for past dynamic features (optional) + return_dict: Whether to return ModelOutput + + Returns: + Predictions of shape (batch, num_quantiles, prediction_length, target_dim) + """ + # Convert inputs to model format + ( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + ) = self._convert( + self.model.patch_size, + past_target, + past_observed_target, + past_is_pad, + feat_dynamic_real=feat_dynamic_real, + observed_feat_dynamic_real=observed_feat_dynamic_real, + past_feat_dynamic_real=past_feat_dynamic_real, + past_observed_feat_dynamic_real=past_observed_feat_dynamic_real, + ) + + per_var_context_token = self.context_token_length(self.model.patch_size) + total_context_token = self.target_dim * per_var_context_token + per_var_predict_token = self.prediction_token_length(self.model.patch_size) + total_predict_token = self.target_dim * per_var_predict_token + + # Initialize prediction tensor + pred_index = torch.arange( + start=per_var_context_token - 1, + end=total_context_token, + step=per_var_context_token, + ) + assign_index = torch.arange( + start=total_context_token, + end=total_context_token + total_predict_token, + step=per_var_predict_token, + ) + + quantile_prediction = repeat( + target, + "... patch_size -> ... num_quantiles patch_size", + num_quantiles=self.model.num_quantiles, + patch_size=self.model.patch_size, + ).clone() + + # Get model predictions + outputs = self.model( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + training_mode=False, + return_dict=True, + ) + preds = outputs.predictions + + # Process predictions + if per_var_predict_token <= self.model.num_predict_token: + # Single-step prediction + preds, adjusted_assign_index = self._structure_multi_predict( + per_var_predict_token, + pred_index, + assign_index, + preds, + ) + quantile_prediction[..., adjusted_assign_index, :, :] = preds + return self._format_preds( + self.model.num_quantiles, + self.model.patch_size, + quantile_prediction, + self.target_dim, + ) + else: + # Multi-step autoregressive prediction + return self._autoregressive_predict( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + quantile_prediction, + preds, + pred_index, + assign_index, + per_var_predict_token, + ) + + def _structure_multi_predict( + self, + per_var_predict_token: int, + pred_index: torch.Tensor, + assign_index: torch.Tensor, + preds: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Structure predictions for multiple prediction tokens.""" + preds = rearrange( + preds, + "... (predict_token num_quantiles patch_size) -> ... predict_token num_quantiles patch_size", + predict_token=self.model.num_predict_token, + num_quantiles=self.model.num_quantiles, + patch_size=self.model.patch_size, + ) + preds = rearrange( + preds[..., pred_index, :per_var_predict_token, :, :], + "... pred_index predict_token num_quantiles patch_size -> ... (pred_index predict_token) num_quantiles patch_size", + ) + adjusted_assign_index = torch.cat( + [ + torch.arange(start=idx, end=idx + per_var_predict_token) + for idx in assign_index + ] + ) + return preds, adjusted_assign_index + + def _autoregressive_predict( + self, + target: torch.Tensor, + observed_mask: torch.Tensor, + sample_id: torch.Tensor, + time_id: torch.Tensor, + variate_id: torch.Tensor, + prediction_mask: torch.Tensor, + quantile_prediction: torch.Tensor, + preds: torch.Tensor, + pred_index: torch.Tensor, + assign_index: torch.Tensor, + per_var_predict_token: int, + ) -> torch.Tensor: + """Perform autoregressive prediction for long horizons.""" + # Expand tensors for quantiles + expand_target = repeat( + target, + "batch_size ... -> batch_size num_quantiles ...", + num_quantiles=self.model.num_quantiles, + batch_size=target.shape[0], + ).clone() + expand_prediction_mask = repeat( + prediction_mask, + "batch_size ... -> batch_size num_quantiles ...", + num_quantiles=self.model.num_quantiles, + batch_size=target.shape[0], + ).clone() + expand_observed_mask = repeat( + observed_mask, + "batch_size ... -> batch_size num_quantiles ...", + num_quantiles=self.model.num_quantiles, + batch_size=target.shape[0], + ).clone() + expand_sample_id = repeat( + sample_id, + "batch_size ... -> batch_size num_quantiles ...", + num_quantiles=self.model.num_quantiles, + batch_size=target.shape[0], + ).clone() + expand_time_id = repeat( + time_id, + "batch_size ... -> batch_size num_quantiles ...", + num_quantiles=self.model.num_quantiles, + batch_size=target.shape[0], + ).clone() + expand_variate_id = repeat( + variate_id, + "batch_size ... -> batch_size num_quantiles ...", + num_quantiles=self.model.num_quantiles, + batch_size=target.shape[0], + ).clone() + + # First prediction step + preds, adjusted_assign_index = self._structure_multi_predict( + self.model.num_predict_token, + pred_index, + assign_index, + preds, + ) + quantile_prediction[..., adjusted_assign_index, :, :] = preds + + expand_target[..., adjusted_assign_index, :] = rearrange( + preds, + "... predict_token num_quantiles patch_size -> ... num_quantiles predict_token patch_size", + num_quantiles=self.model.num_quantiles, + patch_size=self.model.patch_size, + predict_token=self.model.num_predict_token, + ) + expand_prediction_mask[..., adjusted_assign_index] = False + + # Remaining steps + remain_step = per_var_predict_token - self.model.num_predict_token + while remain_step > 0: + outputs = self.model( + expand_target, + expand_observed_mask, + expand_sample_id, + expand_time_id, + expand_variate_id, + expand_prediction_mask, + training_mode=False, + return_dict=True, + ) + preds = outputs.predictions + + pred_index = assign_index + self.model.num_predict_token - 1 + assign_index = pred_index + 1 + preds, adjusted_assign_index = self._structure_multi_predict( + ( + self.model.num_predict_token + if remain_step > self.model.num_predict_token + else remain_step + ), + pred_index, + assign_index, + preds, + ) + + # Compute quantiles + quantile_prediction_next_step = rearrange( + preds, + "... num_quantiles_prev pred_index num_quantiles patch_size -> ... pred_index (num_quantiles_prev num_quantiles) patch_size", + num_quantiles=self.model.num_quantiles, + patch_size=self.model.patch_size, + ) + quantile_prediction_next_step = torch.quantile( + quantile_prediction_next_step, + torch.tensor( + self.model.quantile_levels, + device=target.device, + dtype=torch.float32, + ), + dim=-2, + ) + quantile_prediction[..., adjusted_assign_index, :, :] = rearrange( + quantile_prediction_next_step, + "num_quantiles ... patch_size -> ... num_quantiles patch_size", + ) + + expand_target[..., adjusted_assign_index, :] = rearrange( + quantile_prediction_next_step, + "num_quantiles batch_size predict_token patch_size -> batch_size num_quantiles predict_token patch_size", + num_quantiles=self.model.num_quantiles, + patch_size=self.model.patch_size, + predict_token=len(adjusted_assign_index), + ) + expand_prediction_mask[..., adjusted_assign_index] = False + + remain_step -= self.model.num_predict_token + + return self._format_preds( + self.model.num_quantiles, + self.model.patch_size, + quantile_prediction, + self.target_dim, + ) + + @torch.no_grad() + def predict( + self, + past_target: List[Float[np.ndarray, "past_time tgt"]], + feat_dynamic_real: Optional[List[Float[np.ndarray, "time feat"]]] = None, + past_feat_dynamic_real: Optional[ + List[Float[np.ndarray, "past_time feat"]] + ] = None, + ) -> Float[np.ndarray, "batch num_quantiles future_time *tgt"]: + """ + High-level prediction interface accepting list of numpy arrays. + + Args: + past_target: List of historical time series (numpy arrays) + feat_dynamic_real: List of future dynamic features (optional) + past_feat_dynamic_real: List of past dynamic features (optional) + + Returns: + Predictions as numpy array of shape (batch, num_quantiles, prediction_length, target_dim) + """ + # Prepare data + data_entry = { + "past_target": past_target, + "feat_dynamic_real": feat_dynamic_real, + "past_feat_dynamic_real": past_feat_dynamic_real, + } + + # Create observed masks + data_entry["past_observed_target"] = [~np.isnan(x) for x in past_target] + if feat_dynamic_real: + data_entry["observed_feat_dynamic_real"] = [ + ~np.isnan(x) for x in feat_dynamic_real + ] + else: + data_entry["observed_feat_dynamic_real"] = None + + if past_feat_dynamic_real: + data_entry["past_observed_feat_dynamic_real"] = [ + ~np.isnan(x) for x in past_feat_dynamic_real + ] + else: + data_entry["past_observed_feat_dynamic_real"] = None + + # Impute missing values + impute = CausalMeanImputation() + + def process_sample(sample): + arr = np.asarray(sample) + if arr.ndim == 1: + arr = arr[:, np.newaxis] + if np.issubdtype(arr.dtype, np.number) and np.isnan(arr).any(): + arr = impute(arr) + return arr + + for key, value in data_entry.items(): + if value is not None: + data_entry[key] = [process_sample(sample) for sample in value] + + # Create padding mask + data_entry["past_is_pad"] = np.zeros( + (len(data_entry["past_target"]), self.context_length), dtype=bool + ) + + # Pad or slice to context length + for key in data_entry: + if data_entry[key] is not None and isinstance(data_entry[key], list): + for idx in range(len(data_entry[key])): + if data_entry[key][idx].shape[0] > self.context_length: + data_entry[key][idx] = data_entry[key][idx][ + -self.context_length :, : + ] + else: + pad_length = self.context_length - data_entry[key][idx].shape[0] + pad_block = np.full( + (pad_length, data_entry[key][idx].shape[1]), + data_entry[key][idx][0], + dtype=data_entry[key][idx].dtype, + ) + data_entry[key][idx] = np.concatenate( + [pad_block, data_entry[key][idx]], axis=0 + ) + if key == "past_target": + data_entry["past_is_pad"][idx, :pad_length] = True + + # Convert to tensors + device = next(self.parameters()).device + for k in ["past_target", "feat_dynamic_real", "past_feat_dynamic_real"]: + if data_entry[k] is not None: + data_entry[k] = torch.tensor( + np.array(data_entry[k]), device=device, dtype=torch.float32 + ) + + for k in [ + "past_observed_target", + "observed_feat_dynamic_real", + "past_observed_feat_dynamic_real", + "past_is_pad", + ]: + if data_entry[k] is not None: + data_entry[k] = torch.tensor( + np.array(data_entry[k]), device=device, dtype=torch.bool + ) + + # Get predictions + predictions = self(**data_entry).detach().cpu().numpy() + return predictions + + @staticmethod + def _patched_seq_pad( + patch_size: int, + x: torch.Tensor, + dim: int, + left: bool = True, + value: Optional[float] = None, + ) -> torch.Tensor: + """Pad sequence to be divisible by patch_size.""" + if dim >= 0: + dim = -x.ndim + dim + pad_length = -x.size(dim) % patch_size + if left: + pad = (pad_length, 0) + else: + pad = (0, pad_length) + pad = (0, 0) * (abs(dim) - 1) + pad + return torch.nn.functional.pad(x, pad, value=value) + + def _generate_time_id( + self, + patch_size: int, + past_observed_target: Bool[torch.Tensor, "batch past_seq tgt"], + ) -> Tuple[ + Int[torch.Tensor, "batch past_token"], Int[torch.Tensor, "batch future_token"] + ]: + """Generate time IDs for past and future sequences.""" + past_seq_id = reduce( + self._patched_seq_pad(patch_size, past_observed_target, -2, left=True), + "... (seq patch) dim -> ... seq", + "max", + patch=patch_size, + ) + past_seq_id = torch.clamp( + past_seq_id.cummax(dim=-1).values.cumsum(dim=-1) - 1, min=0 + ) + + batch_shape = " ".join(map(str, past_observed_target.shape[:-2])) + future_seq_id = ( + repeat( + torch.arange( + self.prediction_token_length(patch_size), + device=past_observed_target.device, + ), + f"prediction -> {batch_shape} prediction", + ) + + past_seq_id.max(dim=-1, keepdim=True).values + + 1 + ) + return past_seq_id, future_seq_id + + def _convert( + self, + patch_size: int, + past_target: Float[torch.Tensor, "batch past_time tgt"], + past_observed_target: Bool[torch.Tensor, "batch past_time tgt"], + past_is_pad: Bool[torch.Tensor, "batch past_time"], + future_target: Optional[Float[torch.Tensor, "batch future_time tgt"]] = None, + future_observed_target: Optional[ + Bool[torch.Tensor, "batch future_time tgt"] + ] = None, + future_is_pad: Optional[Bool[torch.Tensor, "batch future_time"]] = None, + feat_dynamic_real: Optional[Float[torch.Tensor, "batch time feat"]] = None, + observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch time feat"] + ] = None, + past_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + past_observed_feat_dynamic_real: Optional[ + Float[torch.Tensor, "batch past_time past_feat"] + ] = None, + ) -> Tuple[ + Float[torch.Tensor, "batch combine_seq patch"], + Bool[torch.Tensor, "batch combine_seq patch"], + Int[torch.Tensor, "batch combine_seq"], + Int[torch.Tensor, "batch combine_seq"], + Int[torch.Tensor, "batch combine_seq"], + Bool[torch.Tensor, "batch combine_seq"], + ]: + """ + Convert input tensors to packed format for the model. + + Returns: + Tuple of (target, observed_mask, sample_id, time_id, variate_id, prediction_mask) + """ + batch_shape = past_target.shape[:-2] + device = past_target.device + + target = [] + observed_mask = [] + sample_id = [] + time_id = [] + variate_id = [] + prediction_mask = [] + dim_count = 0 + + past_seq_id, future_seq_id = self._generate_time_id( + patch_size, past_observed_target + ) + + # Process target variable + if future_target is None: + future_target = torch.zeros( + batch_shape + (self.prediction_length, past_target.shape[-1]), + dtype=past_target.dtype, + device=device, + ) + + target.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad(patch_size, past_target, -2, left=True), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, 0), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, future_target, -2, left=False + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, 0), + ), + ] + ) + + if future_observed_target is None: + future_observed_target = torch.ones( + batch_shape + (self.prediction_length, past_observed_target.shape[-1]), + dtype=torch.bool, + device=device, + ) + + observed_mask.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, past_observed_target, -2, left=True + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, 0), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, future_observed_target, -2, left=False + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, 0), + ), + ] + ) + + if future_is_pad is None: + future_is_pad = torch.zeros( + batch_shape + (self.prediction_length,), + dtype=torch.long, + device=device, + ) + + sample_id.extend( + [ + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, past_is_pad, -1, left=True, value=1 + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=past_target.shape[-1], + ), + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, future_is_pad, -1, left=False, value=1 + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=past_target.shape[-1], + ), + ] + ) + + time_id.extend( + [past_seq_id] * past_target.shape[-1] + + [future_seq_id] * past_target.shape[-1] + ) + + variate_id.extend( + [ + repeat( + torch.arange(past_target.shape[-1], device=device) + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim past)", + past=self.context_token_length(patch_size), + ), + repeat( + torch.arange(past_target.shape[-1], device=device) + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim future)", + future=self.prediction_token_length(patch_size), + ), + ] + ) + dim_count += past_target.shape[-1] + + prediction_mask.extend( + [ + torch.zeros( + batch_shape + + (self.context_token_length(patch_size) * past_target.shape[-1],), + dtype=torch.bool, + device=device, + ), + torch.ones( + batch_shape + + ( + self.prediction_token_length(patch_size) + * past_target.shape[-1], + ), + dtype=torch.bool, + device=device, + ), + ] + ) + + # Process dynamic features if provided + if feat_dynamic_real is not None: + if observed_feat_dynamic_real is None: + raise ValueError( + "observed_feat_dynamic_real must be provided if feat_dynamic_real is provided" + ) + + target.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + feat_dynamic_real[..., : self.context_length, :], + -2, + left=True, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, 0), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + feat_dynamic_real[..., self.context_length :, :], + -2, + left=False, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, 0), + ), + ] + ) + + observed_mask.extend( + [ + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + observed_feat_dynamic_real[ + ..., : self.context_length, : + ], + -2, + left=True, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, 0), + ), + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, + observed_feat_dynamic_real[ + ..., self.context_length :, : + ], + -2, + left=False, + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, 0), + ), + ] + ) + + sample_id.extend( + [ + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, past_is_pad, -1, left=True + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=feat_dynamic_real.shape[-1], + ), + torch.ones( + batch_shape + + ( + self.prediction_token_length(patch_size) + * feat_dynamic_real.shape[-1], + ), + dtype=torch.long, + device=device, + ), + ] + ) + + time_id.extend( + [past_seq_id] * feat_dynamic_real.shape[-1] + + [future_seq_id] * feat_dynamic_real.shape[-1] + ) + + variate_id.extend( + [ + repeat( + torch.arange(feat_dynamic_real.shape[-1], device=device) + + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim past)", + past=self.context_token_length(patch_size), + ), + repeat( + torch.arange(feat_dynamic_real.shape[-1], device=device) + + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim future)", + future=self.prediction_token_length(patch_size), + ), + ] + ) + dim_count += feat_dynamic_real.shape[-1] + + prediction_mask.extend( + [ + torch.zeros( + batch_shape + + ( + self.context_token_length(patch_size) + * feat_dynamic_real.shape[-1], + ), + dtype=torch.bool, + device=device, + ), + torch.zeros( + batch_shape + + ( + self.prediction_token_length(patch_size) + * feat_dynamic_real.shape[-1], + ), + dtype=torch.bool, + device=device, + ), + ] + ) + + if past_feat_dynamic_real is not None: + if past_observed_feat_dynamic_real is None: + raise ValueError( + "past_observed_feat_dynamic_real must be provided if past_feat_dynamic_real is provided" + ) + + target.append( + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, past_feat_dynamic_real, -2, left=True + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, 0), + ) + ) + + observed_mask.append( + torch.nn.functional.pad( + rearrange( + self._patched_seq_pad( + patch_size, past_observed_feat_dynamic_real, -2, left=True + ), + "... (seq patch) dim -> ... (dim seq) patch", + patch=patch_size, + ), + (0, 0), + ) + ) + + sample_id.append( + repeat( + reduce( + ( + self._patched_seq_pad( + patch_size, past_is_pad, -1, left=True + ) + == 0 + ).int(), + "... (seq patch) -> ... seq", + "max", + patch=patch_size, + ), + "... seq -> ... (dim seq)", + dim=past_feat_dynamic_real.shape[-1], + ) + ) + + time_id.extend([past_seq_id] * past_feat_dynamic_real.shape[-1]) + + variate_id.append( + repeat( + torch.arange(past_feat_dynamic_real.shape[-1], device=device) + + dim_count, + f"dim -> {' '.join(map(str, batch_shape))} (dim past)", + past=self.context_token_length(patch_size), + ) + ) + dim_count += past_feat_dynamic_real.shape[-1] + + prediction_mask.append( + torch.zeros( + batch_shape + + ( + self.context_token_length(patch_size) + * past_feat_dynamic_real.shape[-1], + ), + dtype=torch.bool, + device=device, + ) + ) + + # Concatenate all components + target = torch.cat(target, dim=-2) + observed_mask = torch.cat(observed_mask, dim=-2) + sample_id = torch.cat(sample_id, dim=-1) + time_id = torch.cat(time_id, dim=-1) + variate_id = torch.cat(variate_id, dim=-1) + prediction_mask = torch.cat(prediction_mask, dim=-1) + + return ( + target, + observed_mask, + sample_id, + time_id, + variate_id, + prediction_mask, + ) + + def _format_preds( + self, + num_quantiles: int, + patch_size: int, + preds: Float[torch.Tensor, "batch combine_seq patch"], + target_dim: int, + ) -> Float[torch.Tensor, "batch num_quantiles future_time *tgt"]: + """Format predictions to the expected output shape.""" + start = target_dim * self.context_token_length(patch_size) + end = start + target_dim * self.prediction_token_length(patch_size) + preds = preds[..., start:end, :num_quantiles, :patch_size] + preds = rearrange( + preds, + "... (dim seq) num_quantiles patch -> ... num_quantiles (seq patch) dim", + dim=target_dim, + )[..., : self.prediction_length, :] + return preds.squeeze(-1) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/__init__.py new file mode 100644 index 000000000000..2a1e720805f2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/attention.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/attention.py new file mode 100644 index 000000000000..407b4f4abd99 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/attention.py @@ -0,0 +1,366 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import math +from collections.abc import Callable +from functools import partial +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from jaxtyping import Bool, Float, Int +from torch import nn + +from iotdb.ainode.core.model.moirai2.module.position import ( + AttentionBias, + QueryKeyProjection, +) + + +def native_scaled_dot_product_attention( + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + value: Float[torch.Tensor, "*batch group hpg kv_len dim"], + attn_mask: Optional[ + Bool[torch.Tensor, "*batch #group #hpg q_len kv_len"] + | Float[torch.Tensor, "*batch #group #hpg q_len kv_len"] + ] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, +): + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_weight = query @ key.transpose(-2, -1) * scale_factor + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias = torch.zeros_like(attn_weight) + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_weight = attn_weight + attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + + +class GroupedQueryAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_groups: int, + bias: bool = True, + norm_layer: Optional[type[nn.Module] | partial[nn.Module]] = nn.LayerNorm, + softmax_scale: Optional[float] = None, + attn_dropout_p: float = 0.0, + var_attn_bias: Optional[Callable[[], AttentionBias]] = None, + time_attn_bias: Optional[Callable[[], AttentionBias]] = None, + var_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + time_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + ): + super().__init__() + assert num_heads > 0 and dim % num_heads == 0 + assert (num_heads % num_groups == 0) and (num_heads >= num_groups) + + self.num_heads = num_heads + self.num_groups = num_groups + self.head_dim = dim // num_heads + self.heads_per_group = num_heads // num_groups + self.var_attn_bias = var_attn_bias() if var_attn_bias is not None else None + self.time_attn_bias = time_attn_bias() if time_attn_bias is not None else None + self.var_qk_proj = var_qk_proj() if var_qk_proj is not None else None + self.time_qk_proj = time_qk_proj() if time_qk_proj is not None else None + + self.softmax_scale = softmax_scale or 1 / math.sqrt(self.head_dim) + + self.q_proj = nn.Linear(dim, dim, bias=bias) + self.k_proj = nn.Linear(dim, self.head_dim * num_groups, bias=bias) + self.v_proj = nn.Linear(dim, self.head_dim * num_groups, bias=bias) + self.q_norm = ( + norm_layer(self.head_dim) if norm_layer is not None else nn.Identity() + ) + self.k_norm = ( + norm_layer(self.head_dim) if norm_layer is not None else nn.Identity() + ) + self.attn_dropout_p = attn_dropout_p + self.out_proj = nn.Linear(dim, dim, bias=bias) + + def _get_var_id( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_var_id: Optional[Int[torch.Tensor, "*batch q_len"]], + kv_var_id: Optional[Int[torch.Tensor, "*batch kv_len"]], + ) -> tuple[ + Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], + Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], + ]: + if self.var_attn_bias is not None or self.var_qk_proj is not None: + if query_var_id is None: + query_var_id = repeat( + torch.zeros((), device=query.device, dtype=torch.long), + f" -> {' '.join(map(str, query.shape[:-4]))} 1 1 {query.shape[-2]}", + ) + else: + query_var_id = rearrange(query_var_id, "... q_len -> ... 1 1 q_len") + + if kv_var_id is None: + kv_var_id = repeat( + torch.zeros((), device=key.device, dtype=torch.long), + f" -> {' '.join(map(str, key.shape[:-4]))} 1 1 {key.shape[-2]}", + ) + else: + kv_var_id = rearrange(kv_var_id, "... kv_len -> ... 1 1 kv_len") + + return query_var_id, kv_var_id + + def _get_time_id( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_time_id: Optional[Int[torch.Tensor, "*batch q_len"]], + kv_time_id: Optional[Int[torch.Tensor, "*batch kv_len"]], + ) -> tuple[ + Optional[Int[torch.Tensor, "*batch 1 1 q_len"]], + Optional[Int[torch.Tensor, "*batch 1 1 kv_len"]], + ]: + if self.time_attn_bias is not None or self.time_qk_proj is not None: + if query_time_id is None: + query_time_id = repeat( + torch.arange( + query.shape[-2], device=query.device, dtype=torch.long + ), + f"q_len -> {' '.join(map(str, query.shape[:-4]))} 1 1 q_len", + ) + else: + query_time_id = rearrange(query_time_id, "... q_len -> ... 1 1 q_len") + + if kv_time_id is None: + kv_time_id = repeat( + torch.arange(key.shape[-2], device=key.device, dtype=torch.long), + f"kv_len -> {' '.join(map(str, key.shape[:-4]))} 1 1 kv_len", + ) + else: + kv_time_id = rearrange(kv_time_id, "... kv_len-> ... 1 1 kv_len") + + return query_time_id, kv_time_id + + def _update_attn_mask( + self, + attn_mask: Optional[Bool[torch.Tensor, "*batch q_len kv_len"]], + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_var_id: Optional[Int[torch.Tensor, "*batch 1 1 q_len"]] = None, + kv_var_id: Optional[Int[torch.Tensor, "*batch 1 1 kv_len"]] = None, + query_time_id: Optional[Int[torch.Tensor, "*batch 1 1 q_len"]] = None, + kv_time_id: Optional[Int[torch.Tensor, "*batch 1 1 kv_len"]] = None, + ) -> Optional[ + Bool[torch.Tensor, "*batch #group #hpg q_len kv_len"] + | Float[torch.Tensor, "*batch #group #hpg q_len kv_len"] + ]: + if attn_mask is not None: + attn_mask = rearrange( + attn_mask, + "... q_len kv_len -> ... 1 1 q_len kv_len", + ) + + attn_bias = 0 + if self.var_attn_bias is not None: + attn_bias = attn_bias + self.var_attn_bias( + query, + key, + query_id=query_var_id, + kv_id=kv_var_id, + ) + + if self.time_attn_bias is not None: + attn_bias = attn_bias + self.time_attn_bias( + query, + key, + query_id=query_time_id, + kv_id=kv_time_id, + ) + + attn_mask = ( + attn_mask + if isinstance(attn_bias, int) + else ( + attn_bias + if attn_mask is None + else attn_bias.masked_fill(attn_mask.logical_not(), float("-inf")) + ) + ) + return attn_mask + + def _qk_proj( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_var_id: Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], + kv_var_id: Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], + query_time_id: Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], + kv_time_id: Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], + ) -> tuple[ + Float[torch.Tensor, "*batch group hpg q_len dim"], + Float[torch.Tensor, "*batch group hpg kv_len dim"], + ]: + if self.var_qk_proj is not None: + query, key = self.var_qk_proj( + query, key, query_id=query_var_id, kv_id=kv_var_id + ) + + if self.time_qk_proj is not None: + query, key = self.time_qk_proj( + query, key, query_id=query_time_id, kv_id=kv_time_id + ) + + return query, key + + def forward( + self, + query: Float[torch.Tensor, "*batch q_len dim"], + key: Float[torch.Tensor, "*batch kv_len dim"], + value: Float[torch.Tensor, "*batch kv_len dim"], + attn_mask: Optional[Bool[torch.Tensor, "*batch q_len kv_len"]] = None, + query_var_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, + kv_var_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, + query_time_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, + kv_time_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, + ) -> Float[torch.Tensor, "*batch q_len dim"]: + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + query = self.q_norm( + rearrange( + query, + "... q_len (group hpg dim) -> ... group hpg q_len dim", + group=self.num_groups, + hpg=self.heads_per_group, + ) + ) + key = self.k_norm( + repeat( + key, + "... kv_len (group dim) -> ... group hpg kv_len dim", + group=self.num_groups, + hpg=self.heads_per_group, + ) + ) + value = repeat( + value, + "... kv_len (group dim) -> ... group hpg kv_len dim", + group=self.num_groups, + hpg=self.heads_per_group, + ) + + query_var_id, kv_var_id = self._get_var_id(query, key, query_var_id, kv_var_id) + query_time_id, kv_time_id = self._get_time_id( + query, + key, + query_time_id, + kv_time_id, + ) + + attn_mask = self._update_attn_mask( + attn_mask, + query, + key, + query_var_id=query_var_id, + kv_var_id=kv_var_id, + query_time_id=query_time_id, + kv_time_id=kv_time_id, + ) + + query, key = self._qk_proj( + query, + key, + query_var_id=query_var_id, + kv_var_id=kv_var_id, + query_time_id=query_time_id, + kv_time_id=kv_time_id, + ) + + out = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=self.attn_dropout_p, + scale=self.softmax_scale, + ) + out = rearrange(out, "... group hpg q_len dim -> ... q_len (group hpg dim)") + return self.out_proj(out) + + +class MultiQueryAttention(GroupedQueryAttention): + def __init__( + self, + dim: int, + num_heads: int, + bias: bool = True, + norm_layer: Optional[type[nn.Module] | partial[nn.Module]] = nn.LayerNorm, + softmax_scale: Optional[float] = None, + attn_dropout_p: float = 0.0, + var_attn_bias: Optional[Callable[[], AttentionBias]] = None, + time_attn_bias: Optional[Callable[[], AttentionBias]] = None, + var_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + time_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + ): + super().__init__( + dim=dim, + num_heads=num_heads, + num_groups=1, + bias=bias, + norm_layer=norm_layer, + softmax_scale=softmax_scale, + attn_dropout_p=attn_dropout_p, + var_attn_bias=var_attn_bias, + time_attn_bias=time_attn_bias, + var_qk_proj=var_qk_proj, + time_qk_proj=time_qk_proj, + ) + + +class MultiHeadAttention(GroupedQueryAttention): + def __init__( + self, + dim: int, + num_heads: int, + bias: bool = True, + norm_layer: Optional[type[nn.Module] | partial[nn.Module]] = nn.LayerNorm, + softmax_scale: Optional[float] = None, + attn_dropout_p: float = 0.0, + var_attn_bias: Optional[Callable[[], AttentionBias]] = None, + time_attn_bias: Optional[Callable[[], AttentionBias]] = None, + var_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + time_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + ): + super().__init__( + dim=dim, + num_heads=num_heads, + num_groups=num_heads, + bias=bias, + norm_layer=norm_layer, + softmax_scale=softmax_scale, + attn_dropout_p=attn_dropout_p, + var_attn_bias=var_attn_bias, + time_attn_bias=time_attn_bias, + var_qk_proj=var_qk_proj, + time_qk_proj=time_qk_proj, + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/ffn.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/ffn.py new file mode 100644 index 000000000000..2c737ecc00c8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/ffn.py @@ -0,0 +1,159 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch import nn + + +class FeedForward(nn.Module): + def __init__( + self, + in_dim: int, + hidden_dim: Optional[int] = None, + out_dim: Optional[int] = None, + activation: Callable[[torch.Tensor], torch.Tensor] = F.gelu, + bias: bool = True, + ffn_dropout_p: float = 0.0, + ): + super().__init__() + hidden_dim = hidden_dim or 4 * in_dim + out_dim = out_dim or in_dim + + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self.bias = bias + self.ffn_dropout_p = ffn_dropout_p + + self.fc1 = nn.Linear(in_dim, hidden_dim, bias=bias) + self.fc2 = nn.Linear(hidden_dim, out_dim, bias=bias) + self.dropout1 = nn.Dropout(ffn_dropout_p) + self.dropout2 = nn.Dropout(ffn_dropout_p) + self.activation = activation + + def forward( + self, + x: Float[torch.Tensor, "... in_dim"], + centroid: Optional[Float[torch.Tensor, "expert in_dim"]] = None, + ) -> Float[torch.Tensor, "... out_dim"]: + x = self._in_proj(x) + return self.dropout2(self.fc2(self.dropout1(x))) + + def _in_proj( + self, x: Float[torch.Tensor, "... in_dim"] + ) -> Float[torch.Tensor, "... out_dim"]: + return self.activation(self.fc1(x)) + + +class GatedLinearUnitFeedForward(FeedForward): + def __init__( + self, + in_dim: int, + hidden_dim: Optional[int] = None, + out_dim: Optional[int] = None, + activation: Callable[[torch.Tensor], torch.Tensor] = F.silu, + bias: bool = True, + ffn_dropout_p: float = 0.0, + ): + super().__init__( + in_dim, + hidden_dim=hidden_dim or self.adjust_hidden_dim(4 * in_dim), + out_dim=out_dim, + activation=activation, + bias=bias, + ffn_dropout_p=ffn_dropout_p, + ) + self.fc_gate = nn.Linear(self.in_dim, self.hidden_dim, bias=self.bias) + + @staticmethod + def adjust_hidden_dim(dim): + return (int(dim * 2 / 3) + 7) // 8 * 8 + + def _in_proj( + self, x: Float[torch.Tensor, "... in_dim"] + ) -> Float[torch.Tensor, "... out_dim"]: + return self.activation(self.fc_gate(x)) * self.fc1(x) + + +class MoEFeedForward(nn.Module): + def __init__( + self, + num_experts: int, + num_experts_per_token: int, + in_dim: int, + hidden_dim: Optional[int] = None, + out_dim: Optional[int] = None, + activation: Callable[[torch.Tensor], torch.Tensor] = F.silu, + bias: bool = True, + ffn_dropout_p: float = 0.0, + ): + super().__init__() + self.num_experts = num_experts + self.num_experts_per_token = num_experts_per_token + + self.experts = nn.ModuleList( + [ + GatedLinearUnitFeedForward( + in_dim=in_dim, + hidden_dim=hidden_dim, + out_dim=out_dim, + activation=activation, + bias=bias, + ffn_dropout_p=ffn_dropout_p, + ) + for _ in range(num_experts) + ] + ) + + def forward( + self, + x: Float[torch.Tensor, "... in_dim"], + centroid: Optional[Float[torch.Tensor, "expert in_dim"]] = None, + ) -> Float[torch.Tensor, "... dim"]: + x_squashed = x.view(-1, x.shape[-1]) + + centroid = centroid.to(x.device).type_as(x) + if len(x.shape) > 3: + x_temp = x.view(-1, x.shape[-2], x.shape[-1]) + else: + x_temp = x + centroid = centroid.unsqueeze(0).repeat(x_temp.shape[0], 1, 1) + cdist = torch.cdist(x_temp, centroid) + gate_logits = cdist.view(-1, cdist.shape[-1]) + + weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_token) + weights = nn.functional.softmax( + weights, + dim=1, + dtype=torch.float, + ).type_as(x) + + results = torch.zeros_like(x_squashed) + for i, expert in enumerate(self.experts): + batch_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx] += weights[batch_idx, nth_expert, None] * expert( + x_squashed[batch_idx] + ) + + results = results.view_as(x) + return results diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/norm.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/norm.py new file mode 100644 index 000000000000..e7ef8eb6932d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/norm.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Optional + +import torch +from jaxtyping import Float +from torch import nn + + +class RMSNorm(nn.Module): + def __init__( + self, + normalized_shape: int | list[int] | torch.Size, + eps: float = 1e-5, + weight: bool = True, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + + self.normalized_shape = normalized_shape + self.eps = eps + self.mean_dim = tuple(range(-len(normalized_shape), 0)) + + if weight: + self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def forward( + self, x: Float[torch.Tensor, "*batch normalized_shape"] + ) -> Float[torch.Tensor, "*batch normalized_shape"]: + output = x * torch.rsqrt( + x.pow(2).mean(dim=self.mean_dim, keepdim=True) + self.eps + ) + if self.weight is not None: + return output * self.weight + return output + + def extra_repr(self) -> str: + return ( + f"normalized_shape={self.normalized_shape}, " + f"eps={self.eps}, " + f"weight={self.weight is not None}" + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/packed_scaler.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/packed_scaler.py new file mode 100644 index 000000000000..832145acef36 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/packed_scaler.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Optional + +import torch +from einops import reduce +from jaxtyping import Bool, Float, Int +from torch import nn + +from iotdb.ainode.core.model.moirai2.common.torch_util import safe_div + + +class PackedScaler(nn.Module): + def forward( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"] = None, + sample_id: Int[torch.Tensor, "*batch seq_len"] = None, + variate_id: Optional[Int[torch.Tensor, "*batch seq_len"]] = None, + ): + if observed_mask is None: + observed_mask = torch.ones_like(target, dtype=torch.bool) + if sample_id is None: + sample_id = torch.zeros( + target.shape[:-1], dtype=torch.long, device=target.device + ) + if variate_id is None: + variate_id = torch.zeros( + target.shape[:-1], dtype=torch.long, device=target.device + ) + + loc, scale = self._get_loc_scale( + target.double(), observed_mask, sample_id, variate_id + ) + return loc.float(), scale.float() + + def _get_loc_scale( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> tuple[ + Float[torch.Tensor, "*batch seq_len #dim"], + Float[torch.Tensor, "*batch seq_len #dim"], + ]: + raise NotImplementedError + + +class PackedNOPScaler(PackedScaler): + def _get_loc_scale( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> tuple[ + Float[torch.Tensor, "*batch 1 #dim"], Float[torch.Tensor, "*batch 1 #dim"] + ]: + loc = torch.zeros_like(target, dtype=target.dtype) + scale = torch.ones_like(target, dtype=target.dtype) + return loc, scale + + +class PackedStdScaler(PackedScaler): + def __init__(self, correction: int = 1, minimum_scale: float = 1e-5): + super().__init__() + self.correction = correction + self.minimum_scale = minimum_scale + + def _get_loc_scale( + self, + target: Float[torch.Tensor, "*batch seq_len #dim"], + observed_mask: Bool[torch.Tensor, "*batch seq_len #dim"], + sample_id: Int[torch.Tensor, "*batch seq_len"], + variate_id: Int[torch.Tensor, "*batch seq_len"], + ) -> tuple[ + Float[torch.Tensor, "*batch 1 #dim"], Float[torch.Tensor, "*batch 1 #dim"] + ]: + id_mask = torch.logical_and( + torch.eq(sample_id.unsqueeze(-1), sample_id.unsqueeze(-2)), + torch.eq(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)), + ) + tobs = reduce( + id_mask * reduce(observed_mask, "... seq dim -> ... 1 seq", "sum"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + loc = reduce( + id_mask * reduce(target * observed_mask, "... seq dim -> ... 1 seq", "sum"), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + loc = safe_div(loc, tobs) + var = reduce( + id_mask + * reduce( + ((target - loc) ** 2) * observed_mask, + "... seq dim -> ... 1 seq", + "sum", + ), + "... seq1 seq2 -> ... seq1 1", + "sum", + ) + var = safe_div(var, (tobs - self.correction)) + scale = torch.sqrt(var + self.minimum_scale) + loc[sample_id == 0] = 0 + scale[sample_id == 0] = 1 + return loc, scale diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/__init__.py new file mode 100644 index 000000000000..4d23d3fd41b9 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/__init__.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from .additive import LearnedEmbedding, SinusoidalPositionEncoding +from .attn_bias import ( + AttentionBias, + BinaryAttentionBias, + LinearAttentionBias, + RelativeAttentionBias, +) +from .attn_projection import ( + IdentityProjection, + LearnedProjection, + Projection, + QueryKeyProjection, + RotaryProjection, +) + +__all__ = [ + "AttentionBias", + "IdentityProjection", + "RelativeAttentionBias", + "BinaryAttentionBias", + "LearnedEmbedding", + "LearnedProjection", + "LinearAttentionBias", + "Projection", + "QueryKeyProjection", + "RotaryProjection", + "SinusoidalPositionEncoding", +] diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/additive.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/additive.py new file mode 100644 index 000000000000..8bf1ba723afb --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/additive.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import math + +import torch +from jaxtyping import Float, Int +from torch import nn + + +class SinusoidalPositionEncoding(nn.Module): + def __init__( + self, + *, + width: int, + max_len: int, + normalize: bool = True, + ): + """ + Construct a sinusoidal positional embedding module. + + :param width: + Width of the embedding. + :param max_len: + Maximum length of the embedding. + :param normalize: + Perform L2 normalization of the embedding. + """ + super().__init__() + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, width, 2) * (-math.log(10000.0) / width)) + + pe = torch.zeros(max_len, width) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + if normalize: + l2 = torch.linalg.vector_norm(pe, dim=-1) + pe /= l2.unsqueeze(-1) + + self.register_buffer("pe", pe, persistent=False) + + def forward( + self, pos_id: Int[torch.Tensor, "*batch length"] + ) -> Float[torch.Tensor, "*batch length dim"]: + return self.pe[pos_id] + + +class LearnedEmbedding(nn.Module): + def __init__( + self, + *, + width: int, + max_len: int, + ): + super().__init__() + self.pe = nn.Embedding( + max_len, + width, + ) + + def forward( + self, pos_id: Int[torch.Tensor, "*batch length"] + ) -> Float[torch.Tensor, "*batch length dim"]: + return self.pe(pos_id) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/attn_bias.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/attn_bias.py new file mode 100644 index 000000000000..49b8139b7c53 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/attn_bias.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import abc + +import torch +from einops import rearrange +from jaxtyping import Float, Int +from torch import nn + + +class AttentionBias(nn.Module, abc.ABC): + def __init__( + self, + dim: int, + num_heads: int, + num_groups: int, + ): + super().__init__() + assert num_heads > 0 and dim % num_heads == 0 + assert (num_heads % num_groups == 0) and (num_heads >= num_groups) + + self.num_heads = num_heads + self.num_groups = num_groups + self.heads_per_group = num_heads // num_groups + self.head_dim = dim // num_heads + + @abc.abstractmethod + def forward( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_id: Int[torch.Tensor, "*batch 1 1 q_len"], + kv_id: Int[torch.Tensor, "*batch 1 1 kv_len"], + ) -> Float[torch.Tensor, "*batch #group #hpg q_len kv_len"]: ... + + +class RelativeAttentionBias(AttentionBias): + def __init__(self, num_buckets: int, dim: int, num_heads: int, num_groups: int): + super().__init__(dim, num_heads, num_groups) + self.emb = nn.Embedding( + num_embeddings=num_buckets, embedding_dim=self.num_heads + ) + + def forward( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_id: Int[torch.Tensor, "*batch 1 1 q_len"], + kv_id: Int[torch.Tensor, "*batch 1 1 kv_len"], + ) -> Float[torch.Tensor, "*batch #group #hpg q_len kv_len"]: + raise NotImplementedError + + +class BinaryAttentionBias(AttentionBias): + def __init__(self, dim: int, num_heads: int, num_groups: int): + super().__init__(dim, num_heads, num_groups) + self.emb = nn.Embedding(num_embeddings=2, embedding_dim=self.num_heads) + + def forward( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_id: Int[torch.Tensor, "*batch 1 1 q_len"], + kv_id: Int[torch.Tensor, "*batch 1 1 kv_len"], + ) -> Float[torch.Tensor, "*batch #group #hpg q_len kv_len"]: + ind = torch.eq(query_id.unsqueeze(-1), kv_id.unsqueeze(-2)) + weight = rearrange(self.emb.weight, "two num_heads -> two num_heads 1 1") + bias = rearrange( # try to avoid advanced indexing + ~ind * weight[:1] + ind * weight[1:], + "... 1 (group hpg) q_len kv_len -> ... group hpg q_len kv_len", + group=self.num_groups, + hpg=self.heads_per_group, + ) + return bias + + +class LinearAttentionBias(AttentionBias): + def __init__(self, dim: int, num_heads: int, num_groups: int): + super().__init__(dim, num_heads, num_groups) + m = 0.5 ** ((1 + torch.arange(self.num_heads)) * (8 / self.num_heads)) + m = rearrange( + m, + "(group hpg) -> group hpg 1 1", + group=self.num_groups, + hpg=self.heads_per_group, + ) + self.register_buffer("m", m) + + def forward( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_id: Int[torch.Tensor, "*batch 1 1 q_len"], + kv_id: Int[torch.Tensor, "*batch 1 1 kv_len"], + ) -> Float[torch.Tensor, "*batch #group #hpg q_len kv_len"]: + ind = kv_id.unsqueeze(-2) - query_id.unsqueeze(-1) + return self.m * ind diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/attn_projection.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/attn_projection.py new file mode 100644 index 000000000000..e5bfb63c6522 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/position/attn_projection.py @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import abc +import math +from functools import cached_property +from typing import Any, Optional + +import torch +from einops import einsum, rearrange, repeat +from jaxtyping import Float, Int +from torch import nn + + +class Projection(nn.Module, abc.ABC): + def __init__(self, proj_width: int, num_heads: int, num_groups: int, **kwargs: Any): + super().__init__() + self.proj_width = proj_width + self.num_heads = num_heads + self.num_groups = num_groups + self.heads_per_group = num_heads // num_groups + + @abc.abstractmethod + def forward( + self, + x: Float[torch.Tensor, "*batch group hpg seq dim"], + seq_id: Optional[Int[torch.Tensor, "*batch #group #hpg seq"]], + ) -> Float[torch.Tensor, "*batch group hpg seq dim"]: ... + + +class IdentityProjection(Projection): + def __init__(self, *, proj_width: int, num_heads: int, num_groups: int, **kwargs): + super().__init__(proj_width, num_heads, num_groups) + + def forward( + self, + x: Float[torch.Tensor, "*batch group hpg seq dim"], + seq_id: Optional[Int[torch.Tensor, "*batch #group #hpg seq"]] = None, + ) -> Float[torch.Tensor, "*batch group hpg seq dim"]: + return x + + +class RotaryProjection(Projection): + def __init__( + self, + *, + proj_width: int, + num_heads: int, + num_groups: int, + max_len: int = 512, + base: int = 10000, + ): + super().__init__(proj_width, num_heads, num_groups) + assert ( + self.proj_width % 2 == 0 + ), f"proj_width must be even, got {self.proj_width}" + self.register_buffer( + "theta", + 1.0 + / torch.pow( + base, + torch.arange(0, self.proj_width, 2, dtype=torch.float) + / self.proj_width, + ), + persistent=False, + ) + self.register_buffer("cos", None, persistent=False) + self.register_buffer("sin", None, persistent=False) + self._init_freq(max_len=max_len) + + def _init_freq(self, max_len: int): + if self.cos is None or self.cos.size(-2) < max_len: + position = torch.arange( + max_len, device=self.theta.device, dtype=self.theta.dtype + ) + m_theta = einsum(position, self.theta, "length, width -> length width") + m_theta = repeat(m_theta, "length width -> length (width 2)") + self.register_buffer("cos", torch.cos(m_theta), persistent=False) + self.register_buffer("sin", torch.sin(m_theta), persistent=False) + + @staticmethod + def _rotate(x: Float[torch.Tensor, "... dim"]) -> Float[torch.Tensor, "... dim"]: + x1, x2 = rearrange(x, "... (dim r) -> r ... dim", r=2) + return rearrange([-x2, x1], "r ... dim -> ... (dim r)", r=2) # noqa + + def forward( + self, + x: Float[torch.Tensor, "*batch group hpg seq dim"], + seq_id: Optional[Int[torch.Tensor, "*batch #group #hpg seq"]], + ) -> Float[torch.Tensor, "*batch group hpg seq dim"]: + self._init_freq(max_len=seq_id.max() + 1) + rot_cos = self.cos[seq_id] + rot_sin = self.sin[seq_id] + return rot_cos * x + rot_sin * self._rotate(x) + + +class LearnedProjection(Projection): + def __init__( + self, + *, + proj_width: int, + num_heads: int, + num_groups: int, + max_len: int = 512, + ): + super().__init__(proj_width, num_heads, num_groups) + self.max_len = max_len + self.weight = nn.Parameter( + torch.empty((max_len, self.proj_width, self.proj_width)) + ) + self.reset_parameters() + + def reset_parameters(self): + for idx in range(self.max_len): + nn.init.kaiming_uniform_(self.weight[idx], a=math.sqrt(5)) + + def forward( + self, + x: Float[torch.Tensor, "*batch group hpg seq dim"], + seq_id: Optional[Int[torch.Tensor, "*batch #group #hpg seq"]], + ) -> Float[torch.Tensor, "*batch group hpg seq dim"]: + weight = self.weight[seq_id] + return einsum(weight, x, "... out inp, ... inp -> ... out") + + +class QueryKeyProjection(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_groups: int, + proj_layer: type[Projection], + kwargs: Optional[dict[str, Any]] = None, + key_proj_layer: Optional[type[Projection]] = None, + key_kwargs: Optional[dict[str, Any]] = None, + partial_factor: Optional[tuple[float, float]] = None, + ): + super().__init__() + if partial_factor is not None: + assert ( + 0.0 <= partial_factor[0] < partial_factor[1] <= 1.0 + ), f"got {partial_factor[0]}, {partial_factor[1]}" + assert num_heads > 0 and dim % num_heads == 0 + assert (num_heads % num_groups == 0) and (num_heads >= num_groups) + + self.head_dim = dim // num_heads + self.partial_factor = partial_factor + self.query_proj = proj_layer( + proj_width=self.proj_width, + num_heads=num_heads, + num_groups=num_groups, + **(kwargs or {}), + ) + if key_proj_layer is None: + self.key_proj = self.query_proj + else: + self.key_proj = key_proj_layer( + proj_width=self.proj_width, + num_heads=num_heads, + num_groups=num_groups, + **(key_kwargs or {}), + ) + + @cached_property + def proj_width(self) -> int: + if self.partial_factor is None: + return self.head_dim + return int(self.head_dim * (self.partial_factor[1] - self.partial_factor[0])) + + @cached_property + def split_sizes(self) -> tuple[int, int, int]: + if self.partial_factor is None: + return 0, self.head_dim, 0 + return ( + int(self.partial_factor[0] * self.head_dim), + self.proj_width, + int((1.0 - self.partial_factor[1]) * self.head_dim), + ) + + def forward( + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_id: Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], + kv_id: Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], + ) -> tuple[ + Float[torch.Tensor, "*batch group hpg seq dim"], + Float[torch.Tensor, "*batch group hpg seq dim"], + ]: + if self.partial_factor is not None: + queries = list(query.split(self.split_sizes, dim=-1)) + keys = list(key.split(self.split_sizes, dim=-1)) + queries[1] = self.query_proj(queries[1], seq_id=query_id) + keys[1] = self.key_proj(keys[1], seq_id=kv_id) + query = torch.cat(queries, dim=-1) + key = torch.cat(keys, dim=-1) + else: + query = self.query_proj(query, seq_id=query_id) + key = self.key_proj(key, seq_id=kv_id) + return query, key diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/transformer.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/transformer.py new file mode 100644 index 000000000000..3f12468128c2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/transformer.py @@ -0,0 +1,246 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from collections.abc import Callable +from functools import partial +from typing import Optional + +import torch +import torch.nn.functional as F +from jaxtyping import Bool, Float, Int +from torch import nn + +from iotdb.ainode.core.model.moirai2.module.attention import GroupedQueryAttention +from iotdb.ainode.core.model.moirai2.module.ffn import ( + FeedForward, + GatedLinearUnitFeedForward, + MoEFeedForward, +) +from iotdb.ainode.core.model.moirai2.module.position import ( + AttentionBias, + QueryKeyProjection, +) + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + self_attn: GroupedQueryAttention, + ffn: FeedForward, + norm1: Optional[nn.Module], + norm2: Optional[nn.Module], + post_attn_dropout_p: float = 0.0, + pre_norm: bool = True, + ): + super().__init__() + self.pre_norm = pre_norm + self.dropout_p = post_attn_dropout_p + + self.self_attn = self_attn + self.ffn = ffn + self.norm1 = norm1 or nn.Identity() + self.norm2 = norm2 or nn.Identity() + self.dropout = nn.Dropout(post_attn_dropout_p) + + def forward( + self, + x: Float[torch.Tensor, "*batch time_len dim"], + attn_mask: Optional[Bool[torch.Tensor, "*batch time_len time_len"]] = None, + var_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + time_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + centroid: Optional[Float[torch.Tensor, "expert dim"]] = None, + ) -> Float[torch.Tensor, "*batch time_len dim"]: + if self.pre_norm: + x = x + self._sa_block( + self.norm1(x), attn_mask, var_id=var_id, time_id=time_id + ) + x = x + self.ffn(self.norm2(x), centroid=centroid) + else: + x = self.norm1( + x + self._sa_block(x, attn_mask, var_id=var_id, time_id=time_id) + ) + x = self.norm2(x + self.ffn(x, centroid=centroid)) + + return x + + def _sa_block( + self, + x: Float[torch.Tensor, "*batch time_len dim"], + attn_mask: Optional[Bool[torch.Tensor, "*batch time_len time_len"]], + var_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + time_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + ) -> Float[torch.Tensor, "*batch time_len dim"]: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + query_var_id=var_id, + kv_var_id=var_id, + query_time_id=time_id, + kv_time_id=time_id, + ) + return self.dropout(x) + + +class TransformerEncoder(nn.Module): + def __init__( + self, + d_model: int, + num_layers: int, + num_heads: Optional[int] = None, + num_groups: Optional[int] = None, + pre_norm: bool = True, + attn_dropout_p: float = 0.0, + dropout_p: float = 0.0, + norm_layer: Optional[Callable[[int], nn.Module]] = nn.LayerNorm, + activation: Callable[[torch.Tensor], torch.Tensor] = F.silu, + use_moe: bool = False, + use_glu: bool = True, + use_qk_norm: bool = True, + var_attn_bias_layer: Optional[Callable[[int, int, int], AttentionBias]] = None, + time_attn_bias_layer: Optional[Callable[[int, int, int], AttentionBias]] = None, + var_qk_proj_layer: Optional[ + Callable[[int, int, int], QueryKeyProjection] + ] = None, + time_qk_proj_layer: Optional[ + Callable[[int, int, int], QueryKeyProjection] + ] = None, + shared_var_attn_bias: bool = False, + shared_time_attn_bias: bool = False, + shared_var_qk_proj: bool = False, + shared_time_qk_proj: bool = False, + d_ff: Optional[int] = None, + ): + super().__init__() + self.use_moe = use_moe + num_heads = num_heads or d_model // 64 + num_groups = num_groups or num_heads # defaults to mha + + var_attn_bias = self.get_layer( + d_model, + num_heads, + num_groups, + var_attn_bias_layer, + shared_var_attn_bias, + ) + time_attn_bias = self.get_layer( + d_model, + num_heads, + num_groups, + time_attn_bias_layer, + shared_time_attn_bias, + ) + var_qk_proj = self.get_layer( + d_model, num_heads, num_groups, var_qk_proj_layer, shared_var_qk_proj + ) + time_qk_proj = self.get_layer( + d_model, num_heads, num_groups, time_qk_proj_layer, shared_time_qk_proj + ) + + get_self_attn = partial( + GroupedQueryAttention, + dim=d_model, + num_heads=num_heads, + num_groups=num_groups, + bias=False, + norm_layer=norm_layer if use_qk_norm else None, + softmax_scale=None, + attn_dropout_p=attn_dropout_p, + var_attn_bias=var_attn_bias, + time_attn_bias=time_attn_bias, + var_qk_proj=var_qk_proj, + time_qk_proj=time_qk_proj, + ) + if not use_moe: + get_ffn = partial( + GatedLinearUnitFeedForward if use_glu else FeedForward, + in_dim=d_model, + hidden_dim=d_ff, + out_dim=None, + activation=activation, + bias=False, + ffn_dropout_p=dropout_p, + ) + else: + get_ffn = partial( + MoEFeedForward, + num_experts=32, + num_experts_per_token=2, + in_dim=d_model, + hidden_dim=d_ff, + out_dim=None, + activation=activation, + bias=False, + ffn_dropout_p=dropout_p, + ) + self.register_buffer( + "centroid", torch.empty(num_layers, 32, d_model, dtype=torch.float64) + ) + get_encoder_layer_norm = partial(norm_layer, d_model) + + self.layers = nn.ModuleList( + [ + TransformerEncoderLayer( + self_attn=get_self_attn(), + ffn=get_ffn(), + norm1=get_encoder_layer_norm(), + norm2=get_encoder_layer_norm(), + pre_norm=pre_norm, + post_attn_dropout_p=dropout_p, + ) + for _ in range(num_layers) + ] + ) + self.norm = norm_layer(d_model) + + @staticmethod + def get_layer( + dim: int, + num_heads: int, + num_groups: int, + layer: Callable, + shared_layer: bool, + ) -> Optional[Callable[[], nn.Module]]: + if layer is None: + return None + if shared_layer: + module = layer(dim=dim, num_heads=num_heads, num_groups=num_groups) + return lambda: module + return partial(layer, dim=dim, num_heads=num_heads, num_groups=num_groups) + + def forward( + self, + x: Float[torch.Tensor, "*batch time_len dim"], + attn_mask: Optional[Bool[torch.Tensor, "*batch time_len time_len"]] = None, + var_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + time_id: Optional[Int[torch.Tensor, "*batch time_len"]] = None, + ) -> Float[torch.Tensor, "*batch time_len dim"]: + if self.use_moe: + for idx, layer in enumerate(self.layers): + x = layer( + x, + attn_mask, + var_id=var_id, + time_id=time_id, + centroid=self.centroid[idx], + ) + else: + for layer in self.layers: + x = layer(x, attn_mask, var_id=var_id, time_id=time_id) + return self.norm(x) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/ts_embed.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/ts_embed.py new file mode 100644 index 000000000000..5cddbaa1dae2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/module/ts_embed.py @@ -0,0 +1,294 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import math +from typing import Optional + +import torch +from einops import einsum, rearrange +from jaxtyping import Float, Int +from torch import nn + +from iotdb.ainode.core.model.moirai2.common.torch_util import size_to_mask + + +def fs2idx( + feat_size: Int[torch.Tensor, "*batch"], feat_sizes: Int[torch.Tensor, "num_feats"] +) -> Int[torch.Tensor, "*batch"]: + return ( + (rearrange(feat_size, "... -> ... 1") == feat_sizes) + .to(torch.long) + .argmax(dim=-1) + ) + + +class MultiInSizeLinear(nn.Module): + def __init__( + self, + in_features_ls: tuple[int, ...], + out_features: int, + bias: bool = True, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.in_features_ls = in_features_ls + self.out_features = out_features + + self.weight = nn.Parameter( + torch.empty( + (len(in_features_ls), out_features, max(in_features_ls)), dtype=dtype + ) + ) + + if bias: + self.bias = nn.Parameter( + torch.empty((len(in_features_ls), out_features), dtype=dtype) + ) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + self.register_buffer( + "mask", + rearrange( + size_to_mask(max(in_features_ls), torch.as_tensor(in_features_ls)), + "num_feats max_feat -> num_feats 1 max_feat", + ), + persistent=False, + ) + self.register_buffer( + "in_features_buffer", + torch.tensor(in_features_ls), + persistent=False, + ) + + def reset_parameters(self): + for idx, feat_size in enumerate(self.in_features_ls): + nn.init.kaiming_uniform_(self.weight[idx, :, :feat_size], a=math.sqrt(5)) + nn.init.zeros_(self.weight[idx, :, feat_size:]) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self.weight[idx, :, :feat_size] + ) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias[idx], -bound, bound) + + def forward( + self, + x: Float[torch.Tensor, "*batch max_feat"], + in_feat_size: Int[torch.Tensor, "*batch"], + ) -> Float[torch.Tensor, "*batch out_feat"]: + out = 0 + for idx, feat_size in enumerate(self.in_features_ls): + weight = self.weight[idx] * self.mask[idx] + bias = self.bias[idx] if self.bias is not None else 0 + out = out + ( + torch.eq(in_feat_size, feat_size).unsqueeze(-1) + * (einsum(weight, x, "out inp, ... inp -> ... out") + bias) + ) + return out + + def extra_repr(self) -> str: + return ( + f"in_features_ls={self.in_features_ls}, " + f"out_features={self.out_features}, " + f"bias={self.bias is not None}, " + f"dtype={self.weight.dtype}" + ) + + +class FeatLinear(nn.Module): + def __init__( + self, + in_features_ls: tuple[int, ...], + out_features: int, + bias: bool = True, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.in_features_ls = in_features_ls + self.out_features = out_features + + self.weight = nn.Parameter( + torch.empty((len(in_features_ls), out_features, out_features), dtype=dtype) + ) + + if bias: + self.bias = nn.Parameter( + torch.empty((len(in_features_ls), out_features), dtype=dtype) + ) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + self.register_buffer( + "in_features_buffer", + torch.tensor(in_features_ls), + persistent=False, + ) + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def forward( + self, + x: Float[torch.Tensor, "*batch max_feat"], + in_feat_size: Int[torch.Tensor, "*batch"], + ) -> Float[torch.Tensor, "*batch out_feat"]: + out = 0 + for idx, feat_size in enumerate(self.in_features_ls): + weight = self.weight[idx] + bias = self.bias[idx] if self.bias is not None else 0 + out = out + ( + torch.eq(in_feat_size, feat_size).unsqueeze(-1) + * (einsum(weight, x, "out inp, ... inp -> ... out") + bias) + ) + return out + + def extra_repr(self) -> str: + return ( + f"in_features_ls={self.in_features_ls}, " + f"out_features={self.out_features}, " + f"bias={self.bias is not None}, " + f"dtype={self.weight.dtype}" + ) + + +class MultiOutSizeLinear(nn.Module): + def __init__( + self, + in_features: int, + out_features_ls: tuple[int, ...], + dim: int = 1, + bias: bool = True, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.in_features = in_features + self.out_features_ls = out_features_ls + self.dim = dim + + self.weight = nn.Parameter( + torch.empty( + (len(out_features_ls), max(out_features_ls), in_features), dtype=dtype + ) + ) + + if bias: + self.bias = nn.Parameter( + torch.empty((len(out_features_ls), max(out_features_ls)), dtype=dtype) + ) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + self.register_buffer( + "mask", + rearrange( + size_to_mask(max(out_features_ls), torch.as_tensor(out_features_ls)), + "num_feats max_feat -> num_feats max_feat 1", + ), + persistent=False, + ) + self.register_buffer( + "out_features_buffer", + torch.tensor(out_features_ls), + persistent=False, + ) + + def reset_parameters(self): + for idx, feat_size in enumerate(self.out_features_ls): + nn.init.kaiming_uniform_(self.weight[idx, :feat_size], a=math.sqrt(5)) + nn.init.zeros_(self.weight[idx, feat_size:]) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self.weight[idx, :feat_size] + ) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias[idx, :feat_size], -bound, bound) + nn.init.zeros_(self.bias[idx, feat_size:]) + + def forward( + self, + x: Float[torch.Tensor, "*batch in_feat"], + out_feat_size: Int[torch.Tensor, "*batch"], + ) -> Float[torch.Tensor, "*batch max_feat"]: + out = 0 + for idx, feat_size in enumerate(self.out_features_ls): + weight = self.weight[idx] * self.mask[idx] + bias = self.bias[idx] if self.bias is not None else 0 + out = out + ( + torch.eq(out_feat_size, feat_size // self.dim).unsqueeze(-1) + * (einsum(weight, x, "out inp, ... inp -> ... out") + bias) + ) + return out + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, " + f"out_features_ls={self.out_features_ls}, " + f"bias={self.bias is not None}, " + f"dtype={self.weight.dtype}" + ) + + +class ResidualBlock(nn.Module): + def __init__( + self, + input_dims, + hidden_dims, + output_dims, + ): + super(ResidualBlock, self).__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + # Hidden Layer + self.hidden_layer = nn.Linear(input_dims, hidden_dims) + self.silu = nn.SiLU() + + # Output Layer + self.output_layer = nn.Linear(hidden_dims, output_dims) + # Residual Layer + self.residual_layer = nn.Linear(input_dims, output_dims) + + self.reset_parameters() + + def reset_parameters(self): + for m in self.children(): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + if m.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(m.bias, -bound, bound) + + def forward(self, x): + hidden = self.hidden_layer(x) + hidden = self.silu(hidden) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/pipeline_moirai2.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/pipeline_moirai2.py new file mode 100644 index 000000000000..fe2fb6323622 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/pipeline_moirai2.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import numpy as np +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.moirai2.modeling_moirai2 import Moirai2ForPrediction + +logger = Logger() + + +class Moirai2Pipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, **model_kwargs) + + def preprocess(self, inputs, **infer_kwargs): + """ + Preprocess input data for moirai2. + + Parameters + ---------- + inputs : list of dict + A list of dictionaries containing input data. Each dictionary contains: + - 'targets': A tensor (1D or 2D) of shape (input_length,) or (target_count, input_length). + + infer_kwargs: Additional keyword arguments for inference, such as: + - `output_length`(int): Prediction length. + + Returns + ------- + list of dict + Processed inputs compatible with moirai2 format (time, features). + """ + super().preprocess(inputs, **infer_kwargs) + # Moirai2.predict() expects past_target in (time, features) format + processed_inputs = [] + for item in inputs: + targets = item.get("targets", None) + if targets is None: + raise ValueError("Input must contain 'targets' key") + + if isinstance(targets, torch.Tensor): + targets = targets.cpu().numpy() + + # Handle different input formats + if targets.ndim == 1: + # 1D: (input_length,) -> (input_length, 1) + targets = targets[:, np.newaxis] + elif targets.ndim == 2: + # 2D: (target_count, input_length) -> (input_length, target_count) + targets = targets.T + + processed_inputs.append({"past_target": targets}) + return processed_inputs + + def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]: + """ + Generate forecasts for the input time series. + + Parameters + ---------- + inputs : list of dict + Processed inputs from preprocess method. + + **infer_kwargs : Additional arguments for inference: + - output_length (int): The length of the forecast output (default: 96). + + Returns + ------- + list of torch.Tensor + The model's predictions, each of shape (target_count, num_quantiles, prediction_length). + + Note + ---- + All model parameters (prediction_length, context_length, target_dim) are + dynamically determined from the actual input data and inference requirements. + No pre-specification is needed during model loading. + """ + prediction_length = infer_kwargs.get("output_length", 96) + + # Extract past_target from inputs + past_target_list = [item["past_target"] for item in inputs] + + if isinstance(self.model, Moirai2ForPrediction): + # After preprocess(), data is in (time, features) format + first_sample = past_target_list[0] + actual_context_length = first_sample.shape[0] # time dimension + target_dim = ( + first_sample.shape[1] if first_sample.ndim > 1 else 1 + ) # feature dimension + + # Use hparams_context to dynamically set ALL parameters + with self.model.hparams_context( + prediction_length=prediction_length, + context_length=actual_context_length, + target_dim=target_dim, + ): + predictions = self.model.predict( + past_target=past_target_list, + feat_dynamic_real=None, + past_feat_dynamic_real=None, + ) + + # Convert numpy array to torch tensor + # predictions shape: (batch, num_quantiles, future_time, *tgt) + # Note: past_target is ndarray (not tensor) because: + # 1. Moirai2.predict() follows GluonTS interface standard (uses numpy) + # 2. Internally, predict() automatically converts to tensor and moves to GPU + # 3. This design maintains compatibility with GluonTS ecosystem + predictions_list = [] + for i in range(predictions.shape[0]): + pred = predictions[i] # (num_quantiles, future_time, *tgt) + # Transpose to (target_count, num_quantiles, prediction_length) + if pred.ndim == 3: + pred = pred.transpose( + 2, 0, 1 + ) # (target_count, num_quantiles, future_time) + elif pred.ndim == 2: + pred = pred[np.newaxis, :, :] # (1, num_quantiles, future_time) + predictions_list.append(torch.from_numpy(pred)) + return predictions_list + else: + raise ValueError( + f"Model must be an instance of Moirai2ForPrediction, got {type(self.model)}" + ) + + def postprocess( + self, outputs: list[torch.Tensor], **infer_kwargs + ) -> list[torch.Tensor]: + """ + Postprocesses the model's forecast outputs by selecting the 0.5 quantile or averaging over quantiles. + + Args: + outputs (list[torch.Tensor]): List of forecast outputs, where each output is a 3D-tensor + with shape [target_count, quantile_count, output_length]. + + Returns: + list[torch.Tensor]: Processed list of forecast outputs, each is a 2D-tensor + with shape [target_count, output_length]. + """ + outputs_list = [] + for output in outputs: + # Check if 0.5 quantile is available + if output.shape[1] > 0: + # Get the median quantile (middle quantile) + median_idx = output.shape[1] // 2 + outputs_list.append(output[:, median_idx, :]) + else: + # If no quantiles, get the mean + outputs_list.append(output.mean(dim=1)) + super().postprocess(outputs_list, **infer_kwargs) + return outputs_list diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/__init__.py new file mode 100644 index 000000000000..2a1e720805f2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/_base.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/_base.py new file mode 100644 index 000000000000..b4aa490ed29b --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/_base.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import abc +from dataclasses import dataclass +from typing import Any + + +class Transformation(abc.ABC): + @abc.abstractmethod + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: ... + + def chain(self, other: "Transformation") -> "Chain": + return Chain([self, other]) + + def __add__(self, other: "Transformation") -> "Chain": + return self.chain(other) + + def __radd__(self, other): + if other == 0: + return self + return other + self + + +@dataclass +class Chain(Transformation): + """ + Chain multiple transformations together. + """ + + transformations: list[Transformation] + + def __post_init__(self) -> None: + transformations = [] + + for transformation in self.transformations: + if isinstance(transformation, Identity): + continue + elif isinstance(transformation, Chain): + transformations.extend(transformation.transformations) + else: + assert isinstance(transformation, Transformation) + transformations.append(transformation) + + self.transformations = transformations + self.__init_passed_kwargs__ = {"transformations": transformations} + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + for t in self.transformations: + data_entry = t(data_entry) + return data_entry + + +class Identity(Transformation): + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + return data_entry diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/_mixin.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/_mixin.py new file mode 100644 index 000000000000..e5c38a1bf31d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/_mixin.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from collections.abc import Callable +from typing import Any + +import numpy as np + + +class MapFuncMixin: + @staticmethod + def map_func( + func: Callable[[dict[str, Any], str], Any], + data_entry: dict[str, Any], + fields: tuple[str, ...], + optional_fields: tuple[str, ...] = (), + ): + for field in fields: + data_entry[field] = func(data_entry, field) + for field in optional_fields: + if field in data_entry: + data_entry[field] = func(data_entry, field) + + +class ApplyFuncMixin: + @staticmethod + def apply_func( + func: Callable[[dict[str, Any], str], None], + data_entry: dict[str, Any], + fields: tuple[str, ...], + optional_fields: tuple[str, ...] = (), + ): + for field in fields: + func(data_entry, field) + for field in optional_fields: + if field in data_entry: + func(data_entry, field) + + +class CollectFuncMixin: + @staticmethod + def collect_func_list( + func: Callable[[dict[str, Any], str], Any], + data_entry: dict[str, Any], + fields: tuple[str, ...], + optional_fields: tuple[str, ...] = (), + ) -> list[Any]: + collect = [] + for field in fields: + collect.append(func(data_entry, field)) + for field in optional_fields: + if field in data_entry: + collect.append(func(data_entry, field)) + return collect + + @staticmethod + def collect_func_dict( + func: Callable[[dict[str, Any], str], Any], + data_entry: dict[str, Any], + fields: tuple[str, ...], + optional_fields: tuple[str, ...] = (), + ) -> dict[str, Any]: + collect = {} + for field in fields: + collect[field] = func(data_entry, field) + for field in optional_fields: + if field in data_entry: + collect[field] = func(data_entry, field) + return collect + + def collect_func( + self, + func: Callable[[dict[str, Any], str], Any], + data_entry: dict[str, Any], + fields: tuple[str, ...], + optional_fields: tuple[str, ...] = (), + ) -> list[Any] | dict[str, Any]: + if not hasattr(self, "collection_type"): + raise NotImplementedError( + f"{self.__class__.__name__} has no attribute 'collection_type', " + "please use collect_func_list or collect_func_dict instead." + ) + + collection_type = getattr(self, "collection_type") + if collection_type == list: + collect_func = self.collect_func_list + elif collection_type == dict: + collect_func = self.collect_func_dict + else: + raise ValueError(f"Unknown collection_type: {collection_type}") + + return collect_func( + func, + data_entry, + fields, + optional_fields=optional_fields, + ) + + +class CheckArrNDimMixin: + def check_ndim(self, name: str, arr: np.ndarray, expected_ndim: int): + if isinstance(arr, list): + self.check_ndim(name, arr[0], expected_ndim - 1) + return + + if arr.ndim != expected_ndim: + raise AssertionError( + f"Array '{name}' for {self.__class__.__name__} " + f"has expected ndim: {expected_ndim}, " + f"but got ndim: {arr.ndim} of shape {arr.shape}." + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/imputation.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/imputation.py new file mode 100644 index 000000000000..1f9269a966ad --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/transform/imputation.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from dataclasses import dataclass +from typing import Any + +import numpy as np +from jaxtyping import Num + +from iotdb.ainode.core.model.moirai2.transform._base import Transformation +from iotdb.ainode.core.model.moirai2.transform._mixin import ApplyFuncMixin + + +class ImputationMethod: + def __call__( + self, x: Num[np.ndarray, "length *dim"] + ) -> Num[np.ndarray, "length *dim"]: ... + + +@dataclass(frozen=True) +class DummyValueImputation(ImputationMethod): + value: int | float | complex = 0.0 + + def __call__( + self, x: Num[np.ndarray, "length *dim"] + ) -> Num[np.ndarray, "length *dim"]: + x[np.isnan(x)] = self.value + return x + + +@dataclass(frozen=True) +class LastValueImputation(ImputationMethod): + value: int | float | complex = 0.0 + + def __call__( + self, x: Num[np.ndarray, "length *dim"] + ) -> Num[np.ndarray, "length *dim"]: + x = x.T + x[0:1][np.isnan(x[0:1])] = self.value + mask = np.isnan(x) + idx = np.arange(len(x)) + if x.ndim == 2: + idx = np.expand_dims(idx, axis=1) + idx = np.where(~mask, idx, 0) + idx = np.maximum.accumulate(idx, axis=0) + if x.ndim == 2: + x = x[idx, np.arange(x.shape[1])] + else: + x = x[idx] + return x.T + + +@dataclass(frozen=True) +class CausalMeanImputation(ImputationMethod): + """ + This class replaces each missing value with the average of all the values + up to this point, ensuring causality. + + - If the first values are missing, they are replaced by the closest non-missing value. + - If an entire sequence is NaN, it is replaced by a predefined value. + """ + + value: int | float | complex = 0.0 + + def __call__( + self, x: Num[np.ndarray, "length *dim"], value: int | float | complex = 0.0 + ) -> Num[np.ndarray, "length *dim"]: + mask = np.isnan(x).T + + # do last value imputation first + last_value_imputation = LastValueImputation(self.value) + x = last_value_imputation(x) + mask[0] = False + x = x.T + + if x.ndim == 1: + adjusted_values_to_causality = np.concatenate((np.repeat(0.0, 1), x[:-1])) + cumsum = np.cumsum(adjusted_values_to_causality) + indices = np.linspace(0, len(x) - 1, len(x)) + indices[0] = 1 + ar_res = cumsum / indices + x[mask] = ar_res[mask] + else: + # compute cumulative sum + adjusted_values_to_causality = np.vstack( + (np.zeros((1, x.shape[1])), x[:-1, :]) + ) + cumsum = np.cumsum(adjusted_values_to_causality, axis=0) + + # compute causal mean + indices = np.linspace(0, len(x) - 1, len(x)).reshape(-1, 1) + indices[0] = 1 + ar_res = cumsum / indices + # impute with causal mean + x[mask] = ar_res[mask] + return x.T + + +@dataclass +class ImputeTimeSeries(ApplyFuncMixin, Transformation): + fields: tuple[str, ...] + optional_fields: tuple[str, ...] = tuple() + imputation_method: ImputationMethod = DummyValueImputation(value=0.0) + + def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]: + self.apply_func( + self._impute, + data_entry, + self.fields, + optional_fields=self.optional_fields, + ) + return data_entry + + def _impute(self, data_entry: dict[str, Any], field: str): + value = data_entry[field] + nan_entries = np.isnan(value) + if nan_entries.any(): + data_entry[field] = self.imputation_method(value) diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index c3965d9d099a..6a29dbbdb628 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -116,6 +116,7 @@ isort = "6.0.1" setuptools = ">=75.3.0" joblib = ">=1.4.2" urllib3 = "2.6.3" +jaxtyping = ">=0.2.24" [tool.poetry.scripts] ainode = "iotdb.ainode.core.script:main"