Skip to content

Commit c3e2999

Browse files
committed
plan
1 parent 7301228 commit c3e2999

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

_doc/articles/2025/2025-11-31-route2025.rst

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,45 @@ Séance 3
4949

5050
* numpy, broadcasting
5151
* implémentation d'un chi-deux sans boucle
52+
* comment implémenter la fonction `repeat_interleave
53+
<https://docs.pytorch.org/docs/stable/generated/torch.repeat_interleave.html>`_
54+
avec :epkg:`numpy` et sans boucle ?
55+
En particulier cet exemple ``torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)``
56+
57+
Un problème... que fait la fonction suivante ?
58+
59+
.. code-block:: python
60+
61+
def reshape_keep0(arr, new_shape):
62+
orig_shape = arr.shape
63+
final_shape = []
64+
65+
for i, dim in enumerate(new_shape):
66+
if dim == 0:
67+
final_shape.append(orig_shape[i]) # garder dimension originale
68+
else:
69+
final_shape.append(dim)
70+
return arr.reshape(tuple(final_shape))
71+
72+
Comment construire une fonction qui retourne l'argument ``new_shape``
73+
quand on connaît les dimensions de départ et d'arrivée ?
74+
La fonction doit valider les exemples suivants,
75+
chaque dimension sous forme de chaîne de caractères peut prendre n'importe
76+
quelle valeur.
77+
78+
.. code-block:: python
79+
80+
self.assertEqual((0, 1024, -1), align(("d1", 4, 256, "d2"), ("d1", 1024, "d2")))
81+
self.assertEqual((0, 0, 1024), align(("d1", "d2", 4, 256), ("d1", "d2", 1024)))
82+
self.assertEqual((6, -1), align((2, 3, "d1"), ("a", "d1")))
83+
self.assertEqual((6, -1), align((2, 3, "d1"), (6, "d1")))
84+
self.assertEqual((-1, 12, 196, 64), align(("d1", 196, 64), ("d2", 12, 196, 64)))
85+
self.assertEqual((-1, 196, 64), align(("d1", 196, 64), ("d2", 196, 64)))
86+
self.assertEqual((32, 196, 64), align((32, 196, 64), (32, 196, 64)))
87+
self.assertEqual((4, 8, 196, 64), align((32, 196, 64), (4, 8, 196, 64)))
88+
self.assertEqual((32, 196, 64), align((4, 8, 196, 64), (32, 196, 64)))
89+
self.assertEqual((0, 196, 64), align(("d1", 196, 64), ("d1", 196, 64)))
90+
self.assertEqual((0, 196, 2, 32), align(("d1", 196, 64), ("d1", 196, 2, 32)))
5291
5392
Séance 4
5493
++++++++

0 commit comments

Comments
 (0)