14
14
# along with this program; if not, write to the Free Software
15
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
from cStringIO import StringIO
20
from urllib.request import (
24
except ImportError: # python < 3
32
from bzrlib.smart import medium, protocol
33
from bzrlib.tests import http_server
34
from bzrlib.transport import (
37
from ..sixish import (
40
from ..bzr.smart import (
43
from . import http_server
44
from ..transport import chroot
40
47
class HTTPServerWithSmarts(http_server.HttpServer):
51
58
class SmartRequestHandler(http_server.TestingHTTPRequestHandler):
52
59
"""Extend TestingHTTPRequestHandler to support smart client POSTs.
54
XXX: This duplicates a fair bit of the logic in bzrlib.transport.http.wsgi.
61
XXX: This duplicates a fair bit of the logic in breezy.transport.http.wsgi.
58
65
"""Hand the request off to a smart server instance."""
59
backing = get_transport(self.server.test_case_server._home_dir)
66
backing = transport.get_transport_from_path(
67
self.server.test_case_server._home_dir)
60
68
chroot_server = chroot.ChrootServer(backing)
61
69
chroot_server.start_server()
63
t = get_transport(chroot_server.get_url())
71
t = transport.get_transport_from_url(chroot_server.get_url())
64
72
self.do_POST_inner(t)
66
74
chroot_server.stop_server()
83
91
# we have to stop early due to error, but we would also have to use the
84
92
# HTTP trailer facility which may not be widely available.
85
93
request_bytes = self.rfile.read(data_length)
86
protocol_factory, unused_bytes = medium._get_protocol_factory_for_bytes(
88
out_buffer = StringIO()
94
protocol_factory, unused_bytes = (
95
medium._get_protocol_factory_for_bytes(request_bytes))
96
out_buffer = BytesIO()
89
97
smart_protocol_request = protocol_factory(t, out_buffer.write, '/')
90
98
# Perhaps there should be a SmartServerHTTPMedium that takes care of
91
99
# feeding the bytes in the http request to the smart_protocol_request,
106
114
one. This will currently fail if the primary transport is not
107
115
backed by regular disk files.
118
# These attributes can be overriden or parametrized by daughter clasess if
119
# needed, but must exist so that the create_transport_readonly_server()
120
# method (or any method creating an http(s) server) can propagate it.
121
_protocol_version = None
122
_url_protocol = 'http'
110
125
super(TestCaseWithWebserver, self).setUp()
111
126
self.transport_readonly_server = http_server.HttpServer
128
def create_transport_readonly_server(self):
129
server = self.transport_readonly_server(
130
protocol_version=self._protocol_version)
131
server._url_protocol = self._url_protocol
114
135
class TestCaseWithTwoWebservers(TestCaseWithWebserver):
115
136
"""A support class providing readonly urls on two servers that are http://.
136
161
self.start_server(self.__secondary_server)
137
162
return self.__secondary_server
164
def get_secondary_url(self, relpath=None):
165
base = self.get_secondary_server().get_url()
166
return self._adjust_url(base, relpath)
168
def get_secondary_transport(self, relpath=None):
169
t = transport.get_transport_from_url(self.get_secondary_url(relpath))
170
self.assertTrue(t.is_readonly())
140
174
class ProxyServer(http_server.HttpServer):
141
175
"""A proxy test server for http transports."""
184
218
def redirect_to(self, host, port):
185
219
"""Redirect all requests to a specific host:port"""
186
220
self.redirections = [('(.*)',
187
r'http://%s:%s\1' % (host, port) ,
221
r'http://%s:%s\1' % (host, port),
190
224
def is_redirected(self, path):
200
234
for (rsource, rtarget, rcode) in self.redirections:
201
target, match = re.subn(rsource, rtarget, path)
235
target, match = re.subn(rsource, rtarget, path, count=1)
204
break # The first match wins
238
break # The first match wins
207
241
return code, target
210
244
class TestCaseWithRedirectedWebserver(TestCaseWithTwoWebservers):
211
"""A support class providing redirections from one server to another.
213
We set up two webservers to allows various tests involving
215
The 'old' server is redirected to the 'new' server.
218
def create_transport_secondary_server(self):
219
"""Create the secondary server redirecting to the primary server"""
220
new = self.get_readonly_server()
221
redirecting = HTTPServerRedirecting()
222
redirecting.redirect_to(new.host, new.port)
226
super(TestCaseWithRedirectedWebserver, self).setUp()
227
# The redirections will point to the new server
228
self.new_server = self.get_readonly_server()
229
# The requests to the old server will be redirected
230
self.old_server = self.get_secondary_server()
245
"""A support class providing redirections from one server to another.
247
We set up two webservers to allows various tests involving
249
The 'old' server is redirected to the 'new' server.
253
super(TestCaseWithRedirectedWebserver, self).setUp()
254
# The redirections will point to the new server
255
self.new_server = self.get_readonly_server()
256
# The requests to the old server will be redirected to the new server
257
self.old_server = self.get_secondary_server()
259
def create_transport_secondary_server(self):
260
"""Create the secondary server redirecting to the primary server"""
261
new = self.get_readonly_server()
262
redirecting = HTTPServerRedirecting(
263
protocol_version=self._protocol_version)
264
redirecting.redirect_to(new.host, new.port)
265
redirecting._url_protocol = self._url_protocol
268
def get_old_url(self, relpath=None):
269
base = self.old_server.get_url()
270
return self._adjust_url(base, relpath)
272
def get_old_transport(self, relpath=None):
273
t = transport.get_transport_from_url(self.get_old_url(relpath))
274
self.assertTrue(t.is_readonly())
277
def get_new_url(self, relpath=None):
278
base = self.new_server.get_url()
279
return self._adjust_url(base, relpath)
281
def get_new_transport(self, relpath=None):
282
t = transport.get_transport_from_url(self.get_new_url(relpath))
283
self.assertTrue(t.is_readonly())
233
287
class AuthRequestHandler(http_server.TestingHTTPRequestHandler):
243
297
# - auth_header_recv: the header received containing auth
244
298
# - auth_error_code: the error code to indicate auth required
300
def _require_authentication(self):
301
# Note that we must update test_case_server *before*
302
# sending the error or the client may try to read it
303
# before we have sent the whole error back.
304
tcs = self.server.test_case_server
305
tcs.auth_required_errors += 1
306
self.send_response(tcs.auth_error_code)
307
self.send_header_auth_reqed()
308
# We do not send a body
309
self.send_header('Content-Length', '0')
246
313
def do_GET(self):
247
314
if self.authorized():
248
315
return http_server.TestingHTTPRequestHandler.do_GET(self)
250
# Note that we must update test_case_server *before*
251
# sending the error or the client may try to read it
252
# before we have sent the whole error back.
253
tcs = self.server.test_case_server
254
tcs.auth_required_errors += 1
255
self.send_response(tcs.auth_error_code)
256
self.send_header_auth_reqed()
257
# We do not send a body
258
self.send_header('Content-Length', '0')
317
return self._require_authentication()
320
if self.authorized():
321
return http_server.TestingHTTPRequestHandler.do_HEAD(self)
323
return self._require_authentication()
263
326
class BasicAuthRequestHandler(AuthRequestHandler):
273
336
scheme, raw_auth = auth_header.split(' ', 1)
274
337
if scheme.lower() == tcs.auth_scheme:
275
user, password = raw_auth.decode('base64').split(':')
276
return tcs.authorized(user, password)
338
user, password = base64.b64decode(raw_auth).split(b':')
339
return tcs.authorized(user.decode('ascii'),
340
password.decode('ascii'))
410
474
# Recalculate the response_digest to compare with the one
411
475
# sent by the client
412
A1 = '%s:%s:%s' % (user, realm, password)
413
A2 = '%s:%s' % (command, auth['uri'])
415
H = lambda x: osutils.md5(x).hexdigest()
416
KD = lambda secret, data: H("%s:%s" % (secret, data))
476
A1 = ('%s:%s:%s' % (user, realm, password)).encode('utf-8')
477
A2 = ('%s:%s' % (command, auth['uri'])).encode('utf-8')
480
return osutils.md5(x).hexdigest()
482
def KD(secret, data):
483
return H(("%s:%s" % (secret, data)).encode('utf-8'))
418
485
nonce_count = int(auth['nc'], 16)