From 4e33e1d61fadbbb0c9856e2670f1114a03da9227 Mon Sep 17 00:00:00 2001 From: Jeffery Stager Date: Wed, 10 Aug 2022 22:47:33 -0400 Subject: [PATCH 1/3] Working diff for formatting --- src/Server.zig | 38 +++++-- src/diff.zig | 299 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 329 insertions(+), 8 deletions(-) create mode 100644 src/diff.zig diff --git a/src/Server.zig b/src/Server.zig index bcedadf..9a3641e 100644 --- a/src/Server.zig +++ b/src/Server.zig @@ -18,6 +18,7 @@ const Ast = std.zig.Ast; const tracy = @import("tracy.zig"); const uri_utils = @import("uri.zig"); const data = @import("data/data.zig"); +const diff = @import("diff.zig"); // Server fields @@ -2213,16 +2214,37 @@ fn formattingHandler(server: *Server, writer: anytype, id: types.RequestId, req: .Exited => |code| if (code == 0) { if (std.mem.eql(u8, handle.document.text, stdout_bytes)) return try respondGeneric(writer, id, null_result_response); + var edits = try diff.edits(server.allocator, handle.document.text, stdout_bytes); + defer edits.deinit(); + defer for (edits.items) |item| item.newText.deinit(); + + var text_edits = try std + .ArrayList(types.TextEdit) + .initCapacity(server.allocator, edits.items.len); + defer text_edits.deinit(); + + for (edits.items) |edit| { + try text_edits.append(.{ + .range = edit.range, + .newText = edit.newText.items, + }); + } + + const result = types.ResponseParams{ + .TextEdits = text_edits.items, + }; + return try send(writer, server.arena.allocator(), types.Response{ .id = id, - .result = .{ - .TextEdits = &[1]types.TextEdit{ - .{ - .range = try offsets.documentRange(handle.document, server.offset_encoding), - .newText = stdout_bytes, - }, - }, - }, + .result = result, + // .result = .{ + // .TextEdits = &[1]types.TextEdit{ + // .{ + // .range = try offsets.documentRange(handle.document, server.offset_encoding), + // .newText = stdout_bytes, + // }, + // }, + // }, }); }, else => {}, diff --git a/src/diff.zig b/src/diff.zig new file mode 100644 index 0000000..927b2f9 --- /dev/null +++ b/src/diff.zig @@ -0,0 +1,299 @@ +const std = @import("std"); +const types = @import("types.zig"); + +// pub const Position = struct { line: usize, character: usize }; +// pub const Range = struct { start: Position, end: Position }; +pub const Edit = struct { + range: types.Range, + newText: std.ArrayList(u8), +}; + +const Operation = enum { Deletion, Addition, Nothing }; +const Change = struct { + operation: Operation, + pos: usize, + value: ?u8, +}; + +/// Given two input strings, `a` and `b`, return a list of Edits that +/// describe the changes from `a` to `b` +pub fn edits( + allocator: std.mem.Allocator, + a: []const u8, + b: []const u8, +) !std.ArrayList(Edit) { + // Given the input strings A and B, we skip over the first N characters + // where A[0..N] == B[0..N]. We want to trim the start (and end) of the + // strings that have the same text. This decreases the size of the LCS + // table and makes the diff comparison more efficient + var a_trim: []const u8 = a; + var b_trim: []const u8 = b; + const a_trim_offset = trim_input(&a_trim, &b_trim); + _ = a_trim_offset; + + const rows = a_trim.len + 1; + const cols = b_trim.len + 1; + + var lcs = try Array2D.new(allocator, rows, cols); + defer lcs.deinit(); + + calculate_lcs(&lcs, a_trim, b_trim); + + return try get_changes( + &lcs, + a, + a_trim_offset, + a_trim, + b_trim, + allocator, + ); +} + +fn trim_input(a_out: *[]const u8, b_out: *[]const u8) usize { + if (a_out.len == 0 or b_out.len == 0) return 0; + + var a: []const u8 = a_out.*; + var b: []const u8 = b_out.*; + + // Trim the beginning of the string + var start: usize = 0; + while (start < a.len and start < b.len and a[start] == b[start]) : ({ + start += 1; + }) {} + + // Trim the end of the string + var end: usize = 1; + while (end < a.len and end < b.len and a[a.len - end] == b[b.len - end]) : ({ + end += 1; + }) {} + end -= 1; + + // Sanity check to make sure our range isn't negative + if (start < a.len - end and start < b.len - end) { + a_out.* = a[start .. a.len - end]; + b_out.* = b[start .. b.len - end]; + return start; + } + + a_out.* = a_out.*; + b_out.* = b_out.*; + return 0; +} + +pub const Array2D = struct { + const Self = @This(); + + data: [*]usize, + allocator: std.mem.Allocator, + rows: usize, + cols: usize, + + pub fn new( + allocator: std.mem.Allocator, + rows: usize, + cols: usize, + ) !Self { + const data = try allocator.alloc(usize, rows * cols); + + return Self{ + .data = data.ptr, + .allocator = allocator, + .rows = rows, + .cols = cols, + }; + } + + pub fn deinit(self: *Self) void { + self.allocator.free(self.data[0 .. self.rows * self.cols]); + } + + pub fn get(self: *Self, row: usize, col: usize) *usize { + return @ptrCast(*usize, self.data + (row * self.cols) + col); + } +}; + +fn calculate_lcs( + lcs: *Array2D, + astr: []const u8, + bstr: []const u8, +) void { + const rows = astr.len + 1; + const cols = bstr.len + 1; + + std.mem.set(usize, lcs.data[0 .. rows * cols], 0); + + var i: usize = 1; + while (i < rows) : (i += 1) { + var j: usize = 1; + while (j < cols) : (j += 1) { + if (astr[i - 1] == bstr[j - 1]) { + lcs.get(i, j).* = lcs.get(i - 1, j - 1).* + 1; + } else { + lcs.get(i, j).* = std.math.max( + lcs.get(i - 1, j).*, + lcs.get(i, j - 1).*, + ); + } + } + } +} + +pub fn get_changes( + lcs: *Array2D, + a: []const u8, + a_trim_offset: usize, + a_trim: []const u8, + b_trim: []const u8, + allocator: std.mem.Allocator, +) !std.ArrayList(Edit) { + var changes = try std.ArrayList(Change).initCapacity(allocator, a_trim.len); + defer changes.deinit(); + try recur_changes( + lcs, + &changes, + a_trim, + b_trim, + @intCast(i64, a_trim.len), + @intCast(i64, b_trim.len), + ); + + var groups = std.ArrayList([]Change).init(allocator); + defer groups.deinit(); + var active_change: ?[]Change = null; + for (changes.items) |ch, i| { + switch (ch.operation) { + .Addition, .Deletion => { + if (active_change == null) { + active_change = changes.items[i..]; + } + }, + .Nothing => { + if (active_change) |*ac| { + ac.* = ac.*[0..(i - (changes.items.len - ac.*.len))]; + try groups.append(ac.*); + active_change = null; + } + }, + } + } + if (active_change) |*ac| { + ac.* = ac.*[0..(changes.items.len - (changes.items.len - ac.*.len))]; + try groups.append(ac.*); + } + + var a_lines = std.mem.split(u8, a, "\n"); + std.mem.reverse([]Change, groups.items); + for (groups.items) |group| std.mem.reverse(Change, group); + + var edit_results = std.ArrayList(Edit).init(allocator); + errdefer edit_results.deinit(); + + for (groups.items) |group| { + var range_start = group[0].pos; + var range_len: usize = 0; + var newText = std.ArrayList(u8).init(allocator); + _ = range_start; + _ = range_len; + for (group) |ch| { + switch (ch.operation) { + .Addition => try newText.append(ch.value.?), + .Deletion => range_len += 1, + else => {}, + } + } + var range = try char_pos_to_range( + &a_lines, + a_trim_offset + range_start, + a_trim_offset + range_start + range_len, + ); + a_lines.reset(); + try edit_results.append(Edit{ + .range = range, + .newText = newText, + }); + } + + return edit_results; +} + +fn recur_changes( + lcs: *Array2D, + changes: *std.ArrayList(Change), + a: []const u8, + b: []const u8, + i: i64, + j: i64, +) anyerror!void { + const ii = @intCast(usize, i); + const jj = @intCast(usize, j); + + if (i > 0 and j > 0 and a[ii - 1] == b[jj - 1]) { + try changes.append(.{ + .operation = .Nothing, + .pos = ii - 1, + .value = null, + }); + try recur_changes(lcs, changes, a, b, i - 1, j - 1); + } else if (j > 0 and (i == 0 or lcs.get(ii, jj - 1).* >= lcs.get(ii - 1, jj).*)) { + try changes.append(.{ + .operation = .Addition, + .pos = ii, + .value = b[jj - 1], + }); + try recur_changes(lcs, changes, a, b, i, j - 1); + } else if (i > 0 and (j == 0 or lcs.get(ii, jj - 1).* < lcs.get(ii - 1, jj).*)) { + try changes.append(.{ + .operation = .Deletion, + .pos = ii - 1, + .value = a[ii - 1], + }); + try recur_changes(lcs, changes, a, b, i - 1, j); + } +} + +/// Accept a range that is solely based on buffer/character position and +/// convert it to line number & character position range +fn char_pos_to_range( + lines: *std.mem.SplitIterator(u8), + start: usize, + end: usize, +) !types.Range { + var char_pos: usize = 0; + var line_pos: usize = 0; + var result_start_pos: ?types.Position = null; + var result_end_pos: ?types.Position = null; + + while (lines.next()) |line| : ({ + char_pos += line.len + 1; + line_pos += 1; + }) { + if (start >= char_pos and start <= char_pos + line.len) { + result_start_pos = .{ + .line = @intCast(i64, line_pos), + .character = @intCast(i64, start - char_pos), + }; + } + if (end >= char_pos and end <= char_pos + line.len) { + result_end_pos = .{ + .line = @intCast(i64, line_pos), + .character = @intCast(i64, end - char_pos), + }; + } + } + + if (result_start_pos == null) return error.InvalidRange; + if (result_end_pos == null) { + result_end_pos = types.Position{ + .line = @intCast(i64, line_pos), + .character = @intCast(i64, char_pos), + }; + } + if (result_start_pos == null or result_end_pos == null) { + return error.InvalidRange; + } + + return types.Range{ + .start = result_start_pos.?, + .end = result_end_pos.?, + }; +} From 254353a9f4e76b7caa2e67505bd10a7f0ca018e0 Mon Sep 17 00:00:00 2001 From: Jeffery Stager Date: Sat, 13 Aug 2022 17:16:53 -0400 Subject: [PATCH 2/3] Add fallback to old behavior on diff failure --- src/Server.zig | 39 +++++++++++++++---------- src/diff.zig | 77 +++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 88 insertions(+), 28 deletions(-) diff --git a/src/Server.zig b/src/Server.zig index 9a3641e..044459c 100644 --- a/src/Server.zig +++ b/src/Server.zig @@ -2214,9 +2214,25 @@ fn formattingHandler(server: *Server, writer: anytype, id: types.RequestId, req: .Exited => |code| if (code == 0) { if (std.mem.eql(u8, handle.document.text, stdout_bytes)) return try respondGeneric(writer, id, null_result_response); - var edits = try diff.edits(server.allocator, handle.document.text, stdout_bytes); - defer edits.deinit(); - defer for (edits.items) |item| item.newText.deinit(); + var edits = diff.edits(server.allocator, handle.document.text, stdout_bytes) catch { + // If there was an error trying to diff the text, return the formatted response + // as the new text for the entire range of the document + return try send(writer, server.arena.allocator(), types.Response{ + .id = id, + .result = .{ + .TextEdits = &[1]types.TextEdit{ + .{ + .range = try offsets.documentRange(handle.document, server.offset_encoding), + .newText = stdout_bytes, + }, + }, + }, + }); + }; + defer { + for (edits.items) |item| item.newText.deinit(); + edits.deinit(); + } var text_edits = try std .ArrayList(types.TextEdit) @@ -2234,18 +2250,11 @@ fn formattingHandler(server: *Server, writer: anytype, id: types.RequestId, req: .TextEdits = text_edits.items, }; - return try send(writer, server.arena.allocator(), types.Response{ - .id = id, - .result = result, - // .result = .{ - // .TextEdits = &[1]types.TextEdit{ - // .{ - // .range = try offsets.documentRange(handle.document, server.offset_encoding), - // .newText = stdout_bytes, - // }, - // }, - // }, - }); + return try send( + writer, + server.arena.allocator(), + types.Response{ .id = id, .result = result }, + ); }, else => {}, } diff --git a/src/diff.zig b/src/diff.zig index 927b2f9..c40e745 100644 --- a/src/diff.zig +++ b/src/diff.zig @@ -1,14 +1,18 @@ const std = @import("std"); const types = @import("types.zig"); -// pub const Position = struct { line: usize, character: usize }; -// pub const Range = struct { start: Position, end: Position }; +// This is essentially the same as `types.TextEdit`, but we use an +// ArrayList(u8) here to be able to clean up the memory later on pub const Edit = struct { range: types.Range, newText: std.ArrayList(u8), }; +// Whether the `Change` is an addition, deletion, or no change from the +// original string to the new string const Operation = enum { Deletion, Addition, Nothing }; + +/// A single character difference between two strings const Change = struct { operation: Operation, pos: usize, @@ -68,18 +72,45 @@ fn trim_input(a_out: *[]const u8, b_out: *[]const u8) usize { }) {} end -= 1; - // Sanity check to make sure our range isn't negative - if (start < a.len - end and start < b.len - end) { - a_out.* = a[start .. a.len - end]; - b_out.* = b[start .. b.len - end]; - return start; + var a_start = start; + var a_end = a.len - end; + var b_start = start; + var b_end = b.len - end; + + // In certain situations, the trimmed range can be "negative" where + // `a_start` ends up being after `a_end` in the byte stream. If you + // consider the following inputs: + // a: "xx gg xx" + // b: "xx gg xx" + // + // This will lead to the following calculations: + // a_start: 4 + // a_end: 4 + // b_start: 4 + // b_end: 2 + // + // In negative range situations, we add the absolute value of the + // the negative range's length (`b_start - b_end` in this case) to the + // other range's length (a_end + (b_start - b_end)), and then set the + // negative range end to the negative range start (b_end = b_start) + if (a_start > a_end) { + const difference = a_start - a_end; + a_end = a_start; + b_end += difference; + } + if (b_start > b_end) { + const difference = b_start - b_end; + b_end = b_start; + a_end += difference; } - a_out.* = a_out.*; - b_out.* = b_out.*; - return 0; + a_out.* = a[a_start..a_end]; + b_out.* = b[b_start..b_end]; + + return start; } +/// A 2D array that is addressable as a[row, col] pub const Array2D = struct { const Self = @This(); @@ -112,6 +143,7 @@ pub const Array2D = struct { } }; +/// Build a Longest Common Subsequence table fn calculate_lcs( lcs: *Array2D, astr: []const u8, @@ -122,6 +154,13 @@ fn calculate_lcs( std.mem.set(usize, lcs.data[0 .. rows * cols], 0); + // This approach is a dynamic programming technique to calculate the + // longest common subsequence between two strings, `a` and `b`. We start + // at 1 for `i` and `j` because the first column and first row are always + // set to zero + // + // You can find more information about this at the following url: + // https://en.wikipedia.org/wiki/Longest_common_subsequence_problem var i: usize = 1; while (i < rows) : (i += 1) { var j: usize = 1; @@ -146,6 +185,8 @@ pub fn get_changes( b_trim: []const u8, allocator: std.mem.Allocator, ) !std.ArrayList(Edit) { + // First we get a list of changes between strings at the character level: + // "addition", "deletion", and "no change" for each character var changes = try std.ArrayList(Change).initCapacity(allocator, a_trim.len); defer changes.deinit(); try recur_changes( @@ -157,6 +198,9 @@ pub fn get_changes( @intCast(i64, b_trim.len), ); + // We want to group runs of deletions and additions, and separate them by + // runs of `.Nothing` changes. This will allow us to calculate the + // `TextEdit` ranges var groups = std.ArrayList([]Change).init(allocator); defer groups.deinit(); var active_change: ?[]Change = null; @@ -181,6 +225,8 @@ pub fn get_changes( try groups.append(ac.*); } + // The LCS algorithm works "in reverse", so we're putting everything back + // in ascending order var a_lines = std.mem.split(u8, a, "\n"); std.mem.reverse([]Change, groups.items); for (groups.items) |group| std.mem.reverse(Change, group); @@ -188,6 +234,7 @@ pub fn get_changes( var edit_results = std.ArrayList(Edit).init(allocator); errdefer edit_results.deinit(); + // Convert our grouped changes into `Edit`s for (groups.items) |group| { var range_start = group[0].pos; var range_len: usize = 0; @@ -224,6 +271,10 @@ fn recur_changes( i: i64, j: i64, ) anyerror!void { + // This function recursively works backwards through the LCS table in + // order to figure out what kind of changes took place to transform `a` + // into `b` + const ii = @intCast(usize, i); const jj = @intCast(usize, j); @@ -282,15 +333,15 @@ fn char_pos_to_range( } if (result_start_pos == null) return error.InvalidRange; + + // If we did not find an end position, it is outside the range of the + // string for some reason so clamp it to the string end position if (result_end_pos == null) { result_end_pos = types.Position{ .line = @intCast(i64, line_pos), .character = @intCast(i64, char_pos), }; } - if (result_start_pos == null or result_end_pos == null) { - return error.InvalidRange; - } return types.Range{ .start = result_start_pos.?, From 1fbf1c5427d98197deb01488bac9c49cd76c2833 Mon Sep 17 00:00:00 2001 From: Jeffery Stager Date: Sat, 13 Aug 2022 17:23:39 -0400 Subject: [PATCH 3/3] Minor cleanup in formatting function --- src/Server.zig | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/Server.zig b/src/Server.zig index 044459c..edff367 100644 --- a/src/Server.zig +++ b/src/Server.zig @@ -2234,11 +2234,11 @@ fn formattingHandler(server: *Server, writer: anytype, id: types.RequestId, req: edits.deinit(); } + // Convert from `[]diff.Edit` to `[]types.TextEdit` var text_edits = try std .ArrayList(types.TextEdit) .initCapacity(server.allocator, edits.items.len); defer text_edits.deinit(); - for (edits.items) |edit| { try text_edits.append(.{ .range = edit.range, @@ -2246,14 +2246,13 @@ fn formattingHandler(server: *Server, writer: anytype, id: types.RequestId, req: }); } - const result = types.ResponseParams{ - .TextEdits = text_edits.items, - }; - return try send( writer, server.arena.allocator(), - types.Response{ .id = id, .result = result }, + types.Response{ + .id = id, + .result = .{ .TextEdits = text_edits.items }, + }, ); }, else => {},