From 94fed3bb01f00abff3bea748b689be3dbc4e1d9b Mon Sep 17 00:00:00 2001 From: tqtg Date: Sat, 19 Apr 2025 18:40:06 +0000 Subject: [PATCH] TensorFlow v1 to v2 migration for NCF models --- cornac/models/ncf/backend_tf.py | 165 ++++++++++++----------- cornac/models/ncf/recom_gmf.py | 88 ++++++------ cornac/models/ncf/recom_mlp.py | 93 ++++++------- cornac/models/ncf/recom_ncf_base.py | 105 ++++++++++----- cornac/models/ncf/recom_neumf.py | 201 +++++++++++++--------------- cornac/models/ncf/requirements.txt | 2 +- 6 files changed, 328 insertions(+), 326 deletions(-) diff --git a/cornac/models/ncf/backend_tf.py b/cornac/models/ncf/backend_tf.py index 2cf0c5926..0ff2bcc90 100644 --- a/cornac/models/ncf/backend_tf.py +++ b/cornac/models/ncf/backend_tf.py @@ -13,15 +13,8 @@ # limitations under the License. # ============================================================================ -import warnings -# disable annoying tensorflow deprecated API warnings -warnings.filterwarnings("ignore", category=UserWarning) - -import tensorflow.compat.v1 as tf - -tf.logging.set_verbosity(tf.logging.ERROR) -tf.disable_v2_behavior() +import tensorflow as tf act_functions = { @@ -35,88 +28,98 @@ } -def loss_fn(labels, logits): - cross_entropy = tf.reduce_mean( - tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) - ) - reg_loss = tf.losses.get_regularization_loss() - return cross_entropy + reg_loss - - -def train_fn(loss, learning_rate, learner): +def get_optimizer(learning_rate, learner): if learner.lower() == "adagrad": - opt = tf.train.AdagradOptimizer(learning_rate=learning_rate, name="optimizer") + return tf.keras.optimizers.Adagrad(learning_rate=learning_rate) elif learner.lower() == "rmsprop": - opt = tf.train.RMSPropOptimizer(learning_rate=learning_rate, name="optimizer") + return tf.keras.optimizers.RMSprop(learning_rate=learning_rate) elif learner.lower() == "adam": - opt = tf.train.AdamOptimizer(learning_rate=learning_rate, name="optimizer") + return tf.keras.optimizers.Adam(learning_rate=learning_rate) else: - opt = tf.train.GradientDescentOptimizer( - learning_rate=learning_rate, name="optimizer" - ) - - return opt.minimize(loss) - - -def emb( - uid, iid, num_users, num_items, emb_size, reg_user, reg_item, seed=None, scope="emb" -): - with tf.variable_scope(scope): - user_emb = tf.get_variable( - "user_emb", - shape=[num_users, emb_size], - dtype=tf.float32, - initializer=tf.random_normal_initializer(stddev=0.01, seed=seed), - regularizer=tf.keras.regularizers.L2(reg_user), + return tf.keras.optimizers.SGD(learning_rate=learning_rate) + + +class GMFLayer(tf.keras.layers.Layer): + def __init__(self, num_users, num_items, emb_size, reg_user, reg_item, seed=None, **kwargs): + super(GMFLayer, self).__init__(**kwargs) + self.num_users = num_users + self.num_items = num_items + self.emb_size = emb_size + self.reg_user = reg_user + self.reg_item = reg_item + self.seed = seed + + # Initialize embeddings + self.user_embedding = tf.keras.layers.Embedding( + num_users, + emb_size, + embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed), + embeddings_regularizer=tf.keras.regularizers.L2(reg_user), + name="user_embedding" ) - item_emb = tf.get_variable( - "item_emb", - shape=[num_items, emb_size], - dtype=tf.float32, - initializer=tf.random_normal_initializer(stddev=0.01, seed=seed), - regularizer=tf.keras.regularizers.L2(reg_item), - ) - - return tf.nn.embedding_lookup(user_emb, uid), tf.nn.embedding_lookup(item_emb, iid) - - -def gmf(uid, iid, num_users, num_items, emb_size, reg_user, reg_item, seed=None): - with tf.variable_scope("GMF") as scope: - user_emb, item_emb = emb( - uid=uid, - iid=iid, - num_users=num_users, - num_items=num_items, - emb_size=emb_size, - reg_user=reg_user, - reg_item=reg_item, - seed=seed, - scope=scope, + + self.item_embedding = tf.keras.layers.Embedding( + num_items, + emb_size, + embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed), + embeddings_regularizer=tf.keras.regularizers.L2(reg_item), + name="item_embedding" ) + + def call(self, inputs): + user_ids, item_ids = inputs + user_emb = self.user_embedding(user_ids) + item_emb = self.item_embedding(item_ids) return tf.multiply(user_emb, item_emb) -def mlp(uid, iid, num_users, num_items, layers, reg_layers, act_fn, seed=None): - with tf.variable_scope("MLP") as scope: - user_emb, item_emb = emb( - uid=uid, - iid=iid, - num_users=num_users, - num_items=num_items, - emb_size=int(layers[0] / 2), - reg_user=reg_layers[0], - reg_item=reg_layers[0], - seed=seed, - scope=scope, +class MLPLayer(tf.keras.layers.Layer): + def __init__(self, num_users, num_items, layers, reg_layers, act_fn, seed=None, **kwargs): + super(MLPLayer, self).__init__(**kwargs) + self.num_users = num_users + self.num_items = num_items + self.layers = layers + self.reg_layers = reg_layers + self.act_fn = act_fn + self.seed = seed + + # Initialize embeddings + self.user_embedding = tf.keras.layers.Embedding( + num_users, + int(layers[0] / 2), + embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed), + embeddings_regularizer=tf.keras.regularizers.L2(reg_layers[0]), + name="user_embedding" ) - interaction = tf.concat([user_emb, item_emb], axis=-1) - for i, layer in enumerate(layers[1:]): - interaction = tf.layers.dense( - interaction, - units=layer, - name="layer{}".format(i + 1), - activation=act_functions.get(act_fn, tf.nn.relu), - kernel_initializer=tf.initializers.lecun_uniform(seed), - kernel_regularizer=tf.keras.regularizers.L2(reg_layers[i + 1]), + + self.item_embedding = tf.keras.layers.Embedding( + num_items, + int(layers[0] / 2), + embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=0.01, seed=seed), + embeddings_regularizer=tf.keras.regularizers.L2(reg_layers[0]), + name="item_embedding" + ) + + # Define dense layers + self.dense_layers = [] + for i, layer_size in enumerate(layers[1:]): + self.dense_layers.append( + tf.keras.layers.Dense( + layer_size, + activation=act_functions.get(act_fn, tf.nn.relu), + kernel_initializer=tf.keras.initializers.LecunUniform(seed=seed), + kernel_regularizer=tf.keras.regularizers.L2(reg_layers[i + 1]), + name=f"layer{i+1}" + ) ) + + def call(self, inputs): + user_ids, item_ids = inputs + user_emb = self.user_embedding(user_ids) + item_emb = self.item_embedding(item_ids) + interaction = tf.concat([user_emb, item_emb], axis=-1) + + for layer in self.dense_layers: + interaction = layer(interaction) + return interaction diff --git a/cornac/models/ncf/recom_gmf.py b/cornac/models/ncf/recom_gmf.py index f55ec7eef..7b1a6c304 100644 --- a/cornac/models/ncf/recom_gmf.py +++ b/cornac/models/ncf/recom_gmf.py @@ -111,55 +111,45 @@ def __init__( ######################## ## TensorFlow backend ## ######################## - def _build_graph_tf(self): - import tensorflow.compat.v1 as tf - from .backend_tf import gmf, loss_fn, train_fn - - self.graph = tf.Graph() - with self.graph.as_default(): - tf.set_random_seed(self.seed) - - self.user_id = tf.placeholder(shape=[None], dtype=tf.int32, name="user_id") - self.item_id = tf.placeholder(shape=[None], dtype=tf.int32, name="item_id") - self.labels = tf.placeholder( - shape=[None, 1], dtype=tf.float32, name="labels" - ) - - self.interaction = gmf( - uid=self.user_id, - iid=self.item_id, - num_users=self.num_users, - num_items=self.num_items, - emb_size=self.num_factors, - reg_user=self.reg, - reg_item=self.reg, - seed=self.seed, - ) - - logits = tf.layers.dense( - self.interaction, - units=1, - name="logits", - kernel_initializer=tf.initializers.lecun_uniform(self.seed), - ) - self.prediction = tf.nn.sigmoid(logits) - - self.loss = loss_fn(labels=self.labels, logits=logits) - self.train_op = train_fn( - self.loss, learning_rate=self.lr, learner=self.learner - ) - - self.initializer = tf.global_variables_initializer() - self.saver = tf.train.Saver() - - self._sess_init_tf() - - def _score_tf(self, user_idx, item_idx): - feed_dict = { - self.user_id: [user_idx], - self.item_id: np.arange(self.num_items) if item_idx is None else [item_idx], - } - return self.sess.run(self.prediction, feed_dict=feed_dict) + def _build_model_tf(self): + import tensorflow as tf + from .backend_tf import GMFLayer + + # Define inputs + user_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="user_input") + item_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="item_input") + + # GMF layer + gmf_layer = GMFLayer( + num_users=self.num_users, + num_items=self.num_items, + emb_size=self.num_factors, + reg_user=self.reg, + reg_item=self.reg, + seed=self.seed, + name="gmf_layer" + ) + + # Get embeddings and element-wise product + gmf_vector = gmf_layer([user_input, item_input]) + + # Output layer + logits = tf.keras.layers.Dense( + 1, + kernel_initializer=tf.keras.initializers.LecunUniform(seed=self.seed), + name="logits" + )(gmf_vector) + + prediction = tf.keras.layers.Activation('sigmoid', name="prediction")(logits) + + # Create model with both logits and prediction outputs + model = tf.keras.Model( + inputs=[user_input, item_input], + outputs=prediction, + name="GMF" + ) + + return model ##################### ## PyTorch backend ## diff --git a/cornac/models/ncf/recom_mlp.py b/cornac/models/ncf/recom_mlp.py index 6901b91c4..3b2f68801 100644 --- a/cornac/models/ncf/recom_mlp.py +++ b/cornac/models/ncf/recom_mlp.py @@ -116,60 +116,45 @@ def __init__( ######################## ## TensorFlow backend ## ######################## - def _build_graph_tf(self): - import tensorflow.compat.v1 as tf - from .backend_tf import mlp, loss_fn, train_fn - - self.graph = tf.Graph() - with self.graph.as_default(): - tf.set_random_seed(self.seed) - - self.user_id = tf.placeholder(shape=[None], dtype=tf.int32, name="user_id") - self.item_id = tf.placeholder(shape=[None], dtype=tf.int32, name="item_id") - self.labels = tf.placeholder( - shape=[None, 1], dtype=tf.float32, name="labels" - ) - - self.interaction = mlp( - uid=self.user_id, - iid=self.item_id, - num_users=self.num_users, - num_items=self.num_items, - layers=self.layers, - reg_layers=[self.reg] * len(self.layers), - act_fn=self.act_fn, - seed=self.seed, - ) - logits = tf.layers.dense( - self.interaction, - units=1, - name="logits", - kernel_initializer=tf.initializers.lecun_uniform(self.seed), - ) - self.prediction = tf.nn.sigmoid(logits) - - self.loss = loss_fn(labels=self.labels, logits=logits) - self.train_op = train_fn( - self.loss, learning_rate=self.lr, learner=self.learner - ) - - self.initializer = tf.global_variables_initializer() - self.saver = tf.train.Saver() - - self._sess_init_tf() - - def _score_tf(self, user_idx, item_idx): - if item_idx is None: - feed_dict = { - self.user_id: np.ones(self.num_items) * user_idx, - self.item_id: np.arange(self.num_items), - } - else: - feed_dict = { - self.user_id: [user_idx], - self.item_id: [item_idx], - } - return self.sess.run(self.prediction, feed_dict=feed_dict) + def _build_model_tf(self): + import tensorflow as tf + from .backend_tf import MLPLayer + + # Define inputs + user_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="user_input") + item_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="item_input") + + # MLP layer + mlp_layer = MLPLayer( + num_users=self.num_users, + num_items=self.num_items, + layers=self.layers, + reg_layers=[self.reg] * len(self.layers), + act_fn=self.act_fn, + seed=self.seed, + name="mlp_layer" + ) + + # Get MLP vector + mlp_vector = mlp_layer([user_input, item_input]) + + # Output layer + logits = tf.keras.layers.Dense( + 1, + kernel_initializer=tf.keras.initializers.LecunUniform(seed=self.seed), + name="logits" + )(mlp_vector) + + prediction = tf.keras.layers.Activation('sigmoid', name="prediction")(logits) + + # Create model + model = tf.keras.Model( + inputs=[user_input, item_input], + outputs=prediction, + name="MLP" + ) + + return model ##################### ## PyTorch backend ## diff --git a/cornac/models/ncf/recom_ncf_base.py b/cornac/models/ncf/recom_ncf_base.py index 05bac60a4..c8ac894cc 100644 --- a/cornac/models/ncf/recom_ncf_base.py +++ b/cornac/models/ncf/recom_ncf_base.py @@ -141,28 +141,34 @@ def fit(self, train_set, val_set=None): ######################## ## TensorFlow backend ## ######################## - def _build_graph_tf(self): + def _build_model_tf(self): raise NotImplementedError() - def _sess_init_tf(self): - import tensorflow.compat.v1 as tf - - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - self.sess = tf.Session(graph=self.graph, config=config) - self.sess.run(self.initializer) - - def _get_feed_dict(self, batch_users, batch_items, batch_ratings): - return { - self.user_id: batch_users, - self.item_id: batch_items, - self.labels: batch_ratings.reshape(-1, 1), - } - def _fit_tf(self, train_set, val_set): - if not hasattr(self, "graph"): - self._build_graph_tf() - + import tensorflow as tf + + # Set random seed for reproducibility + if self.seed is not None: + tf.random.set_seed(self.seed) + np.random.seed(self.seed) + + # Configure GPU memory growth to avoid OOM errors + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + try: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + except RuntimeError as e: + print(e) + + # Build the model + self.model = self._build_model_tf() + + # Get optimizer + from .backend_tf import get_optimizer + optimizer = get_optimizer(learning_rate=self.lr, learner=self.learner) + + # Training loop loop = trange(self.num_epochs, disable=not self.verbose) for _ in loop: count = 0 @@ -172,17 +178,33 @@ def _fit_tf(self, train_set, val_set): self.batch_size, shuffle=True, binary=True, num_zeros=self.num_neg ) ): - _, _loss = self.sess.run( - [self.train_op, self.loss], - feed_dict=self._get_feed_dict( - batch_users, batch_items, batch_ratings - ), - ) + batch_ratings = batch_ratings.reshape(-1, 1, 1) + + # Convert to tensors + batch_users = tf.convert_to_tensor(batch_users, dtype=tf.int32) + batch_items = tf.convert_to_tensor(batch_items, dtype=tf.int32) + batch_ratings = tf.convert_to_tensor(batch_ratings, dtype=tf.float32) + + # Training step + with tf.GradientTape() as tape: + predictions = self.model([batch_users, batch_items], training=True) + cross_entropy = tf.keras.losses.binary_crossentropy( + y_true=batch_ratings, + y_pred=predictions, + from_logits=False # predictions are already probabilities + ) + cross_entropy = tf.reduce_mean(cross_entropy) + loss_value = cross_entropy + tf.reduce_sum(self.model.losses) + + # Apply gradients + grads = tape.gradient(loss_value, self.model.trainable_variables) + optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) + count += len(batch_users) - sum_loss += len(batch_users) * _loss + sum_loss += len(batch_users) * loss_value.numpy() if i % 10 == 0: loop.set_postfix(loss=(sum_loss / count)) - + if self.early_stopping is not None and self.early_stop( train_set, val_set, **self.early_stopping ): @@ -190,7 +212,24 @@ def _fit_tf(self, train_set, val_set): loop.close() def _score_tf(self, user_idx, item_idx): - raise NotImplementedError() + """Score function for TensorFlow models.""" + import tensorflow as tf + + if item_idx is None: + # Score all items for a given user + user_tensor = tf.convert_to_tensor([user_idx], dtype=tf.int32) + item_tensor = tf.convert_to_tensor(np.arange(self.num_items), dtype=tf.int32) + + # Broadcast user_idx to match the shape of item_tensor + user_tensor = tf.broadcast_to(user_tensor, shape=item_tensor.shape) + else: + # Score a specific item for a given user + user_tensor = tf.convert_to_tensor([user_idx], dtype=tf.int32) + item_tensor = tf.convert_to_tensor([item_idx], dtype=tf.int32) + + # Get predictions + predictions = self.model([user_tensor, item_tensor], training=False) + return predictions.numpy().squeeze() ##################### ## PyTorch backend ## @@ -271,7 +310,9 @@ def save(self, save_dir=None): model_file = Recommender.save(self, save_dir) if self.backend == "tensorflow": - self.saver.save(self.sess, model_file.replace(".pkl", ".cpt")) + # Save the TensorFlow model + if hasattr(self, "model"): + self.model.save_weights(model_file.replace(".pkl", ".h5")) elif self.backend == "pytorch": # TODO: implement model saving for PyTorch raise NotImplementedError() @@ -301,8 +342,10 @@ def load(model_path, trainable=False): model.pretrained = False if model.backend == "tensorflow": - model._build_graph() - model.saver.restore(model.sess, model.load_from.replace(".pkl", ".cpt")) + # Build the model + model.model = model._build_model_tf() + # Load weights + model.model.load_weights(model.load_from.replace(".pkl", ".h5")) elif model.backend == "pytorch": # TODO: implement model loading for PyTorch raise NotImplementedError() diff --git a/cornac/models/ncf/recom_neumf.py b/cornac/models/ncf/recom_neumf.py index 760048d0f..8e3f9ffa4 100644 --- a/cornac/models/ncf/recom_neumf.py +++ b/cornac/models/ncf/recom_neumf.py @@ -157,121 +157,102 @@ def from_pretrained(self, pretrained_gmf, pretrained_mlp, alpha=0.5): ######################## ## TensorFlow backend ## ######################## - def _build_graph_tf(self): - import tensorflow.compat.v1 as tf - from .backend_tf import gmf, mlp, loss_fn, train_fn - - self.graph = tf.Graph() - with self.graph.as_default(): - tf.set_random_seed(self.seed) - - self.gmf_user_id = tf.placeholder( - shape=[None], dtype=tf.int32, name="gmf_user_id" - ) - self.mlp_user_id = tf.placeholder( - shape=[None], dtype=tf.int32, name="mlp_user_id" - ) - self.item_id = tf.placeholder(shape=[None], dtype=tf.int32, name="item_id") - self.labels = tf.placeholder( - shape=[None, 1], dtype=tf.float32, name="labels" - ) - - gmf_feat = gmf( - uid=self.gmf_user_id, - iid=self.item_id, - num_users=self.num_users, - num_items=self.num_items, - emb_size=self.num_factors, - reg_user=self.reg, - reg_item=self.reg, - seed=self.seed, - ) - mlp_feat = mlp( - uid=self.mlp_user_id, - iid=self.item_id, - num_users=self.num_users, - num_items=self.num_items, - layers=self.layers, - reg_layers=[self.reg] * len(self.layers), - act_fn=self.act_fn, - seed=self.seed, - ) - - self.interaction = tf.concat([gmf_feat, mlp_feat], axis=-1) - logits = tf.layers.dense( - self.interaction, - units=1, - name="logits", - kernel_initializer=tf.initializers.lecun_uniform(self.seed), - ) - self.prediction = tf.nn.sigmoid(logits) - - self.loss = loss_fn(labels=self.labels, logits=logits) - self.train_op = train_fn( - self.loss, learning_rate=self.lr, learner=self.learner - ) - - self.initializer = tf.global_variables_initializer() - self.saver = tf.train.Saver() - - self._sess_init_tf() - + def _build_model_tf(self): + import tensorflow as tf + from .backend_tf import GMFLayer, MLPLayer + + # Define inputs + user_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="user_input") + item_input = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name="item_input") + + # GMF layer + gmf_layer = GMFLayer( + num_users=self.num_users, + num_items=self.num_items, + emb_size=self.num_factors, + reg_user=self.reg, + reg_item=self.reg, + seed=self.seed, + name="gmf_layer" + ) + + # MLP layer + mlp_layer = MLPLayer( + num_users=self.num_users, + num_items=self.num_items, + layers=self.layers, + reg_layers=[self.reg] * len(self.layers), + act_fn=self.act_fn, + seed=self.seed, + name="mlp_layer" + ) + + # Get embeddings and element-wise product + gmf_vector = gmf_layer([user_input, item_input]) + mlp_vector = mlp_layer([user_input, item_input]) + + # Concatenate GMF and MLP vectors + concat_vector = tf.keras.layers.Concatenate(axis=-1)([gmf_vector, mlp_vector]) + + # Output layer + logits = tf.keras.layers.Dense( + 1, + kernel_initializer=tf.keras.initializers.LecunUniform(seed=self.seed), + name="logits" + )(concat_vector) + + prediction = tf.keras.layers.Activation('sigmoid', name="prediction")(logits) + + # Create model + model = tf.keras.Model( + inputs=[user_input, item_input], + outputs=prediction, + name="NeuMF" + ) + + # Handle pretrained models if self.pretrained: - gmf_kernel = self.pretrained_gmf.sess.run( - self.pretrained_gmf.sess.graph.get_tensor_by_name("logits/kernel:0") + # Get GMF and MLP models + gmf_model = self.pretrained_gmf.model + mlp_model = self.pretrained_mlp.model + + # Copy GMF embeddings + model.get_layer('gmf_layer').user_embedding.set_weights( + gmf_model.get_layer('gmf_layer').user_embedding.get_weights() ) - gmf_bias = self.pretrained_gmf.sess.run( - self.pretrained_gmf.sess.graph.get_tensor_by_name("logits/bias:0") + model.get_layer('gmf_layer').item_embedding.set_weights( + gmf_model.get_layer('gmf_layer').item_embedding.get_weights() ) - mlp_kernel = self.pretrained_mlp.sess.run( - self.pretrained_mlp.sess.graph.get_tensor_by_name("logits/kernel:0") + + # Copy MLP embeddings and layers + model.get_layer('mlp_layer').user_embedding.set_weights( + mlp_model.get_layer('mlp_layer').user_embedding.get_weights() ) - mlp_bias = self.pretrained_mlp.sess.run( - self.pretrained_mlp.sess.graph.get_tensor_by_name("logits/bias:0") + model.get_layer('mlp_layer').item_embedding.set_weights( + mlp_model.get_layer('mlp_layer').item_embedding.get_weights() ) - logits_kernel = np.concatenate( - [self.alpha * gmf_kernel, (1 - self.alpha) * mlp_kernel] - ) - logits_bias = self.alpha * gmf_bias + (1 - self.alpha) * mlp_bias - - for v in self.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): - if v.name.startswith("GMF"): - sess = self.pretrained_gmf.sess - self.sess.run( - tf.assign(v, sess.run(sess.graph.get_tensor_by_name(v.name))) - ) - elif v.name.startswith("MLP"): - sess = self.pretrained_mlp.sess - self.sess.run( - tf.assign(v, sess.run(sess.graph.get_tensor_by_name(v.name))) - ) - elif v.name.startswith("logits/kernel"): - self.sess.run(tf.assign(v, logits_kernel)) - elif v.name.startswith("logits/bias"): - self.sess.run(tf.assign(v, logits_bias)) - - def _get_feed_dict(self, batch_users, batch_items, batch_ratings): - return { - self.gmf_user_id: batch_users, - self.mlp_user_id: batch_users, - self.item_id: batch_items, - self.labels: batch_ratings.reshape(-1, 1), - } - - def _score_tf(self, user_idx, item_idx): - if item_idx is None: - feed_dict = { - self.gmf_user_id: [user_idx], - self.mlp_user_id: np.ones(self.num_items) * user_idx, - self.item_id: np.arange(self.num_items), - } - else: - feed_dict = { - self.gmf_user_id: [user_idx], - self.mlp_user_id: [user_idx], - self.item_id: [item_idx], - } - return self.sess.run(self.prediction, feed_dict=feed_dict) + + # Copy dense layers in MLP + for i, layer in enumerate(model.get_layer('mlp_layer').dense_layers): + layer.set_weights(mlp_model.get_layer('mlp_layer').dense_layers[i].get_weights()) + + # Combine weights for output layer + gmf_logits_weights = gmf_model.get_layer('logits').get_weights() + mlp_logits_weights = mlp_model.get_layer('logits').get_weights() + + # Combine kernel weights + combined_kernel = np.concatenate([ + self.alpha * gmf_logits_weights[0], + (1.0 - self.alpha) * mlp_logits_weights[0] + ], axis=0) + + # Combine bias weights + combined_bias = self.alpha * gmf_logits_weights[1] + (1.0 - self.alpha) * mlp_logits_weights[1] + + # Set combined weights to output layer + model.get_layer('logits').set_weights([combined_kernel, combined_bias]) + + return model ##################### ## PyTorch backend ## diff --git a/cornac/models/ncf/requirements.txt b/cornac/models/ncf/requirements.txt index c108aa3bf..972a009a8 100644 --- a/cornac/models/ncf/requirements.txt +++ b/cornac/models/ncf/requirements.txt @@ -1,2 +1,2 @@ -tensorflow==2.12.0 +tensorflow>=2.12.0 torch>=0.4.1 \ No newline at end of file