/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

  • Committer: Martin Pool
  • Date: 2008-05-08 04:12:06 UTC
  • mto: This revision was merged to the branch mainline in revision 3415.
  • Revision ID: mbp@sourcefrog.net-20080508041206-tkrr8ucmcyrlzkum
Some review cleanups for assertion removal

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2004, 2005 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
 
 
18
# mbp: "you know that thing where cvs gives you conflict markers?"
 
19
# s: "i hate that."
 
20
 
 
21
 
 
22
from bzrlib.errors import CantReprocessAndShowBase
 
23
import bzrlib.patiencediff
 
24
from bzrlib.textfile import check_text_lines
 
25
 
 
26
 
 
27
def intersect(ra, rb):
 
28
    """Given two ranges return the range where they intersect or None.
 
29
 
 
30
    >>> intersect((0, 10), (0, 6))
 
31
    (0, 6)
 
32
    >>> intersect((0, 10), (5, 15))
 
33
    (5, 10)
 
34
    >>> intersect((0, 10), (10, 15))
 
35
    >>> intersect((0, 9), (10, 15))
 
36
    >>> intersect((0, 9), (7, 15))
 
37
    (7, 9)
 
38
    """
 
39
    sa = max(ra[0], rb[0])
 
40
    sb = min(ra[1], rb[1])
 
41
    if sa < sb:
 
42
        return sa, sb
 
43
    else:
 
44
        return None
 
45
 
 
46
 
 
47
def compare_range(a, astart, aend, b, bstart, bend):
 
48
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
 
49
    """
 
50
    if (aend-astart) != (bend-bstart):
 
51
        return False
 
52
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
 
53
        if a[ia] != b[ib]:
 
54
            return False
 
55
    else:
 
56
        return True
 
57
        
 
58
 
 
59
 
 
60
 
 
61
class Merge3(object):
 
62
    """3-way merge of texts.
 
63
 
 
64
    Given BASE, OTHER, THIS, tries to produce a combined text
 
65
    incorporating the changes from both BASE->OTHER and BASE->THIS.
 
66
    All three will typically be sequences of lines."""
 
67
    def __init__(self, base, a, b, is_cherrypick=False):
 
68
        check_text_lines(base)
 
69
        check_text_lines(a)
 
70
        check_text_lines(b)
 
71
        self.base = base
 
72
        self.a = a
 
73
        self.b = b
 
74
        self.is_cherrypick = is_cherrypick
 
75
 
 
76
    def merge_lines(self,
 
77
                    name_a=None,
 
78
                    name_b=None,
 
79
                    name_base=None,
 
80
                    start_marker='<<<<<<<',
 
81
                    mid_marker='=======',
 
82
                    end_marker='>>>>>>>',
 
83
                    base_marker=None,
 
84
                    reprocess=False):
 
85
        """Return merge in cvs-like form.
 
86
        """
 
87
        newline = '\n'
 
88
        if len(self.a) > 0:
 
89
            if self.a[0].endswith('\r\n'):
 
90
                newline = '\r\n'
 
91
            elif self.a[0].endswith('\r'):
 
92
                newline = '\r'
 
93
        if base_marker and reprocess:
 
94
            raise CantReprocessAndShowBase()
 
95
        if name_a:
 
96
            start_marker = start_marker + ' ' + name_a
 
97
        if name_b:
 
98
            end_marker = end_marker + ' ' + name_b
 
99
        if name_base and base_marker:
 
100
            base_marker = base_marker + ' ' + name_base
 
101
        merge_regions = self.merge_regions()
 
102
        if reprocess is True:
 
103
            merge_regions = self.reprocess_merge_regions(merge_regions)
 
104
        for t in merge_regions:
 
105
            what = t[0]
 
106
            if what == 'unchanged':
 
107
                for i in range(t[1], t[2]):
 
108
                    yield self.base[i]
 
109
            elif what == 'a' or what == 'same':
 
110
                for i in range(t[1], t[2]):
 
111
                    yield self.a[i]
 
112
            elif what == 'b':
 
113
                for i in range(t[1], t[2]):
 
114
                    yield self.b[i]
 
115
            elif what == 'conflict':
 
116
                yield start_marker + newline
 
117
                for i in range(t[3], t[4]):
 
118
                    yield self.a[i]
 
119
                if base_marker is not None:
 
120
                    yield base_marker + newline
 
121
                    for i in range(t[1], t[2]):
 
122
                        yield self.base[i]
 
123
                yield mid_marker + newline
 
124
                for i in range(t[5], t[6]):
 
125
                    yield self.b[i]
 
126
                yield end_marker + newline
 
127
            else:
 
128
                raise ValueError(what)
 
129
 
 
130
    def merge_annotated(self):
 
131
        """Return merge with conflicts, showing origin of lines.
 
