-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathanim.py
More file actions
152 lines (122 loc) · 5.55 KB
/
anim.py
File metadata and controls
152 lines (122 loc) · 5.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Copyright (c) 2022, InterDigital R&D France. All rights reserved. This source
code is made available under the license found in the LICENSE.txt at the root
directory of the repository.
"""
# Project
import util
# Misc
import torch
class Topology:
def __init__(self, hierarchy: list):
# joints
self._joints = tuple(joint for joint, _ in hierarchy)
self._joints_idxs = {joint:index for index, joint in enumerate(self._joints)}
# parents
self._parents = tuple(tuple(parents) for _, parents in hierarchy)
self._parents_idxs = tuple(tuple(self.index(parent) for parent in parents) for parents in self._parents)
# children
self._children = tuple(set() for joint in self._joints)
for joint, parents_idxs in zip(self._joints, self._parents_idxs):
for pidx in parents_idxs:
self._children[pidx].add(joint)
self._children = tuple(frozenset(children) for children in self._children)
self._children_idxs = tuple(frozenset(self.index(child) for child in children) for children in self._children)
# roots
self._roots = tuple(self[jidx] for jidx, parents in enumerate(self._parents) if len(parents) == 0)
self._roots_idxs = tuple(self.index(root) for root in self._roots)
# depth
self._depths = [frozenset() for _ in self._joints]
pending = [(ridx, 0) for ridx in self._roots_idxs]
while len(pending) > 0:
jidx, depth = pending.pop(0)
self._depths[jidx] |= {depth}
pending += [(cidx, depth + 1) for cidx in self._children_idxs[jidx]]
self._depths = tuple(self._depths)
# hierarchy, i.e. list of (joint, first parent) pairs in descending order
self._hierarchy = []
for joint, parents in zip(self._joints, self._parents):
if len(parents) != 0:
self._hierarchy.append((joint, parents[0]))
self._hierarchy = sorted(self._hierarchy, key=lambda b: max(self._depths[self.index(b[1])]))
self._hierarchy_idxs = [(self.index(j), self.index(p)) for j, p in self._hierarchy]
# bones
self._bones = []
for joint, parents in zip(self._joints, self._parents):
self._bones += [(joint, parent) for parent in parents]
self._bones_idxs = [(self.index(joint), self.index(parent)) for joint, parent in self._bones]
def __len__(self) -> int:
return self.njoints
@property
def njoints(self) -> int:
return len(self._joints)
@property
def nbones(self) -> int:
return len(self._bones)
def __contains__(self, joint: str) -> bool:
return joint in self._joints
def __getitem__(self, index: int) -> str:
return self._joints[index]
def isroot(self, joint: str) -> bool:
return self.parent(joint, index=False) is None
def isequiv(self, other) -> bool:
assert isinstance(other, Topology), "Expected {} but got {}.".format(Topology, type(other))
return self.renamed(map(str, range(len(self)))) == other.renamed(map(str, range(len(other))))
def issubset(self, other) -> bool:
assert isinstance(other, Topology), "Expected {} but got {}.".format(Topology, type(other))
return all(joint in other for joint in self)
def index(self, joint: str) -> int:
return self._joints_idxs[joint]
def joints(self):
return self._joints
def parent(self, joint, index=None):
if index is None:
index = isinstance(joint, int)
if not isinstance(joint, int):
joint = self.index(joint)
parents = self._parents_idxs[joint] if index else self._parents[joint]
return (-1 if index else None) if len(parents) == 0 else parents[0]
def parents(self, indices=False):
return self._parents_idxs if indices else self._parents
def roots(self, indices=False):
return self._roots_idxs if indices else self._roots
def hierarchy(self, indices=False):
return self._hierarchy_idxs.copy() if indices else self._hierarchy.copy()
def bones(self, indices=False):
return self._bones_idxs.copy() if indices else self._bones.copy()
def lutable(self, other) -> list:
if not isinstance(other, Topology):
raise TypeError("Expected type '{}' but got '{}'.".format(Topology, type(other)))
return set(other._joints).issubset(self._joints)
def lut(self, other) -> list:
if not self.lutable(other):
raise ValueError("Target topology must be a subset.")
return [self.index(joint) for joint in other._joints]
def mirrored(self):
def mirrored(joint: str):
return joint.replace("left", "<placeholder>").replace("right", "left").replace("<placeholder>", "right")
return Topology([(mirrored(j), [mirrored(p) for p in parents]) for j, parents in zip(self._joints, self._parents)])
class Positions:
def to_local(positions, topology: Topology):
output = positions.clone() # [...] x J x 3
jidxs, pidxs = torch.as_tensor(topology.hierarchy(indices=True)).unbind(1)
output[..., jidxs, :] -= output[..., pidxs, :]
return output # [...] x J x 3
def to_global(positions, topology: Topology):
output = positions.clone() # [...] x J x 3
for j, p in topology.hierarchy(indices=True):
output[..., j, :] += output[..., p, :] # [...] x 3
return output # [...] x J x 3
def FK(angles: torch.Tensor, skeleton: torch.Tensor, trajectory: torch.Tensor, topology: Topology) -> torch.Tensor:
"""
Compute global joint positions (i.e. forward kinematics) from MVNX data representation.
"""
output = util.HMat.join(
util.SO3.to_local(util.SU2.to_SO3(angles), topology),
Positions.to_local(skeleton, topology),
)
for j, p in topology.hierarchy(indices=True):
parent, joint = output[..., p, :,:], output[..., j, :,:]
output[..., j, :,:] = util.HMat.compose(joint.clone(), parent.clone())
positions = util.HMat.tvec(output)
return positions if trajectory is None else (positions + trajectory)