Skip to content

Commit a593b50

Browse files
Kade LinKade Lin
authored andcommitted
Implement the task4_3 to 4_5
1 parent 4efa3dc commit a593b50

14 files changed

+2456
-64
lines changed

.DS_Store

8 KB
Binary file not shown.

a.txt

Lines changed: 0 additions & 8 deletions
This file was deleted.

minitorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Public packages for minitorch"""
2+
13
from .testing import MathTest, MathTestVariable # type: ignore # noqa: F401,F403
24
from .datasets import * # noqa: F401,F403
35
from .optim import * # noqa: F401,F403

minitorch/fast_conv.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
from typing import Tuple, TypeVar, Any
22

3-
import numpy as np
43
from numba import prange
54
from numba import njit as _njit
65

76
from .autodiff import Context
87
from .tensor import Tensor
98
from .tensor_data import (
10-
MAX_DIMS,
11-
Index,
129
Shape,
1310
Strides,
1411
Storage,
@@ -22,6 +19,26 @@
2219

2320

2421
def njit(fn: Fn, **kwargs: Any) -> Fn:
22+
"""Compile a Python function into a Numba-optimized function using `njit`.
23+
24+
This function compiles the given Python function `fn` into a Numba-optimized
25+
version with just-in-time compilation. The `inline="always"` option ensures
26+
that the compiled function is always inlined into the calling code, improving
27+
performance by reducing function call overhead.
28+
29+
Parameters
30+
----------
31+
fn : Fn
32+
The Python function to be compiled.
33+
**kwargs : Any
34+
Additional arguments for the Numba `njit` compiler.
35+
36+
Returns
37+
-------
38+
Fn
39+
The Numba-optimized version of the input function.
40+
41+
"""
2542
return _njit(inline="always", **kwargs)(fn) # type: ignore
2643

2744

@@ -98,7 +115,7 @@ def _tensor_conv1d(
98115
for ic in range(in_channels):
99116
for k in range(kw):
100117
iw = w + k if not reverse else w - k
101-
if 0 <= iw < width:
118+
if 0 <= iw < width:
102119
input_idx = b * s1[0] + ic * s1[1] + iw * s1[2]
103120
weight_idx = oc * s2[0] + ic * s2[1] + k * s2[2]
104121
acc += input[input_idx] * weight[weight_idx]
@@ -139,6 +156,25 @@ def forward(ctx: Context, input: Tensor, weight: Tensor) -> Tensor:
139156

140157
@staticmethod
141158
def backward(ctx: Context, grad_output: Tensor) -> Tuple[Tensor, Tensor]:
159+
"""Compute the gradients for 1D Convolution.
160+
161+
Parameters
162+
----------
163+
ctx : Context
164+
Context object containing saved tensors from the forward pass.
165+
grad_output : Tensor
166+
Gradient of the loss with respect to the output of the convolution.
167+
168+
Returns
169+
-------
170+
Tuple[Tensor, Tensor]
171+
A tuple containing:
172+
- grad_input: Gradient of the loss with respect to the input tensor.
173+
Shape: [batch, in_channels, width]
174+
- grad_weight: Gradient of the loss with respect to the weight tensor.
175+
Shape: [out_channels, in_channels, kernel_width]
176+
177+
"""
142178
input, weight = ctx.saved_values
143179
batch, in_channels, w = input.shape
144180
out_channels, in_channels, kw = weight.shape
@@ -215,7 +251,8 @@ def _tensor_conv2d(
215251
reverse (bool): anchor weight at top-left or bottom-right
216252
217253
"""
218-
batch_, out_channels, _, _ = out_shape
254+
# batch_, out_channels, _, _ = out_shape
255+
batch_, out_channels, out_height, out_width = out_shape
219256
batch, in_channels, height, width = input_shape
220257
out_channels_, in_channels_, kh, kw = weight_shape
221258