132
 
 
133
        Most useful for debugging merge.        
 
134
        """
 
135
        for t in self.merge_regions():
 
136
            what = t[0]
 
137
            if what == 'unchanged':
 
138
                for i in range(t[1], t[2]):
 
139
                    yield 'u | ' + self.base[i]
 
140
            elif what == 'a' or what == 'same':
 
141
                for i in range(t[1], t[2]):
 
142
                    yield what[0] + ' | ' + self.a[i]
 
143
            elif what == 'b':
 
144
                for i in range(t[1], t[2]):
 
145
                    yield 'b | ' + self.b[i]
 
146
            elif what == 'conflict':
 
147
                yield '<<<<\n'
 
148
                for i in range(t[3], t[4]):
 
149
                    yield 'A | ' + self.a[i]
 
150
                yield '----\n'
 
151
                for i in range(t[5], t[6]):
 
152
                    yield 'B | ' + self.b[i]
 
153
                yield '>>>>\n'
 
154
            else:
 
155
                raise ValueError(what)
 
156
 
 
157
    def merge_groups(self):
 
158
        """Yield sequence of line groups.  Each one is a tuple:
 
159
 
 
160
        'unchanged', lines
 
161
             Lines unchanged from base
 
162
 
 
163
        'a', lines
 
164
             Lines taken from a
 
165
 
 
166
        'same', lines
 
167
             Lines taken from a (and equal to b)
 
168
 
 
169
        'b', lines
 
170
             Lines taken from b
 
171
 
 
172
        'conflict', base_lines, a_lines, b_lines
 
173
             Lines from base were changed to either a or b and conflict.
 
174
        """
 
175
        for t in self.merge_regions():
 
176
            what = t[0]
 
177
            if what == 'unchanged':
 
178
                yield what, self.base[t[1]:t[2]]
 
179
            elif what == 'a' or what == 'same':
 
180
                yield what, self.a[t[1]:t[2]]
 
181
            elif what == 'b':
 
182
                yield what, self.b[t[1]:t[2]]
 
183
            elif what == 'conflict':
 
184
                yield (what,
 
185
                       self.base[t[1]:t[2]],
 
186
                       self.a[t[3]:t[4]],
 
187
                       self.b[t[5]:t[6]])
 
188
            else:
 
189
                raise ValueError(what)
 
190
 
 
191
    def merge_regions(self):
 
192
        """Return sequences of matching and conflicting regions.
 
193
 
 
194
        This returns tuples, where the first value says what kind we
 
195
        have:
 
196
 
 
197
        'unchanged', start, end
 
198
             Take a region of base[start:end]
 
199
 
 
200
        'same', astart, aend
 
201
             b and a are different from base but give the same result
 
202
 
 
203
        'a', start, end
 
204
             Non-clashing insertion from a[start:end]
 
205
 
 
206
        Method is as follows:
 
207
 
 
208
        The two sequences align only on regions which match the base
 
209
        and both descendents.  These are found by doing a two-way diff
 
210
        of each one against the base, and then finding the
 
211
        intersections between those regions.  These "sync regions"
 
212
        are by definition unchanged in both and easily dealt with.
 
213
 
 
214
        The regions in between can be in any of three cases:
 
215
        conflicted, or changed on only one side.
 
