Skip to content

Commit fbcc5db

Browse files
added typehints to midimessage and init
1 parent c4136b9 commit fbcc5db

File tree

2 files changed

+66
-32
lines changed

2 files changed

+66
-32
lines changed

adafruit_midi/__init__.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
https://github.com/adafruit/circuitpython/releases
2626
2727
"""
28+
try:
29+
from typing import Union, Tuple, Any, List, Optional, Dict
30+
except ImportError:
31+
pass
2832

2933
from .midi_message import MIDIMessage
3034

@@ -54,13 +58,13 @@ class MIDI:
5458

5559
def __init__(
5660
self,
57-
midi_in=None,
58-
midi_out=None,
61+
midi_in: Optional[Any] = None,
62+
midi_out: Optional[Any] = None,
5963
*,
60-
in_channel=None,
61-
out_channel=0,
62-
in_buf_size=30,
63-
debug=False
64+
in_channel: Optional[Union[int, Tuple[int, ...]]] = None,
65+
out_channel: int = 0,
66+
in_buf_size: int = 30,
67+
debug: bool = False
6468
):
6569
if midi_in is None and midi_out is None:
6670
raise ValueError("No midi_in or midi_out provided")
@@ -78,7 +82,7 @@ def __init__(
7882
self._skipped_bytes = 0
7983

8084
@property
81-
def in_channel(self):
85+
def in_channel(self) -> Optional[Union[int, Tuple[int, ...]]]:
8286
"""The incoming MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g.
8387
``in_channel = 3`` will listen on MIDI channel 4.
8488
Can also listen on multiple channels, e.g. ``in_channel = (0,1,2)``
@@ -87,7 +91,7 @@ def in_channel(self):
8791
return self._in_channel
8892

8993
@in_channel.setter
90-
def in_channel(self, channel):
94+
def in_channel(self, channel: Optional[Union[str, int, Tuple[int, ...]]]) -> None:
9195
if channel is None or channel == "ALL":
9296
self._in_channel = tuple(range(16))
9397
elif isinstance(channel, int) and 0 <= channel <= 15:
@@ -98,18 +102,19 @@ def in_channel(self, channel):
98102
raise RuntimeError("Invalid input channel")
99103

100104
@property
101-
def out_channel(self):
105+
def out_channel(self) -> int:
102106
"""The outgoing MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g.
103107
``out_channel = 3`` will send to MIDI channel 4. Default is 0 (MIDI channel 1)."""
104108
return self._out_channel
105109

106110
@out_channel.setter
107-
def out_channel(self, channel):
111+
def out_channel(self, channel: Optional[int]) -> None:
112+
assert channel is not None
108113
if not 0 <= channel <= 15:
109114
raise RuntimeError("Invalid output channel")
110115
self._out_channel = channel
111116

112-
def receive(self):
117+
def receive(self) -> Optional[MIDIMessage]:
113118
"""Read messages from MIDI port, store them in internal read buffer, then parse that data
114119
and return the first MIDI message (event).
115120
This maintains the blocking characteristics of the midi_in port.
@@ -120,6 +125,7 @@ def receive(self):
120125
# If the buffer here is not full then read as much as we can fit from
121126
# the input port
122127
if len(self._in_buf) < self._in_buf_size:
128+
assert self._midi_in is not None
123129
bytes_in = self._midi_in.read(self._in_buf_size - len(self._in_buf))
124130
if bytes_in:
125131
if self._debug:
@@ -140,7 +146,7 @@ def receive(self):
140146
# msg could still be None at this point, e.g. in middle of monster SysEx
141147
return msg
142148

143-
def send(self, msg, channel=None):
149+
def send(self, msg: MIDIMessage, channel: Optional[int] = None) -> None:
144150
"""Sends a MIDI message.
145151
146152
:param msg: Either a MIDIMessage object or a sequence (list) of MIDIMessage objects.
@@ -164,7 +170,8 @@ def send(self, msg, channel=None):
164170

165171
self._send(data, len(data))
166172

167-
def _send(self, packet, num):
173+
def _send(self, packet: bytes, num: int) -> None:
168174
if self._debug:
169175
print("Sending: ", [hex(i) for i in packet[:num]])
176+
assert self._midi_out is not None
170177
self._midi_out.write(packet, num)

