678
807
bytes = self._request.read_bytes(bytes_wanted)
679
808
_body_decoder.accept_bytes(bytes)
680
809
for body_bytes in iter(_body_decoder.read_next_chunk, None):
681
if 'hpss' in debug.debug_flags:
810
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
682
811
mutter(' %d byte chunk read',
685
814
self._request.finished_reading()
817
def build_server_protocol_three(backing_transport, write_func,
819
request_handler = request.SmartServerRequestHandler(
820
backing_transport, commands=request.request_handlers,
821
root_client_path=root_client_path)
822
responder = ProtocolThreeResponder(write_func)
823
message_handler = message.ConventionalRequestHandler(request_handler, responder)
824
return ProtocolThreeDecoder(message_handler)
827
class ProtocolThreeDecoder(_StatefulDecoder):
829
response_marker = RESPONSE_VERSION_THREE
830
request_marker = REQUEST_VERSION_THREE
832
def __init__(self, message_handler, expect_version_marker=False):
833
_StatefulDecoder.__init__(self)
834
self._has_dispatched = False
836
if expect_version_marker:
837
self.state_accept = self._state_accept_expecting_protocol_version
838
# We're expecting at least the protocol version marker + some
840
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
842
self.state_accept = self._state_accept_expecting_headers
843
self._number_needed_bytes = 4
844
self.decoding_failed = False
845
self.request_handler = self.message_handler = message_handler
847
def accept_bytes(self, bytes):
848
self._number_needed_bytes = None
850
_StatefulDecoder.accept_bytes(self, bytes)
851
except KeyboardInterrupt:
853
except errors.SmartMessageHandlerError, exception:
854
# We do *not* set self.decoding_failed here. The message handler
855
# has raised an error, but the decoder is still able to parse bytes
856
# and determine when this message ends.
857
log_exception_quietly()
858
self.message_handler.protocol_error(exception.exc_value)
859
# The state machine is ready to continue decoding, but the
860
# exception has interrupted the loop that runs the state machine.
861
# So we call accept_bytes again to restart it.
862
self.accept_bytes('')
863
except Exception, exception:
864
# The decoder itself has raised an exception. We cannot continue
866
self.decoding_failed = True
867
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
868
# This happens during normal operation when the client tries a
869
# protocol version the server doesn't understand, so no need to
870
# log a traceback every time.
871
# Note that this can only happen when
872
# expect_version_marker=True, which is only the case on the
876
log_exception_quietly()
877
self.message_handler.protocol_error(exception)
879
def _extract_length_prefixed_bytes(self):
880
if len(self._in_buffer) < 4:
881
# A length prefix by itself is 4 bytes, and we don't even have that
883
raise _NeedMoreBytes(4)
884
(length,) = struct.unpack('!L', self._in_buffer[:4])
885
end_of_bytes = 4 + length
886
if len(self._in_buffer) < end_of_bytes:
887
# We haven't yet read as many bytes as the length-prefix says there
889
raise _NeedMoreBytes(end_of_bytes)
890
# Extract the bytes from the buffer.
891
bytes = self._in_buffer[4:end_of_bytes]
892
self._in_buffer = self._in_buffer[end_of_bytes:]
895
def _extract_prefixed_bencoded_data(self):
896
prefixed_bytes = self._extract_length_prefixed_bytes()
898
decoded = bdecode(prefixed_bytes)
900
raise errors.SmartProtocolError(
901
'Bytes %r not bencoded' % (prefixed_bytes,))
904
def _extract_single_byte(self):
905
if self._in_buffer == '':
906
# The buffer is empty
907
raise _NeedMoreBytes(1)
908
one_byte = self._in_buffer[0]
909
self._in_buffer = self._in_buffer[1:]
912
def _state_accept_expecting_protocol_version(self):
913
needed_bytes = len(MESSAGE_VERSION_THREE) - len(self._in_buffer)
915
# We don't have enough bytes to check if the protocol version
916
# marker is right. But we can check if it is already wrong by
917
# checking that the start of MESSAGE_VERSION_THREE matches what
919
# [In fact, if the remote end isn't bzr we might never receive
920
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
921
# are wrong then we should just raise immediately rather than
923
if not MESSAGE_VERSION_THREE.startswith(self._in_buffer):
924
# We have enough bytes to know the protocol version is wrong
925
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
926
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
927
if not self._in_buffer.startswith(MESSAGE_VERSION_THREE):
928
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
929
self._in_buffer = self._in_buffer[len(MESSAGE_VERSION_THREE):]
930
self.state_accept = self._state_accept_expecting_headers
932
def _state_accept_expecting_headers(self):
933
decoded = self._extract_prefixed_bencoded_data()
934
if type(decoded) is not dict:
935
raise errors.SmartProtocolError(
936
'Header object %r is not a dict' % (decoded,))
937
self.state_accept = self._state_accept_expecting_message_part
939
self.message_handler.headers_received(decoded)
941
raise errors.SmartMessageHandlerError(sys.exc_info())
943
def _state_accept_expecting_message_part(self):
944
message_part_kind = self._extract_single_byte()
945
if message_part_kind == 'o':
946
self.state_accept = self._state_accept_expecting_one_byte
947
elif message_part_kind == 's':
948
self.state_accept = self._state_accept_expecting_structure
949
elif message_part_kind == 'b':
950
self.state_accept = self._state_accept_expecting_bytes
951
elif message_part_kind == 'e':
954
raise errors.SmartProtocolError(
955
'Bad message kind byte: %r' % (message_part_kind,))
957
def _state_accept_expecting_one_byte(self):
958
byte = self._extract_single_byte()
959
self.state_accept = self._state_accept_expecting_message_part
961
self.message_handler.byte_part_received(byte)
963
raise errors.SmartMessageHandlerError(sys.exc_info())
965
def _state_accept_expecting_bytes(self):
966
# XXX: this should not buffer whole message part, but instead deliver
967
# the bytes as they arrive.
968
prefixed_bytes = self._extract_length_prefixed_bytes()
969
self.state_accept = self._state_accept_expecting_message_part
971
self.message_handler.bytes_part_received(prefixed_bytes)
973
raise errors.SmartMessageHandlerError(sys.exc_info())
975
def _state_accept_expecting_structure(self):
976
structure = self._extract_prefixed_bencoded_data()
977
self.state_accept = self._state_accept_expecting_message_part
979
self.message_handler.structure_part_received(structure)
981
raise errors.SmartMessageHandlerError(sys.exc_info())
984
self.unused_data = self._in_buffer
986
self.state_accept = self._state_accept_reading_unused
988
self.message_handler.end_received()
990
raise errors.SmartMessageHandlerError(sys.exc_info())
992
def _state_accept_reading_unused(self):
993
self.unused_data += self._in_buffer
996
def next_read_size(self):
997
if self.state_accept == self._state_accept_reading_unused:
999
elif self.decoding_failed:
1000
# An exception occured while processing this message, probably from
1001
# self.message_handler. We're not sure that this state machine is
1002
# in a consistent state, so just signal that we're done (i.e. give
1006
if self._number_needed_bytes is not None:
1007
return self._number_needed_bytes - len(self._in_buffer)
1009
raise AssertionError("don't know how many bytes are expected!")
1012
class _ProtocolThreeEncoder(object):
1014
response_marker = request_marker = MESSAGE_VERSION_THREE
1016
def __init__(self, write_func):
1017
self._write_func = write_func
1019
def _serialise_offsets(self, offsets):
1020
"""Serialise a readv offset list."""
1022
for start, length in offsets:
1023
txt.append('%d,%d' % (start, length))
1024
return '\n'.join(txt)
1026
def _write_protocol_version(self):
1027
self._write_func(MESSAGE_VERSION_THREE)
1029
def _write_prefixed_bencode(self, structure):
1030
bytes = bencode(structure)
1031
self._write_func(struct.pack('!L', len(bytes)))
1032
self._write_func(bytes)
1034
def _write_headers(self, headers):
1035
self._write_prefixed_bencode(headers)
1037
def _write_structure(self, args):
1038
self._write_func('s')
1041
if type(arg) is unicode:
1042
utf8_args.append(arg.encode('utf8'))
1044
utf8_args.append(arg)
1045
self._write_prefixed_bencode(utf8_args)
1047
def _write_end(self):
1048
self._write_func('e')
1050
def _write_prefixed_body(self, bytes):
1051
self._write_func('b')
1052
self._write_func(struct.pack('!L', len(bytes)))
1053
self._write_func(bytes)
1055
def _write_error_status(self):
1056
self._write_func('oE')
1058
def _write_success_status(self):
1059
self._write_func('oS')
1062
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1064
def __init__(self, write_func):
1065
_ProtocolThreeEncoder.__init__(self, write_func)
1066
self.response_sent = False
1067
self._headers = {'Software version': bzrlib.__version__}
1069
def send_error(self, exception):
1070
if self.response_sent:
1071
raise AssertionError(
1072
"send_error(%s) called, but response already sent."
1074
if isinstance(exception, errors.UnknownSmartMethod):
1075
failure = request.FailedSmartServerResponse(
1076
('UnknownMethod', exception.verb))
1077
self.send_response(failure)
1079
self.response_sent = True
1080
self._write_protocol_version()
1081
self._write_headers(self._headers)
1082
self._write_error_status()
1083
self._write_structure(('error', str(exception)))
1086
def send_response(self, response):
1087
if self.response_sent:
1088
raise AssertionError(
1089
"send_response(%r) called, but response already sent."
1091
self.response_sent = True
1092
self._write_protocol_version()
1093
self._write_headers(self._headers)
1094
if response.is_successful():
1095
self._write_success_status()
1097
self._write_error_status()
1098
self._write_structure(response.args)
1099
if response.body is not None:
1100
self._write_prefixed_body(response.body)
1101
elif response.body_stream is not None:
1102
for chunk in response.body_stream:
1103
self._write_prefixed_body(chunk)
1107
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1109
def __init__(self, medium_request):
1110
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1111
self._medium_request = medium_request
1114
def set_headers(self, headers):
1115
self._headers = headers.copy()
1117
def call(self, *args):
1118
if 'hpss' in debug.debug_flags:
1119
mutter('hpss call: %s', repr(args)[1:-1])
1120
base = getattr(self._medium_request._medium, 'base', None)
1121
if base is not None:
1122
mutter(' (to %s)', base)
1123
self._request_start_time = time.time()
1124
self._write_protocol_version()
1125
self._write_headers(self._headers)
1126
self._write_structure(args)
1128
self._medium_request.finished_writing()
1130
def call_with_body_bytes(self, args, body):
1131
"""Make a remote call of args with body bytes 'body'.
1133
After calling this, call read_response_tuple to find the result out.
1135
if 'hpss' in debug.debug_flags:
1136
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1137
path = getattr(self._medium_request._medium, '_path', None)
1138
if path is not None:
1139
mutter(' (to %s)', path)
1140
mutter(' %d bytes', len(body))
1141
self._request_start_time = time.time()
1142
self._write_protocol_version()
1143
self._write_headers(self._headers)
1144
self._write_structure(args)
1145
self._write_prefixed_body(body)
1147
self._medium_request.finished_writing()
1149
def call_with_body_readv_array(self, args, body):
1150
"""Make a remote call with a readv array.
1152
The body is encoded with one line per readv offset pair. The numbers in
1153
each pair are separated by a comma, and no trailing \n is emitted.
1155
if 'hpss' in debug.debug_flags:
1156
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1157
path = getattr(self._medium_request._medium, '_path', None)
1158
if path is not None:
1159
mutter(' (to %s)', path)
1160
self._request_start_time = time.time()
1161
self._write_protocol_version()
1162
self._write_headers(self._headers)
1163
self._write_structure(args)
1164
readv_bytes = self._serialise_offsets(body)
1165
if 'hpss' in debug.debug_flags:
1166
mutter(' %d bytes in readv request', len(readv_bytes))
1167
self._write_prefixed_body(readv_bytes)
1169
self._medium_request.finished_writing()