Skip to content

Commit 70b6ae4

Browse files
committed
WIP: use context in transport/UVHandle directly
1 parent 7975b7b commit 70b6ae4

File tree

9 files changed

+101
-17
lines changed

9 files changed

+101
-17
lines changed

tests/test_context.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,26 @@
33
import decimal
44
import random
55
import socket
6+
import unittest
67
import weakref
78

89
from uvloop import _testbase as tb
910

1011

1112
class _Protocol(asyncio.Protocol):
12-
def __init__(self, *, loop=None):
13+
def __init__(self, cvar, *, loop=None):
14+
self.cvar = cvar
15+
self.transport = None
16+
self.connection_made_fut = asyncio.Future(loop=loop)
17+
self.connection_lost_ctx = None
1318
self.done = asyncio.Future(loop=loop)
1419

20+
def connection_made(self, transport):
21+
self.transport = transport
22+
self.connection_made_fut.set_result(self.cvar.get())
23+
1524
def connection_lost(self, exc):
25+
self.connection_lost_ctx = self.cvar.get()
1626
if exc is None:
1727
self.done.set_result(None)
1828
else:
@@ -141,7 +151,7 @@ async def main():
141151
def test_create_server_protocol_factory_context(self):
142152
cvar = contextvars.ContextVar('cvar', default='outer')
143153
factory_called_future = self.loop.create_future()
144-
proto = _Protocol(loop=self.loop)
154+
proto = _Protocol(cvar, loop=self.loop)
145155

146156
def factory():
147157
try:
@@ -172,6 +182,71 @@ async def test():
172182

173183
self.loop.run_until_complete(test())
174184

185+
def test_create_server_connection_made_lost(self):
186+
cvar = contextvars.ContextVar('cvar', default='outer')
187+
proto = _Protocol(cvar, loop=self.loop)
188+
189+
async def test():
190+
cvar.set('inner')
191+
port = tb.find_free_port()
192+
srv = await self.loop.create_server(
193+
lambda: proto, '127.0.0.1', port,
194+
)
195+
196+
s = socket.socket(socket.AF_INET)
197+
with s:
198+
s.setblocking(False)
199+
await self.loop.sock_connect(s, ('127.0.0.1', port))
200+
201+
try:
202+
inner = await proto.connection_made_fut
203+
self.assertEqual(inner, "inner")
204+
205+
await proto.done
206+
self.assertEqual(proto.connection_lost_ctx, "inner")
207+
finally:
208+
srv.close()
209+
await srv.wait_closed()
210+
211+
self.loop.run_until_complete(test())
212+
213+
def test_create_server_manual_connection_lost(self):
214+
if self.implementation == 'asyncio':
215+
raise unittest.SkipTest('this seems to be a bug in asyncio')
216+
217+
cvar = contextvars.ContextVar('cvar', default='outer')
218+
proto = _Protocol(cvar, loop=self.loop)
219+
220+
async def close():
221+
cvar.set('closing')
222+
proto.transport.close()
223+
224+
async def test():
225+
cvar.set('inner')
226+
port = tb.find_free_port()
227+
srv = await self.loop.create_server(
228+
lambda: proto, '127.0.0.1', port,
229+
)
230+
231+
s = socket.socket(socket.AF_INET)
232+
s.setblocking(False)
233+
await self.loop.sock_connect(s, ('127.0.0.1', port))
234+
235+
try:
236+
inner = await proto.connection_made_fut
237+
self.assertEqual(inner, "inner")
238+
239+
await asyncio.ensure_future(close())
240+
241+
await proto.done
242+
self.assertEqual(proto.connection_lost_ctx, "inner")
243+
finally:
244+
s.close()
245+
srv.close()
246+
await srv.wait_closed()
247+
248+
self.loop.run_until_complete(test())
249+
175250

176251
class Test_UV_Context(_ContextBaseTests, tb.UVTestCase):
177252
pass

uvloop/handles/basetransport.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ cdef class UVBaseTransport(UVSocketHandle):
2626
new_MethodHandle(self._loop,
2727
"UVTransport._call_connection_made",
2828
<method_t>self._call_connection_made,
29-
None,
29+
self.context,
3030
self))
3131

3232
cdef inline _schedule_call_connection_lost(self, exc):
3333
self._loop._call_soon_handle(
3434
new_MethodHandle1(self._loop,
3535
"UVTransport._call_connection_lost",
3636
<method1_t>self._call_connection_lost,
37-
None,
37+
self.context,
3838
self, exc))
3939

4040
cdef _fatal_error(self, exc, throw, reason=None):

uvloop/handles/handle.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ cdef class UVHandle:
55
readonly _source_traceback
66
bint _closed
77
bint _inited
8+
object context
89

910
# Added to enable current UDPTransport implementation,
1011
# which doesn't use libuv handles.

uvloop/handles/pipe.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ cdef class UnixServer(UVStreamServer):
7373

7474
self._mark_as_open()
7575

