-
Notifications
You must be signed in to change notification settings - Fork 1.3k
WIP API for Yolact, python package #323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 9 commits
f75754d
286c496
e6815e1
e752308
a0905eb
d824333
913045f
0aecc3d
39739f4
eb61db9
236352d
2f73774
cf2d4e6
fccd89b
5a5d81d
550d83a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import torch, torchvision | ||
import torch.nn as nn | ||
|
||
class Concat(nn.Module): | ||
def __init__(self, nets, extra_params): | ||
super().__init__() | ||
|
||
self.nets = nn.ModuleList(nets) | ||
self.extra_params = extra_params | ||
|
||
def forward(self, x): | ||
# Concat each along the channel dimension | ||
return torch.cat([net(x) for net in self.nets], dim=1, **self.extra_params) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
#locals | ||
from data.config import Config | ||
from utils.functions import make_net | ||
|
||
|
||
# As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules | ||
use_jit = torch.cuda.device_count() <= 1 | ||
if not use_jit: | ||
print('Multiple GPUs detected! Turning off JIT.') | ||
|
||
ScriptModuleWrapper = torch.jit.ScriptModule if use_jit else nn.Module | ||
script_method_wrapper = torch.jit.script_method if use_jit else lambda fn, _rcn=None: fn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part should probably go in its own file, since it should only be loaded once and might need to be used in multiple places. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure why that's important, but like this fccd89b ? |
||
|
||
|
||
class FastMaskIoUNet(ScriptModuleWrapper): | ||
|
||
def __init__(self, config:Config): | ||
super().__init__() | ||
|
||
cfg = config | ||
input_channels = 1 | ||
last_layer = [(cfg.num_classes-1, 1, {})] | ||
self.maskiou_net, _ = make_net(input_channels, cfg.maskiou_net + last_layer, include_last_relu=True) | ||
|
||
def forward(self, x): | ||
x = self.maskiou_net(x) | ||
maskiou_p = F.max_pool2d(x, kernel_size=x.size()[2:]).squeeze(-1).squeeze(-1) | ||
return maskiou_p | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
|
||
from typing import List | ||
|
||
#local imports | ||
from data.config import Config | ||
|
||
# As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules | ||
use_jit = torch.cuda.device_count() <= 1 | ||
if not use_jit: | ||
print('Multiple GPUs detected! Turning off JIT.') | ||
|
||
ScriptModuleWrapper = torch.jit.ScriptModule if use_jit else nn.Module | ||
script_method_wrapper = torch.jit.script_method if use_jit else lambda fn, _rcn=None: fn | ||
breznak marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
|
||
class FPN(ScriptModuleWrapper): | ||
""" | ||
Implements a general version of the FPN introduced in | ||
https://arxiv.org/pdf/1612.03144.pdf | ||
|
||
Parameters (in cfg.fpn): | ||
- num_features (int): The number of output features in the fpn layers. | ||
- interpolation_mode (str): The mode to pass to F.interpolate. | ||
- num_downsample (int): The number of downsampled layers to add onto the selected layers. | ||
These extra layers are downsampled from the last selected layer. | ||
|
||
Args: | ||
- in_channels (list): For each conv layer you supply in the forward pass, | ||
how many features will it have? | ||
""" | ||
__constants__ = ['interpolation_mode', 'num_downsample', 'use_conv_downsample', 'relu_pred_layers', | ||
'lat_layers', 'pred_layers', 'downsample_layers', 'relu_downsample_layers'] | ||
|
||
def __init__(self, in_channels, config:Config): | ||
super().__init__() | ||
|
||
cfg = config | ||
|
||
self.lat_layers = nn.ModuleList([ | ||
nn.Conv2d(x, cfg.fpn.num_features, kernel_size=1) | ||
for x in reversed(in_channels) | ||
]) | ||
|
||
# This is here for backwards compatability | ||
padding = 1 if cfg.fpn.pad else 0 | ||
self.pred_layers = nn.ModuleList([ | ||
nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=padding) | ||
for _ in in_channels | ||
]) | ||
|
||
if cfg.fpn.use_conv_downsample: | ||
self.downsample_layers = nn.ModuleList([ | ||
nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=1, stride=2) | ||
for _ in range(cfg.fpn.num_downsample) | ||
]) | ||
|
||
self.interpolation_mode = cfg.fpn.interpolation_mode | ||
self.num_downsample = cfg.fpn.num_downsample | ||
self.use_conv_downsample = cfg.fpn.use_conv_downsample | ||
self.relu_downsample_layers = cfg.fpn.relu_downsample_layers | ||
self.relu_pred_layers = cfg.fpn.relu_pred_layers | ||
|
||
@script_method_wrapper | ||
def forward(self, convouts:List[torch.Tensor]): | ||
""" | ||
Args: | ||
- convouts (list): A list of convouts for the corresponding layers in in_channels. | ||
Returns: | ||
- A list of FPN convouts in the same order as x with extra downsample layers if requested. | ||
""" | ||
|
||
out = [] | ||
x = torch.zeros(1, device=convouts[0].device) | ||
for i in range(len(convouts)): | ||
out.append(x) | ||
|
||
# For backward compatability, the conv layers are stored in reverse but the input and output is | ||
# given in the correct order. Thus, use j=-i-1 for the input and output and i for the conv layers. | ||
j = len(convouts) | ||
for lat_layer in self.lat_layers: | ||
j -= 1 | ||
|
||
if j < len(convouts) - 1: | ||
_, _, h, w = convouts[j].size() | ||
x = F.interpolate(x, size=(h, w), mode=self.interpolation_mode, align_corners=False) | ||
|
||
x = x + lat_layer(convouts[j]) | ||
out[j] = x | ||
|
||
# This janky second loop is here because TorchScript. | ||
j = len(convouts) | ||
for pred_layer in self.pred_layers: | ||
j -= 1 | ||
out[j] = pred_layer(out[j]) | ||
|
||
if self.relu_pred_layers: | ||
F.relu(out[j], inplace=True) | ||
|
||
cur_idx = len(out) | ||
|
||
# In the original paper, this takes care of P6 | ||
if self.use_conv_downsample: | ||
for downsample_layer in self.downsample_layers: | ||
out.append(downsample_layer(out[-1])) | ||
else: | ||
for idx in range(self.num_downsample): | ||
# Note: this is an untested alternative to out.append(out[-1][:, :, ::2, ::2]). Thanks TorchScript. | ||
out.append(nn.functional.max_pool2d(out[-1], 1, stride=2)) | ||
|
||
if self.relu_downsample_layers: | ||
for idx in range(len(out) - cur_idx): | ||
out[idx] = F.relu(out[idx + cur_idx], inplace=False) | ||
|
||
return out |
Uh oh!
There was an error while loading. Please reload this page.