M2_SETI/IA/seti_master-master/code/onnx2pytorch.py

329 lines
9 KiB
Python
Raw Normal View History

2023-01-29 16:56:40 +01:00
#stolen from https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8
import onnx
import struct
import torch
import torch.nn as nn
import warnings
from collections import OrderedDict
data_type_tab = {
1: ['f', 4],
2: ['B', 1],
3: ['b', 1],
4: ['H', 2],
5: ['h', 2],
6: ['i', 4],
7: ['q', 8],
10: ['e', 2],
11: ['d', 8],
12: ['I', 4],
13: ['Q', 8]
}
def empty(x):
return x
def _slim422(l4):
assert len(l4) == 4
p0, p1 = l4[::2]
if l4[0] == 0: # TODO bad code
p0 = l4[2] // 2
if l4[2] == 1:
p0 = 1
if l4[1] == 0: # TODO bad code
p1 = l4[3] // 2
if l4[3] == 1:
p1 = 1
return p0, p1
def _check_attr(attrs, map):
for attr in attrs:
if attr.name not in map:
warnings.warn("Missing {} in parser's attr_map.".format(attr.name))
def unpack_weights(initializer):
ret = {}
for i in initializer:
name = i.name
dtype = i.data_type
shape = list(i.dims)
if dtype not in data_type_tab:
warnings("This data type {} is not supported yet.".format(dtype))
fmt, size = data_type_tab[dtype]
if len(i.raw_data) == 0:
if dtype == 1:
data_list = i.float_data
elif dtype == 7:
data_list = i.int64_data
else:
warnings.warn("No-raw-data type {} not supported yet.".format(dtype))
else:
data_list = struct.unpack('<' + fmt * (len(i.raw_data) // size), i.raw_data)
t = torch.tensor(data_list)
if len(shape) != 0:
t = t.view(*shape)
ret[name] = t
return ret
def rebuild_conv(node, weights):
rebuild_conv.conv_attr_map = {
"pads": "padding",
"strides": "stride",
"kernel_shape": "kernel_size",
"group": "groups",
"dilations": "dilation"
}
assert len(node.output) == 1
with_bias = False
if len(node.input) == 3:
with_bias = True
bias_name = node.input[2]
bias = weights[bias_name]
weight_name = node.input[1]
weight = weights[weight_name]
in_channels = weight.shape[1]
out_channels = weight.shape[0]
kwargs = {}
for att in node.attribute:
kwargs[rebuild_conv.conv_attr_map[att.name]] = list(att.ints) if att.name != 'group' else att.i
if 'padding' in kwargs:
kwargs["padding"] = _slim422(kwargs["padding"])
groups = 1 if 'groups' not in kwargs else kwargs['groups']
in_channels *= groups
conv = nn.Conv2d(in_channels, out_channels, **kwargs, bias=with_bias)
conv.weight.data = weight
if with_bias:
conv.bias.data = bias
return conv, node.input[:1], node.output
def rebuild_dropout(node, weights):
ratio = node.attribute[0].f
return nn.Dropout2d(p=ratio), node.input, node.output
def rebuild_relu(node, weights, nodes):
return nn.ReLU(), node.input, node.output
def rebuild_maxpool(node, weights):
rebuild_maxpool.mp_attr_map = {
"pads": "padding",
"strides": "stride",
"kernel_shape": "kernel_size",
}
kwargs = {}
for att in node.attribute:
kwargs[rebuild_maxpool.mp_attr_map[att.name]] = list(att.ints)
if 'padding' in kwargs:
kwargs["padding"] = _slim422(kwargs["padding"])
mp = nn.MaxPool2d(**kwargs)
return mp, node.input, node.output
def rebuild_add(node, weights, graph):
def add(a, b):
return a + b
return add, node.input, node.output
def rebuild_transpose(node, weights, graph):
perm = node.attribute[0].ints
def transpose(x):
x = x.permute(*perm)
return x
return transpose, node.input, node.output
def rebuild_flatten(node, weights):
if len(node.attribute) == 0:
d = 1
else:
d = node.attribute[0].i
def flatten(x):
o_shape = []
for i in range(d):
o_shape.append(x.shape[i])
o_shape.append(-1)
return x.view(*o_shape)
return flatten, node.input, node.output
def rebuild_gemm(node, weights):
weight = weights[node.input[1]]
bias = weights[node.input[2]]
in_feats = weight.shape[1]
out_feats = weight.shape[0]
linear = nn.Linear(in_features=in_feats, out_features=out_feats)
linear.weight.data = weight
linear.bias.data = bias
return linear, node.input[:1], node.output
#TODO: rebuild matmul, find a way to include the transpose weights
def rebuild_matmul(node, weights, nodes):
weight = weights[node.input[1]]
for next_node in nodes:
if next_node.input[0] == node.output[0]:
bias = weights[next_node.input[1]]
in_feats = weight.shape[0]
out_feats = weight.shape[1]
linear = nn.Linear(in_features=in_feats,
out_features=out_feats)
linear.weight.data = weight.transpose(1,0)
linear.bias.data = bias
return linear, node.input[:1], node.output
def rebuild_pad(node, weights):
mode = node.attribute[0].s
pads = list(node.attribute[1].ints)
value = node.attribute[2].f
assert mode == b'constant' # TODO constant only
assert sum(pads[:4]) == 0 # TODO pad2d only
pad = nn.ConstantPad2d(pads[4:], value)
return pad, node.input, node.output
def rebuild_constant(node, weights, nodes):
raw_data = node.attribute[0].t.raw_data
data_type = node.attribute[0].t.data_type
fmt, size = data_type_tab[data_type]
data = struct.unpack('<' + fmt * (len(raw_data) // size), raw_data)
if len(data) == 1:
data = data[0]
def constant():
return torch.tensor(data)
return constant, [], node.output
def rebuild_sum(node, weights):
def sum(*inputs):
ret = inputs[0]
for i in inputs[1:]:
ret += i
return ret
return sum, node.input, node.output
def rebuild_shape(node, weights):
def shape(x):
return torch.tensor(list(x.shape))
return shape, node.input, node.output
def rebuild_gather(node, weights):
axis = node.attribute[0].i
def gather(x, idx):
return torch.gather(x, axis, idx)
return gather, node.input, node.output
def _nd_unsqueeze(x, dims):
dims = sorted(dims)
for d in dims:
x = torch.unsqueeze(x, dim=d)
return x
def rebuild_unsqueeze(node, weights):
axes = node.attribute[0].ints
def unsqueeze(x):
return _nd_unsqueeze(x, axes)
return unsqueeze, node.input, node.output
def rebuild_mul(node, weights):
def mul(a, b):
return a * b
return mul, node.input, node.output
def rebuild_softmax(node, weights):
def f_softmax(x):
return x.softmax(dim=1, dtype=torch.double).float()
return f_softmax, node.input, node.output
def rebuild_reshape(node, weights, nodes):
def reshape(x, s):
data_shape = x.shape
onnx_shape = s.tolist()
pt_shape = []
for idx, d in enumerate(onnx_shape):
if d == 0:
pt_shape.append(data_shape[idx])
else:
pt_shape.append(d)
return torch.reshape(x, pt_shape)
return reshape, node.input, node.output
def rebuild_op(node, weights, nodes):
op_type = node.op_type
return globals()['rebuild_'+op_type.lower()](node, weights, nodes)
def construct_pytorch_nodes(graph, weights):
ret = []
for single_node in graph.node:
ret.append(rebuild_op(single_node, weights, graph.node))
return ret
def resolve_deps(name, deps, inter_tensors):
if name in inter_tensors:
return
else:
op, deps_names = deps[name]
args = []
for deps_name in deps_names:
resolve_deps(deps_name, deps, inter_tensors)
args.append(inter_tensors[deps_name])
result = op(*args)
inter_tensors[name] = result
class DependencyModule(nn.Module):
#TODO: add a layers and convert everything to a dict
def __init__(self, onnx_model, input_name=None):
super(DependencyModule, self).__init__()
self.deps = {}
self.inter_tensors = dict()
self.weights = unpack_weights(onnx_model.graph.initializer)
nodes = construct_pytorch_nodes(onnx_model.graph, self.weights)
#TODO: ModuleDict with propernames
layers = OrderedDict({})
for idx, (node, inputs, outputs) in enumerate(nodes):
if isinstance(node, nn.Module):
layers['layer'+str(idx)] = node
for output_name in outputs:
self.deps[output_name] = (node, inputs)
self.input_name = onnx_model.graph.input[0].name # TODO only you
self.output_name = onnx_model.graph.output[0].name # TODO only you
if input_name is not None:
self.input_name = input_name
self.layers = nn.ModuleDict(layers)
def forward(self, input):
self.inter_tensors = self.weights.copy()
self.inter_tensors[self.input_name] = input
resolve_deps(self.output_name, self.deps, self.inter_tensors)
return self.inter_tensors[self.output_name]
def convert(onnx_path):
onnx_model = onnx.load(onnx_path)
reconstruct_model = DependencyModule(onnx_model)
return reconstruct_model