76-
cdef UVStream _make_new_transport(self, object protocol, object waiter):
76+
cdef UVStream _make_new_transport(self, object protocol, object waiter,
77+
object context):
7778
cdef UnixTransport tr
78-
tr = UnixTransport.new(self._loop, protocol, self._server, waiter)
79+
tr = UnixTransport.new(self._loop, protocol, self._server, waiter) # TODO: context
7980
return <UVStream>tr
8081

8182

uvloop/handles/streamserver.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ cdef class UVStreamServer(UVSocketHandle):
2323
cdef inline listen(self)
2424
cdef inline _on_listen(self)
2525

26-
cdef UVStream _make_new_transport(self, object protocol, object waiter)
26+
cdef UVStream _make_new_transport(self, object protocol, object waiter,
27+
object context)

uvloop/handles/streamserver.pyx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ cdef class UVStreamServer(UVSocketHandle):
7070
protocol = self.listen_context.run(self.protocol_factory)
7171

7272
if self.ssl is None:
73-
client = self._make_new_transport(protocol, None)
73+
client = self._make_new_transport(protocol, None,
74+
self.listen_context)
7475

7576
else:
7677
waiter = self._loop._new_future()
@@ -83,7 +84,8 @@ cdef class UVStreamServer(UVSocketHandle):
8384
ssl_handshake_timeout=self.ssl_handshake_timeout,
8485
ssl_shutdown_timeout=self.ssl_shutdown_timeout)
8586

86-
client = self._make_new_transport(ssl_protocol, None)
87+
client = self._make_new_transport(ssl_protocol, None,
88+
self.listen_context)
8789

8890
waiter.add_done_callback(
8991
ft_partial(self.__on_ssl_connected, client))
@@ -112,7 +114,8 @@ cdef class UVStreamServer(UVSocketHandle):
112114
cdef inline _mark_as_open(self):
113115
self.opened = 1
114116

115-
cdef UVStream _make_new_transport(self, object protocol, object waiter):
117+
cdef UVStream _make_new_transport(self, object protocol, object waiter,
118+
object context):
116119
raise NotImplementedError
117120

118121
def __on_ssl_connected(self, transport, fut):

uvloop/handles/tcp.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ cdef class TCPTransport(UVStream):
2323

2424
@staticmethod
2525
cdef TCPTransport new(Loop loop, object protocol, Server server,
26-
object waiter)
26+
object waiter, object context)

uvloop/handles/tcp.pyx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,11 @@ cdef class TCPServer(UVStreamServer):
9191
else:
9292
self._mark_as_open()
9393

94-
cdef UVStream _make_new_transport(self, object protocol, object waiter):
94+
cdef UVStream _make_new_transport(self, object protocol, object waiter,
95+
object context):
9596
cdef TCPTransport tr
96-
tr = TCPTransport.new(self._loop, protocol, self._server, waiter)
97+
tr = TCPTransport.new(self._loop, protocol, self._server, waiter,
98+
context)
9799
return <UVStream>tr
98100

99101

@@ -102,7 +104,7 @@ cdef class TCPTransport(UVStream):
102104

103105
@staticmethod
104106
cdef TCPTransport new(Loop loop, object protocol, Server server,
105-
object waiter):
107+
object waiter, object context):
106108

107109
cdef TCPTransport handle
108110
handle = TCPTransport.__new__(TCPTransport)
@@ -111,6 +113,7 @@ cdef class TCPTransport(UVStream):
111113
handle.__peername_set = 0
112114
handle.__sockname_set = 0
113115
handle._set_nodelay()
116+
handle.context = context
114117
return handle
115118

116119
cdef _set_nodelay(self):

uvloop/loop.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,7 +1921,7 @@ cdef class Loop:
19211921
tr = None
19221922
try:
19231923
waiter = self._new_future()
1924-
tr = TCPTransport.new(self, protocol, None, waiter)
1924+
tr = TCPTransport.new(self, protocol, None, waiter, None) # TODO: context
19251925

19261926
if lai is not NULL:
19271927
lai_iter = lai
@@ -1981,7 +1981,7 @@ cdef class Loop:
19811981
sock.setblocking(False)
19821982

19831983
waiter = self._new_future()
1984-
tr = TCPTransport.new(self, protocol, None, waiter)
1984+
tr = TCPTransport.new(self, protocol, None, waiter, None) # TODO: context
19851985
try:
19861986
# libuv will make socket non-blocking
19871987
tr._open(sock.fileno())
@@ -2595,7 +2595,7 @@ cdef class Loop:
25952595
self, protocol, None, transport_waiter)
25962596
elif sock.family in (uv.AF_INET, uv.AF_INET6):
25972597
transport = <UVStream>TCPTransport.new(
2598-
self, protocol, None, transport_waiter)
2598+
self, protocol, None, transport_waiter, None) # TODO: context
25992599

26002600
if transport is None:
26012601
raise ValueError(

0 commit comments

Comments
 (0)