Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
from algoperf.workloads.criteo1tb.workload import \
BaseCriteo1TbDlrmSmallWorkload
from custom_pytorch_jax_converter import use_pytorch_weights_inplace, use_pytorch_weights_inplace_mnist



class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
Expand Down Expand Up @@ -103,6 +105,7 @@ def init_model_fn(
{'params': params_rng, 'dropout': dropout_rng},
jnp.ones(input_shape, jnp.float32))
initial_params = initial_variables['params']
initial_params = use_pytorch_weights_inplace(initial_params, file_name="/results/pytorch_base_model_criteo1tb_24_june.pth")
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
return jax_utils.replicate(initial_params), None
Expand Down
1 change: 1 addition & 0 deletions algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def init_model_fn(
dropout_rate=dropout_rate,
use_layer_norm=self.use_layer_norm,
embedding_init_multiplier=self.embedding_init_multiplier)
torch.save(model.state_dict(), "/results/pytorch_base_model_criteo1tb_24_june.pth")
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
Expand Down
4 changes: 3 additions & 1 deletion algoperf/workloads/mnist/mnist_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from algoperf import param_utils
from algoperf import spec
from algoperf.workloads.mnist.workload import BaseMnistWorkload

from custom_pytorch_jax_converter import use_pytorch_weights_inplace, use_pytorch_weights_inplace_mnist

class _Model(nn.Module):

Expand Down Expand Up @@ -42,8 +42,10 @@ def init_model_fn(
del aux_dropout_rate
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
self._model = _Model()

initial_params = self._model.init({'params': rng}, init_val,
train=True)['params']
initial_params = use_pytorch_weights_inplace_mnist(initial_params, file_name="/results/pytorch_base_model_mnist_24june.pth")
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
return jax_utils.replicate(initial_params), None
Expand Down
1 change: 1 addition & 0 deletions algoperf/workloads/mnist/mnist_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def init_model_fn(

torch.random.manual_seed(rng[0])
self._model = _Model()
torch.save(self._model.state_dict(), "/results/pytorch_base_model_mnist_24june.pth")
self._param_shapes = param_utils.pytorch_param_shapes(self._model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
self._model.to(DEVICE)
Expand Down
186 changes: 186 additions & 0 deletions custom_pytorch_jax_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import torch
import numpy as np
import jax
import jax.numpy as jnp
import logging
import copy
import copy
from jax.tree_util import tree_map
"""
Jax default parameter structure:
dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'embedding_table'])

Pytorch stateduct structure:
dict_keys(['embedding_chunk_0', 'embedding_chunk_1', 'embedding_chunk_2', 'embedding_chunk_3', 'bot_mlp.0.weight', 'bot_mlp.0.bias', 'bot_mlp.2.weight', 'bot_mlp.2.bias', 'bot_mlp.4.weight', 'bot_mlp.4.bias', 'top_mlp.0.weight', 'top_mlp.0.bias', 'top_mlp.2.weight', 'top_mlp.2.bias', 'top_mlp.4.weight', 'top_mlp.4.bias', 'top_mlp.6.weight', 'top_mlp.6.bias', 'top_mlp.8.weight', 'top_mlp.8.bias'])



The following function converts the PyTorch weights to the Jax format
and assigns them to the Jax model parameters.
The function assumes that the Jax model parameters are already initialized
and that the PyTorch weights are in the correct format.
"""
def use_pytorch_weights_inplace(jax_params, file_name=None, replicate=False):

# Load PyTorch state_dict
state_dict = torch.load(file_name)
print(state_dict.keys())
# Convert PyTorch tensors to NumPy arrays
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}

# --- Embedding Table ---
embedding_table = np.concatenate([
numpy_weights[f'embedding_chunk_{i}'] for i in range(4)
], axis=0) # adjust axis depending on chunking direction

jax_params['embedding_table'] = jnp.array(embedding_table)

# --- Bot MLP: Dense_0 to Dense_2 ---
for i, j in zip([0, 2, 4], range(3)):
jax_params[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'bot_mlp.{i}.weight'].T)
jax_params[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'bot_mlp.{i}.bias'])

# --- Top MLP: Dense_3 to Dense_7 ---
for i, j in zip([0, 2, 4, 6, 8], range(3, 8)):
jax_params[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T)
jax_params[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias'])
#jax_params = tree_map(lambda x: jnp.array(x), jax_params)
del state_dict
return jax_params


def use_pytorch_weights_inplace_mnist(jax_params, file_name=None, replicate=False):
# Load the PyTorch checkpoint
ckpt = torch.load(file_name)
state_dict = ckpt['state_dict'] if 'state_dict' in ckpt else ckpt

print("Loaded PyTorch keys:", state_dict.keys())

# Convert to numpy
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}