216
        """
 
217
 
 
218
        # section a[0:ia] has been disposed of, etc
 
219
        iz = ia = ib = 0
 
220
        
 
221
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
 
222
            matchlen = zend - zmatch
 
223
            len_a = amatch - ia
 
224
            len_b = bmatch - ib
 
225
            len_base = zmatch - iz
 
226
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
 
227
 
 
228
            if len_a or len_b:
 
229
                # try to avoid actually slicing the lists
 
230
                same = compare_range(self.a, ia, amatch,
 
231
                                     self.b, ib, bmatch)
 
232
 
 
233
                if same:
 
234
                    yield 'same', ia, amatch
 
235
                else:
 
236
                    equal_a = compare_range(self.a, ia, amatch,
 
237
                                            self.base, iz, zmatch)
 
238
                    equal_b = compare_range(self.b, ib, bmatch,
 
239
                                            self.base, iz, zmatch)
 
240
                    if equal_a and not equal_b:
 
241
                        yield 'b', ib, bmatch
 
242
                    elif equal_b and not equal_a:
 
243
                        yield 'a', ia, amatch
 
244
                    elif not equal_a and not equal_b:
 
245
                        if self.is_cherrypick:
 
246
                            for node in self._refine_cherrypick_conflict(
 
247
                                                    iz, zmatch, ia, amatch,
 
248
                                                    ib, bmatch):
 
249
                                yield node
 
250
                        else:
 
251
                            yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
 
252
                    else:
 
253
                        raise AssertionError("can't handle a=b=base but unmatched")
 
254
 
 
255
                ia = amatch
 
256
                ib = bmatch
 
257
            iz = zmatch
 
258
 
 
259
            # if the same part of the base was deleted on both sides
 
260
            # that's OK, we can just skip it.
 
261
 
 
262
            if matchlen > 0:
 
263
                yield 'unchanged', zmatch, zend
 
264
                iz = zend
 
265
                ia = aend
 
266
                ib = bend
 
267
 
 
268
    def _refine_cherrypick_conflict(self, zstart, zend, astart, aend, bstart, bend):
 
269
        """When cherrypicking b => a, ignore matches with b and base."""
 
270
        # Do not emit regions which match, only regions which do not match
 
271
        matches = bzrlib.patiencediff.PatienceSequenceMatcher(None,
 
272
            self.base[zstart:zend], self.b[bstart:bend]).get_matching_blocks()
 
273
        last_base_idx = 0
 
274
        last_b_idx = 0
 
275
        last_b_idx = 0
 
276
        yielded_a = False
 
277
        for base_idx, b_idx, match_len in matches:
 
278
            conflict_z_len = base_idx - last_base_idx
 
279
            conflict_b_len = b_idx - last_b_idx
 
280
            if conflict_b_len == 0: # There are no lines in b which conflict,
 
281
                                    # so skip it
 
282
                pass
 
283
            else:
 
284
                if yielded_a:
 
285
                    yield ('conflict',
 
286
                           zstart + last_base_idx, zstart + base_idx,
 
287
                           aend, aend, bstart + last_b_idx, bstart + b_idx)
 
288
                else:
 
289
                    # The first conflict gets the a-range
 
290
                    yielded_a = True
 
291
                    yield ('conflict', zstart + last_base_idx, zstart +
 
292
                    base_idx,
 
293
                           astart, aend, bstart + last_b_idx, bstart + b_idx)
 
294
            last_base_idx = base_idx + match_len
 
295
            last_b_idx = b_idx + match_len
 
296
        if last_base_idx != zend - zstart or last_b_idx != bend - bstart:
 
297
            if yielded_a:
 
298
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
 
299
                       aend, aend, bstart + last_b_idx, bstart + b_idx)
 
300
            else:
 
301
                # The first conflict gets the a-range
 
302
                yielded_a = True
 
303
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
 
304
                       astart, aend, bstart + last_b_idx, bstart + b_idx)
 
305
        if not yielded_a:
 
306
            yield ('conflict', zstart, zend, astart, aend, bstart, bend)
 
307
 
 
308
    def reprocess_merge_regions(self, merge_regions):
 
309
        """Where there are conflict regions, remove the agreed lines.
 
310
 
 
311
        Lines where both A and B have made the same changes are 
 
312
        eliminated.
 