adafruit_midi/midi_message.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,19 @@
2525
__version__ = "0.0.0+auto.0"
2626
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_MIDI.git"
2727

28+
try:
29+
from typing import Union, Tuple, Any, List, Optional, Dict
30+
except ImportError:
31+
pass
32+
2833
# From C3 - A and B are above G
2934
# Semitones A B C D E F G
3035
NOTE_OFFSET = [21, 23, 12, 14, 16, 17, 19]
3136

3237

33-
def channel_filter(channel, channel_spec):
38+
def channel_filter(
39+
channel: int, channel_spec: Optional[Union[int, Tuple[int, ...]]]
40+
) -> bool:
3441
"""
3542
Utility function to return True iff the given channel matches channel_spec.
3643
"""
@@ -41,13 +48,12 @@ def channel_filter(channel, channel_spec):
4148
raise ValueError("Incorrect type for channel_spec" + str(type(channel_spec)))
4249

4350

44-
def note_parser(note):
51+
def note_parser(note: Union[int, str]) -> int:
4552
"""If note is a string then it will be parsed and converted to a MIDI note (key) number, e.g.
4653
"C4" will return 60, "C#4" will return 61. If note is not a string it will simply be returned.
4754
4855
:param note: Either 0-127 int or a str representing the note, e.g. "C#4"
4956
"""
50-
midi_note = note
5157
if isinstance(note, str):
5258
if len(note) < 2:
5359
raise ValueError("Bad note format")
@@ -61,7 +67,8 @@ def note_parser(note):
6167
sharpen = -1
6268
# int may throw exception here
6369
midi_note = int(note[1 + abs(sharpen) :]) * 12 + NOTE_OFFSET[noteidx] + sharpen
64-
70+
elif isinstance(note, int):
71+
midi_note = note
6572
return midi_note
6673

6774