# Mapping PyTorch keys → JAX Dense layers
layer_map = {
'net.layer1': 'Dense_0',
'net.layer2': 'Dense_1',
}

for pt_name, jax_name in layer_map.items():
weight_key = f"{pt_name}.weight"
bias_key = f"{pt_name}.bias"

if weight_key not in numpy_weights or bias_key not in numpy_weights:
raise KeyError(f"Missing keys: {weight_key} or {bias_key} in PyTorch weights")

jax_params[jax_name]['kernel'] = jnp.array(numpy_weights[weight_key].T) # Transpose!
jax_params[jax_name]['bias'] = jnp.array(numpy_weights[bias_key])

return jax_params


# def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
# """Compares two JAX PyTrees of weights and prints where they differ."""
# all_equal = True

# def compare_fn(p1, p2):
# nonlocal all_equal
# #if not jnp.allclose(p1, p2):
# if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
# logging.info("❌ Mismatch found:")
# logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}")
# logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
# all_equal = False
# return jnp.allclose(p1, p2, atol=atol, rtol=rtol)

# try:
# _ = jax.tree_util.tree_map(compare_fn, params1, params2)
# except Exception as e:
# logging.info("❌ Structure mismatch or error during comparison:", e)
# return False

# if all_equal:
# logging.info("✅ All weights are equal (within tolerance)")
# return all_equal

import jax
import jax.numpy as jnp
import logging

def maybe_unreplicate(pytree):
"""If leading axis matches device count, strip it assuming it's pmap replication."""
num_devices = jax.device_count()
return jax.tree_util.tree_map(
lambda x: x[0] if isinstance(x, jax.Array) and x.shape[0] == num_devices else x,
pytree
)

def move_to_cpu(tree):
return jax.tree_util.tree_map(lambda x: jax.device_put(x, device=jax.devices("cpu")[0]), tree)


def are_weights_equal(params1, params2, atol=1e-6, rtol=1e-6):
"""Compares two JAX PyTrees of weights and logs where they differ, safely handling PMAP replication."""
# Attempt to unreplicate if needed
params1 = maybe_unreplicate(params1)
params2 = maybe_unreplicate(params2)

params1 = move_to_cpu(params1)
params2 = move_to_cpu(params2)

all_equal = True

def compare_fn(p1, p2):
nonlocal all_equal
if not jnp.allclose(p1, p2, atol=atol, rtol=rtol):
logging.info("❌ Mismatch found:")
logging.info(f"Shape 1: {p1.shape}, Shape 2: {p2.shape}")
logging.info(f"Max diff: {jnp.max(jnp.abs(p1 - p2))}")
all_equal = False
return jnp.allclose(p1, p2, atol=atol, rtol=rtol)

try:
jax.tree_util.tree_map(compare_fn, params1, params2)
except Exception as e:
logging.info("❌ Structure mismatch or error during comparison:", exc_info=True)
return False

if all_equal:
logging.info("✅ All weights are equal (within tolerance)")
return all_equal



def use_pytorch_weights2(jax_params, file_name=None, replicate=False):

def deep_copy_to_cpu(pytree):
return tree_map(lambda x: jax.device_put(jnp.array(copy.deepcopy(x)), device=jax.devices("cpu")[0]), pytree)

breakpoint()
jax_copy = deep_copy_to_cpu(jax_params)
# Load PyTorch state_dict lazily to CPU
state_dict = torch.load(file_name, map_location='cpu')
print(state_dict.keys())
# Convert PyTorch tensors to NumPy arrays
numpy_weights = {k: v.cpu().numpy() for k, v in state_dict.items()}

# --- Embedding Table ---
embedding_table = np.concatenate([
numpy_weights[f'embedding_chunk_{i}'] for i in range(4)
], axis=0) # adjust axis depending on chunking direction

jax_copy['embedding_table'] = jnp.array(embedding_table)

# --- Bot MLP: Dense_0 to Dense_2 ---
for i, j in zip([0, 2, 4], range(3)):
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'bot_mlp.{i}.weight'].T)
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'bot_mlp.{i}.bias'])

# --- Top MLP: Dense_3 to Dense_7 ---
for i, j in zip([0, 2, 4, 6, 8], range(3, 8)):
jax_copy[f'Dense_{j}']['kernel'] = jnp.array(numpy_weights[f'top_mlp.{i}.weight'].T)
jax_copy[f'Dense_{j}']['bias'] = jnp.array(numpy_weights[f'top_mlp.{i}.bias'])
#jax_copy = tree_map(lambda x: jnp.array(x), jax_copy)
del state_dict

return jax_copy

Loading
Loading