313
        """
 
314
        for region in merge_regions:
 
315
            if region[0] != "conflict":
 
316
                yield region
 
317
                continue
 
318
            type, iz, zmatch, ia, amatch, ib, bmatch = region
 
319
            a_region = self.a[ia:amatch]
 
320
            b_region = self.b[ib:bmatch]
 
321
            matches = bzrlib.patiencediff.PatienceSequenceMatcher(
 
322
                    None, a_region, b_region).get_matching_blocks()
 
323
            next_a = ia
 
324
            next_b = ib
 
325
            for region_ia, region_ib, region_len in matches[:-1]:
 
326
                region_ia += ia
 
327
                region_ib += ib
 
328
                reg = self.mismatch_region(next_a, region_ia, next_b,
 
329
                                           region_ib)
 
330
                if reg is not None:
 
331
                    yield reg
 
332
                yield 'same', region_ia, region_len+region_ia
 
333
                next_a = region_ia + region_len
 
334
                next_b = region_ib + region_len
 
335
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
 
336
            if reg is not None:
 
337
                yield reg
 
338
 
 
339
    @staticmethod
 
340
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
 
341
        if next_a < region_ia or next_b < region_ib:
 
342
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
 
343
 
 
344
    def find_sync_regions(self):
 
345
        """Return a list of sync regions, where both descendents match the base.
 
346
 
 
347
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
 
348
        always a zero-length sync region at the end of all the files.
 
349
        """
 
350
 
 
351
        ia = ib = 0
 
352
        amatches = bzrlib.patiencediff.PatienceSequenceMatcher(
 
353
                None, self.base, self.a).get_matching_blocks()
 
354
        bmatches = bzrlib.patiencediff.PatienceSequenceMatcher(
 
355
                None, self.base, self.b).get_matching_blocks()
 
356
        len_a = len(amatches)
 
357
        len_b = len(bmatches)
 
358
 
 
359
        sl = []
 
360
 
 
361
        while ia < len_a and ib < len_b:
 
362
            abase, amatch, alen = amatches[ia]
 
363
            bbase, bmatch, blen = bmatches[ib]
 
364
 
 
365
            # there is an unconflicted block at i; how long does it
 
366
            # extend?  until whichever one ends earlier.
 
367
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
 
368
            if i:
 
369
                intbase = i[0]
 
370
                intend = i[1]
 
371
                intlen = intend - intbase
 
372
                asub = amatch + (intbase - abase)
 
373
                bsub = bmatch + (intbase - bbase)
 
374
                aend = asub + intlen
 
375
                bend = bsub + intlen
 
376
                sl.append((intbase, intend,
 
377
                           asub, aend,
 
378
                           bsub, bend))
 
379
            # advance whichever one ends first in the base text
 
380
            if (abase + alen) < (bbase + blen):
 
381
                ia += 1
 
382
            else:
 
383
                ib += 1
 
384
            
 
385
        intbase = len(self.base)
 
386
        abase = len(self.a)
 
387
        bbase = len(self.b)
 
388
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
 
389
 
 
390
        return sl
 
391
 
 
392
    def find_unconflicted(self):
 
393
        """Return a list of ranges in base that are not conflicted."""
 
394
        am = bzrlib.patiencediff.PatienceSequenceMatcher(
 
395
                None, self.base, self.a).get_matching_blocks()
 
396
        bm = bzrlib.patiencediff.PatienceSequenceMatcher(
 
397
                None, self.base, self.b).get_matching_blocks()
 
398
 
 
399
        unc = []
 
400
 
 
401
        while am and bm:
 
402
            # there is an unconflicted block at i; how long does it
 
403
            # extend?  until whichever one ends earlier.
 
404
            a1 = am[0][0]
 
405
            a2 = a1 + am[0][2]
 
406
            b1 = bm[0][0]
 
407
            b2 = b1 + bm[0][2]
 
408
            i = intersect((a1, a2), (b1, b2))
 
409
            if i:
 
410
                unc.append(i)
 
411
 
 
412
            if a2 < b2:
 
413
                del am[0]
 
414
            else:
 
415
                del bm[0]
 
416
                
 
417
        return unc
 
418
 
 
419
 
 
420
def main(argv):
 
421
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
 
422
    a = file(argv[1], 'rt').readlines()
 
423
    base = file(argv[2], 'rt').readlines()
 
424
    b = file(argv[3], 'rt').readlines()
 
425
 
 
426
    m3 = Merge3(base, a, b)
 
427
 
 
428
    #for sr in m3.find_sync_regions():
 
429
    #    print sr
 
430
 
 
431
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
 
432
    sys.stdout.writelines(m3.merge_annotated())
 
433
 
 
434
 
 
435
if __name__ == '__main__':
 
436
    import sys
 
437
    sys.exit(main(sys.argv))