@@ -82,57 +89,70 @@ class MIDIMessage:
8289
This is an *abstract* class.
8390
"""
8491

85-
_STATUS = None
92+
_STATUS: Optional[int] = None
8693
_STATUSMASK = None
87-
LENGTH = None
94+
LENGTH: Optional[int] = None
8895
CHANNELMASK = 0x0F
8996
ENDSTATUS = None
9097

9198
# Commonly used exceptions to save memory
9299
@staticmethod
93-
def _raise_valueerror_oor():
100+
def _raise_valueerror_oor() -> None:
94101
raise ValueError("Out of range")
95102

96103
# Each element is ((status, mask), class)
97104
# order is more specific masks first
98-
_statusandmask_to_class = []
105+
# Add better type hints for status, mask, class referenced above
106+
_statusandmask_to_class: List[
107+
Tuple[Tuple[Optional[bytes], Optional[int]], "MIDIMessage"]
108+
] = []
99109

100-
def __init__(self, *, channel=None):
110+
def __init__(self, *, channel: Optional[int] = None) -> None:
101111
self._channel = channel # dealing with pylint inadequacy
102112
self.channel = channel
103113

104114
@property
105-
def channel(self):
115+
def channel(self) -> Optional[int]:
106116
"""The channel number of the MIDI message where appropriate.
107117
This is *updated* by MIDI.send() method.
108118
"""
109119
return self._channel
110120

111121
@channel.setter
112-
def channel(self, channel):
122+
def channel(self, channel: int) -> None:
113123
if channel is not None and not 0 <= channel <= 15:
114124
raise ValueError("Channel must be 0-15 or None")
115125
self._channel = channel
116126

117127
@classmethod
118-
def register_message_type(cls):
128+
def register_message_type(cls) -> None:
119129
"""Register a new message by its status value and mask.
120130
This is called automagically at ``import`` time for each message.
121131
"""
122132
### These must be inserted with more specific masks first
123133
insert_idx = len(MIDIMessage._statusandmask_to_class)
124134
for idx, m_type in enumerate(MIDIMessage._statusandmask_to_class):
135+
assert cls._STATUSMASK is not None
125136
if cls._STATUSMASK > m_type[0][1]:
126137
insert_idx = idx
127138
break
128139

140+
assert cls._STATUS is not None
141+
assert cls._STATUSMASK is not None
129142
MIDIMessage._statusandmask_to_class.insert(
130143
insert_idx, ((cls._STATUS, cls._STATUSMASK), cls)
131144
)
132145

133146
# pylint: disable=too-many-arguments
134147
@classmethod
135-
def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endidx):
148+
def _search_eom_status(
149+
cls,
150+
buf: Dict[int, bool],
151+
eom_status: bool,
152+
msgstartidx: int,
153+
msgendidxplusone: int,
154+
endidx: int,
155+
) -> Tuple[int, bool, bool]:
136156
good_termination = False
137157
bad_termination = False
138158

@@ -155,14 +175,17 @@ def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endi
155175
return (msgendidxplusone, good_termination, bad_termination)
156176

157177
@classmethod
158-
def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx):
178+
def _match_message_status(
179+
cls, buf: bytearray, msgstartidx: int, msgendidxplusone: int, endidx: int
180+
) -> Tuple[Optional[Any], bool, bool, bool, bool, int]:
159181
msgclass = None
160182
status = buf[msgstartidx]
161183
known_msg = False
162184
complete_msg = False
163185
bad_termination = False
164186

165187
# Rummage through our list looking for a status match
188+
assert msgclass is not None
166189
for status_mask, msgclass in MIDIMessage._statusandmask_to_class:
167190
masked_status = status & status_mask[1]
168191
if status_mask[0] == masked_status:
@@ -198,7 +221,9 @@ def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx):
198221

199222
# pylint: disable=too-many-locals,too-many-branches
200223
@classmethod
201-
def from_message_bytes(cls, midibytes, channel_in):
224+
def from_message_bytes(
225+
cls, midibytes: bytearray, channel_in: Optional[Union[int, Tuple[int, ...]]]
226+
) -> Tuple[Optional["MIDIMessage"], int, int]:
202227
"""Create an appropriate object of the correct class for the
203228
first message found in some MIDI bytes filtered by channel_in.
204229
@@ -240,6 +265,7 @@ def from_message_bytes(cls, midibytes, channel_in):
240265
channel_match_orna = True
241266
if complete_message and not bad_termination:
242267
try:
268+
assert msgclass is not None
243269
msg = msgclass.from_bytes(midibytes[msgstartidx:msgendidxplusone])
244270
if msg.channel is not None:
245271
channel_match_orna = channel_filter(msg.channel, channel_in)
@@ -270,17 +296,18 @@ def from_message_bytes(cls, midibytes, channel_in):
270296

271297
# A default method for constructing wire messages with no data.
272298
# Returns an (immutable) bytes with just the status code in.
273-
def __bytes__(self):
299+
def __bytes__(self) -> bytes:
274300
"""Return the ``bytes`` wire protocol representation of the object
275301
with channel number applied where appropriate."""
302+
assert self._STATUS is not None
276303
return bytes([self._STATUS])
277304

278305
# databytes value present to keep interface uniform but unused
279306
# A default method for constructing message objects with no data.
280307
# Returns the new object.
281308
# pylint: disable=unused-argument
282309
@classmethod
283-
def from_bytes(cls, msg_bytes):
310+
def from_bytes(cls, msg_bytes: bytes) -> MIDIMessage:
284311
"""Creates an object from the byte stream of the wire protocol
285312
representation of the MIDI message."""
286313
return cls()
@@ -298,7 +325,7 @@ class MIDIUnknownEvent(MIDIMessage):
298325

299326
LENGTH = -1
300327

301-
def __init__(self, status):
328+
def __init__(self, status: int):
302329
self.status = status
303330
super().__init__()
304331

@@ -316,7 +343,7 @@ class MIDIBadEvent(MIDIMessage):
316343

317344
LENGTH = -1
318345

319-
def __init__(self, msg_bytes, exception):
346+
def __init__(self, msg_bytes: bytearray, exception: Exception):
320347
self.data = bytes(msg_bytes)
321348
self.exception_text = repr(exception)
322349
super().__init__()

0 commit comments

Comments
 (0)