From 0d1657ae38a5067fd76f1bb49166c21cbe28214d Mon Sep 17 00:00:00 2001 From: Rajath Bharadwaj Date: Sat, 1 Jul 2023 23:37:39 -0400 Subject: [PATCH] Fixes TypeError: torch.Size() takes an iterable of 'int' (item 1 is 'NoneType') Error When using Transformer4Rec, whilst creating the `tabular_inputs` from `tr.TabularSequenceFeatures.from_schema`, it throws an TypeError. After a bit of inspect, the following changes solved the issue. --- transformers4rec/torch/features/tabular.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformers4rec/torch/features/tabular.py b/transformers4rec/torch/features/tabular.py index 83845b80cb..3b1896468c 100644 --- a/transformers4rec/torch/features/tabular.py +++ b/transformers4rec/torch/features/tabular.py @@ -168,6 +168,7 @@ def from_schema( # type: ignore if continuous_soft_embeddings: maybe_continuous_module = cls.SOFT_EMBEDDING_MODULE_CLASS.from_schema( schema, + max_sequence_length=max_sequence_length, tags=continuous_tags, **kwargs, ) @@ -177,7 +178,7 @@ def from_schema( # type: ignore ) if categorical_tags: maybe_categorical_module = cls.EMBEDDING_MODULE_CLASS.from_schema( - schema, tags=categorical_tags, **kwargs + schema, max_sequence_length=max_sequence_length, tags=categorical_tags, **kwargs ) if pretrained_embeddings_tags: maybe_pretrained_module = cls.PRETRAINED_EMBEDDING_MODULE_CLASS.from_schema(