@@ -232,32 +269,30 @@ def _tensor_conv2d(
232269
s20, s21, s22, s23 = s2[0], s2[1], s2[2], s2[3]
233270

234271
# TODO: Implement for Task 4.2.
235-
o_s0, o_s1, o_s2, o_s3 = out_strides
272+
s30, s31, s32, s33 = out_strides
236273

237274
for b in prange(batch):
238275
for oc in range(out_channels):
239-
for h in range(height):
240-
for w in range(width):
276+
for oh in range(out_height):
277+
for ow in range(out_width):
241278
acc = 0.0
242279
for ic in range(in_channels):
243280
for kh_idx in range(kh):
244281
for kw_idx in range(kw):
245-
ih = h + kh_idx if not reverse else h - kh_idx
246-
iw = w + kw_idx if not reverse else w - kw_idx
282+
ih = oh + kh_idx if not reverse else oh - kh_idx
283+
iw = ow + kw_idx if not reverse else ow - kw_idx
284+
247285
if 0 <= ih < height and 0 <= iw < width:
248-
input_idx = (
249-
b * s10 + ic * s11 + ih * s12 + iw * s13
250-
)
286+
input_idx = b * s10 + ic * s11 + ih * s12 + iw * s13
251287
weight_idx = (
252288
oc * s20
253289
+ ic * s21
254290
+ kh_idx * s22
255291
+ kw_idx * s23
256292
)
257293
acc += input[input_idx] * weight[weight_idx]
258-
out_idx = (
259-
b * o_s0 + oc * o_s1 + h * o_s2 + w * o_s3
260-
)
294+
295+
out_idx = b * s30 + oc * s31 + oh * s32 + ow * s33
261296
out[out_idx] = acc
262297

263298

@@ -292,6 +327,26 @@ def forward(ctx: Context, input: Tensor, weight: Tensor) -> Tensor:
292327

293328
@staticmethod
294329
def backward(ctx: Context, grad_output: Tensor) -> Tuple[Tensor, Tensor]:
330+
"""Compute the gradients for 2D Convolution.
331+
332+
Parameters
333+
----------
334+
ctx : Context
335+
Context object containing saved tensors from the forward pass.
336+
grad_output : Tensor
337+
Gradient of the loss with respect to the output of the convolution.
338+
Shape: [batch, out_channels, height, width]
339+
340+
Returns
341+
-------
342+
Tuple[Tensor, Tensor]
343+
A tuple containing:
344+
- grad_input: Gradient of the loss with respect to the input tensor.
345+
Shape: [batch, in_channels, height, width]
346+
- grad_weight: Gradient of the loss with respect to the weight tensor.
347+
Shape: [out_channels, in_channels, kernel_height, kernel_width]
348+
349+
"""
295350
input, weight = ctx.saved_values
296351
batch, in_channels, h, w = input.shape
297352
out_channels, in_channels, kh, kw = weight.shape

minitorch/nn.py

Lines changed: 218 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,224 @@ def tile(input: Tensor, kernel: Tuple[int, int]) -> Tuple[Tensor, int, int]:
3636
assert height % kh == 0
3737
assert width % kw == 0
3838
# TODO: Implement for Task 4.3.
39-
raise NotImplementedError("Need to implement for Task 4.3")
39+
input = input.contiguous()
40+
41+
new_height = height // kh
42+
new_width = width // kw
43+
out = input.view(batch, channel, new_height, kh, new_width, kw)
44+
out = out.permute(0, 1, 2, 4, 3, 5)
45+
out = out.contiguous()
46+
out = out.view(batch, channel, new_height, new_width, kh * kw)
47+
48+
return out, new_height, new_width
4049

4150

4251
# TODO: Implement for Task 4.3.
52+
def avgpool2d(input: Tensor, kernel: Tuple[int, int]) -> Tensor:
53+
"""Apply average pooling on the 2D input tensor.
54+
55+
Parameters
56+
----------
57+
input : Tensor
58+
Input tensor with shape [batch, channels, height, width].
59+
kernel : Tuple[int, int]
60+
Tuple specifying the height and width of the pooling kernel.
61+
62+
Returns
63+
-------
64+
Tensor
65+
Tensor after applying average pooling, with reduced height and width
66+
depending on the kernel size.
67+
Shape: [batch, channels, new_height, new_width]
68+
69+
"""
70+
tiled, new_height, new_width = tile(input, kernel)
71+
pooled = tiled.mean(dim=4)
72+
return pooled.view(input.shape[0], input.shape[1], new_height, new_width)
73+
74+
75+
class Max(Function):
76+
@staticmethod
77+
def forward(ctx: Context, t: Tensor, dim: Tensor) -> Tensor:
78+
"""Compute the maximum values along a specified dimension.
79+
80+
Parameters
81+
----------
82+
ctx : Context
83+
Context object for storing intermediate values for the backward pass.
84+
t : Tensor
85+
Input tensor.
86+
dim : Tensor
87+
Dimension along which to compute the maximum.
88+
89+
Returns
90+
-------
91+
Tensor
92+
A tensor containing the maximum values along the specified dimension.
93+
94+
"""
95+
d = int(dim.item())
96+
res = FastOps.reduce(operators.max, start=-1e30)(t, d)
97+
ctx.save_for_backward(t, dim, res)
98+
return res
99+
100+
@staticmethod
101+
def backward(ctx: Context, grad_output: Tensor) -> Tuple[Tensor, float]:
102+
"""Compute the gradient of the max operation.
103+
104+
Parameters
105+
----------
106+
ctx : Context
107+
Context object containing saved values from the forward pass.
108+
grad_output : Tensor
109+
Gradient of the loss with respect to the output of the max operation.
110+
111+
Returns
112+
-------
113+
Tuple[Tensor, float]
114+
- Gradient of the loss with respect to the input tensor.
115+
- A float representing the gradient with respect to the dimension, which is always 0.
116+
117+
"""
118+
t, dim, max_val = ctx.saved_values
119+
d = int(dim.item())
120+
mask = t == max_val
121+
sum_mask = mask.sum(dim=d)
122+
grad_input = mask * (grad_output / sum_mask)
123+
return grad_input, 0.0
124+
125+
126+
def max(t: Tensor, dim: int) -> Tensor:
127+
"""Apply the max function along a specified dimension.
128+
129+
Parameters
130+
----------
131+
t : Tensor
132+
Input tensor.
133+
dim : int
134+
Dimension along which to compute the maximum.
135+
136+
Returns
137+
-------
138+
Tensor
139+
Tensor containing the maximum values along the specified dimension.
140+
141+
"""
142+
return Max.apply(t, tensor(dim))
143+
144+
145+
def argmax(t: Tensor, dim: int) -> Tensor:
146+
"""Compute the indices of the maximum values along a specified dimension.
147+
148+
Parameters
149+
----------
150+
t : Tensor
151+
Input tensor.
152+
dim : int
153+
Dimension along which to compute the indices of the maximum.
154+
155+
Returns
156+
-------
157+
Tensor
158+
Tensor containing one-hot encoded indices of the maximum values along the specified dimension.
159+
160+
"""
161+
m = max(t, dim)
162+
expand_shape = list(m.shape)
163+
expand_shape.insert(dim, t.shape[dim])
164+
mask = t == m
165+
return mask
166+
167+
168+
def softmax(t: Tensor, dim: int) -> Tensor:
169+
"""Compute the softmax along a specified dimension.
170+
171+
Parameters
172+
----------
173+
t : Tensor
174+
Input tensor.
175+
dim : int
176+
Dimension along which to compute the softmax.
177+
178+
Returns
179+
-------
180+
Tensor
181+
Tensor containing the softmax probabilities along the specified dimension.
182+
183+
"""
184+
exp_t = t.exp()
185+
sum_exp = exp_t.sum(dim=dim)
186+
return exp_t / sum_exp
187+
188+
189+
def logsoftmax(t: Tensor, dim: int) -> Tensor:
190+
"""Compute the log of the softmax along a specified dimension.
191+
192+
Parameters
193+
----------
194+
t : Tensor
195+
Input tensor.
196+
dim : int
197+
Dimension along which to compute the logsoftmax.
198+
199+
Returns
200+
-------
201+
Tensor
202+
Tensor containing the log of the softmax probabilities along the specified dimension.
203+
204+
"""
205+
m = max(t, dim=dim)
206+
log_sum_exp = ((t - m).exp().sum(dim=dim)).log() + m
207+
return t - log_sum_exp
208+
209+
210+
def maxpool2d(input: Tensor, kernel: Tuple[int, int]) -> Tensor:
211+
"""Apply max pooling on the 2D input tensor.
212+
213+
Parameters
214+
----------
215+
input : Tensor
216+
Input tensor with shape [batch, channels, height, width].
217+
kernel : Tuple[int, int]
218+
Tuple specifying the height and width of the pooling kernel.
219+
220+
Returns
221+
-------
222+
Tensor
223+
Tensor after applying max pooling, with reduced height and width
224+
depending on the kernel size.
225+
Shape: [batch, channels, new_height, new_width]
226+
227+
"""
228+
tiled, new_height, new_width = tile(input, kernel)
229+
pooled = max(tiled, dim=4)
230+
return pooled.view(input.shape[0], input.shape[1], new_height, new_width)
231+
232+
233+
def dropout(input: Tensor, p: float = 0.5, ignore: bool = False) -> Tensor:
234+
"""Apply dropout regularization to the input tensor.
235+
236+
Parameters
237+
----------
238+
input : Tensor
239+
Input tensor.
240+
p : float, optional
241+
Dropout probability (default is 0.5).
242+
ignore : bool, optional
243+
If True, bypass dropout (default is False).
244+
245+
Returns
246+
-------
247+
Tensor
248+
Tensor with randomly zeroed elements scaled by 1 / (1 - p) to maintain expected value.
249+
250+
"""
251+
if p == 1.0:
252+
if not ignore:
253+
return input.zeros(input.shape)
254+
else:
255+
return input
256+
if ignore:
257+
return input
258+
mask = rand(input.shape, backend=input.backend) > p
259+
return input * mask * (1.0 / (1 - p))

0 commit comments

Comments
 (0)