diff --git a/src/Server.zig b/src/Server.zig index c38fe89..9478735 100644 --- a/src/Server.zig +++ b/src/Server.zig @@ -18,6 +18,7 @@ const Ast = std.zig.Ast; const tracy = @import("tracy.zig"); const uri_utils = @import("uri.zig"); const data = @import("data/data.zig"); +const diff = @import("diff.zig"); // Server fields @@ -2228,17 +2229,46 @@ fn formattingHandler(server: *Server, writer: anytype, id: types.RequestId, req: .Exited => |code| if (code == 0) { if (std.mem.eql(u8, handle.document.text, stdout_bytes)) return try respondGeneric(writer, id, null_result_response); - return try send(writer, server.arena.allocator(), types.Response{ - .id = id, - .result = .{ - .TextEdits = &[1]types.TextEdit{ - .{ - .range = try offsets.documentRange(handle.document, server.offset_encoding), - .newText = stdout_bytes, + var edits = diff.edits(server.allocator, handle.document.text, stdout_bytes) catch { + // If there was an error trying to diff the text, return the formatted response + // as the new text for the entire range of the document + return try send(writer, server.arena.allocator(), types.Response{ + .id = id, + .result = .{ + .TextEdits = &[1]types.TextEdit{ + .{ + .range = try offsets.documentRange(handle.document, server.offset_encoding), + .newText = stdout_bytes, + }, }, }, + }); + }; + defer { + for (edits.items) |item| item.newText.deinit(); + edits.deinit(); + } + + // Convert from `[]diff.Edit` to `[]types.TextEdit` + var text_edits = try std + .ArrayList(types.TextEdit) + .initCapacity(server.allocator, edits.items.len); + defer text_edits.deinit(); + for (edits.items) |edit| { + try text_edits.append(.{ + .range = edit.range, + .newText = edit.newText.items, + }); + } + + return try send( + writer, + server.arena.allocator(), + types.Response{ + .id = id, + .result = .{ .TextEdits = text_edits.items }, }, - }); + ); }, else => {}, } diff --git a/src/diff.zig b/src/diff.zig new file mode 100644 index 0000000..c40e745 --- /dev/null +++ b/src/diff.zig @@ -0,0 +1,350 @@ +const std = @import("std"); +const types = @import("types.zig"); + +// This is essentially the same as `types.TextEdit`, but we use an +// ArrayList(u8) here to be able to clean up the memory later on +pub const Edit = struct { + range: types.Range, + newText: std.ArrayList(u8), +}; + +// Whether the `Change` is an addition, deletion, or no change from the +// original string to the new string +const Operation = enum { Deletion, Addition, Nothing }; + +/// A single character difference between two strings +const Change = struct { + operation: Operation, + pos: usize, + value: ?u8, +}; + +/// Given two input strings, `a` and `b`, return a list of Edits that +/// describe the changes from `a` to `b` +pub fn edits( + allocator: std.mem.Allocator, + a: []const u8, + b: []const u8, +) !std.ArrayList(Edit) { + // Given the input strings A and B, we skip over the first N characters + // where A[0..N] == B[0..N]. We want to trim the start (and end) of the + // strings that have the same text. This decreases the size of the LCS + // table and makes the diff comparison more efficient + var a_trim: []const u8 = a; + var b_trim: []const u8 = b; + const a_trim_offset = trim_input(&a_trim, &b_trim); + _ = a_trim_offset; + + const rows = a_trim.len + 1; + const cols = b_trim.len + 1; + + var lcs = try Array2D.new(allocator, rows, cols); + defer lcs.deinit(); + + calculate_lcs(&lcs, a_trim, b_trim); + + return try get_changes( + &lcs, + a, + a_trim_offset, + a_trim, + b_trim, + allocator, + ); +} + +fn trim_input(a_out: *[]const u8, b_out: *[]const u8) usize { + if (a_out.len == 0 or b_out.len == 0) return 0; + + var a: []const u8 = a_out.*; + var b: []const u8 = b_out.*; + + // Trim the beginning of the string + var start: usize = 0; + while (start < a.len and start < b.len and a[start] == b[start]) : ({ + start += 1; + }) {} + + // Trim the end of the string + var end: usize = 1; + while (end < a.len and end < b.len and a[a.len - end] == b[b.len - end]) : ({ + end += 1; + }) {} + end -= 1; + + var a_start = start; + var a_end = a.len - end; + var b_start = start; + var b_end = b.len - end; + + // In certain situations, the trimmed range can be "negative" where + // `a_start` ends up being after `a_end` in the byte stream. If you + // consider the following inputs: + // a: "xx gg xx" + // b: "xx gg xx" + // + // This will lead to the following calculations: + // a_start: 4 + // a_end: 4 + // b_start: 4 + // b_end: 2 + // + // In negative range situations, we add the absolute value of the + // the negative range's length (`b_start - b_end` in this case) to the + // other range's length (a_end + (b_start - b_end)), and then set the + // negative range end to the negative range start (b_end = b_start) + if (a_start > a_end) { + const difference = a_start - a_end; + a_end = a_start; + b_end += difference; + } + if (b_start > b_end) { + const difference = b_start - b_end; + b_end = b_start; + a_end += difference; + } + + a_out.* = a[a_start..a_end]; + b_out.* = b[b_start..b_end]; + + return start; +} + +/// A 2D array that is addressable as a[row, col] +pub const Array2D = struct { + const Self = @This(); + + data: [*]usize, + allocator: std.mem.Allocator, + rows: usize, + cols: usize, + + pub fn new( + allocator: std.mem.Allocator, + rows: usize, + cols: usize, + ) !Self { + const data = try allocator.alloc(usize, rows * cols); + + return Self{ + .data = data.ptr, + .allocator = allocator, + .rows = rows, + .cols = cols, + }; + } + + pub fn deinit(self: *Self) void { + self.allocator.free(self.data[0 .. self.rows * self.cols]); + } + + pub fn get(self: *Self, row: usize, col: usize) *usize { + return @ptrCast(*usize, self.data + (row * self.cols) + col); + } +}; + +/// Build a Longest Common Subsequence table +fn calculate_lcs( + lcs: *Array2D, + astr: []const u8, + bstr: []const u8, +) void { + const rows = astr.len + 1; + const cols = bstr.len + 1; + + std.mem.set(usize, lcs.data[0 .. rows * cols], 0); + + // This approach is a dynamic programming technique to calculate the + // longest common subsequence between two strings, `a` and `b`. We start + // at 1 for `i` and `j` because the first column and first row are always + // set to zero + // + // You can find more information about this at the following url: + // https://en.wikipedia.org/wiki/Longest_common_subsequence_problem + var i: usize = 1; + while (i < rows) : (i += 1) { + var j: usize = 1; + while (j < cols) : (j += 1) { + if (astr[i - 1] == bstr[j - 1]) { + lcs.get(i, j).* = lcs.get(i - 1, j - 1).* + 1; + } else { + lcs.get(i, j).* = std.math.max( + lcs.get(i - 1, j).*, + lcs.get(i, j - 1).*, + ); + } + } + } +} + +pub fn get_changes( + lcs: *Array2D, + a: []const u8, + a_trim_offset: usize, + a_trim: []const u8, + b_trim: []const u8, + allocator: std.mem.Allocator, +) !std.ArrayList(Edit) { + // First we get a list of changes between strings at the character level: + // "addition", "deletion", and "no change" for each character + var changes = try std.ArrayList(Change).initCapacity(allocator, a_trim.len); + defer changes.deinit(); + try recur_changes( + lcs, + &changes, + a_trim, + b_trim, + @intCast(i64, a_trim.len), + @intCast(i64, b_trim.len), + ); + + // We want to group runs of deletions and additions, and separate them by + // runs of `.Nothing` changes. This will allow us to calculate the + // `TextEdit` ranges + var groups = std.ArrayList([]Change).init(allocator); + defer groups.deinit(); + var active_change: ?[]Change = null; + for (changes.items) |ch, i| { + switch (ch.operation) { + .Addition, .Deletion => { + if (active_change == null) { + active_change = changes.items[i..]; + } + }, + .Nothing => { + if (active_change) |*ac| { + ac.* = ac.*[0..(i - (changes.items.len - ac.*.len))]; + try groups.append(ac.*); + active_change = null; + } + }, + } + } + if (active_change) |*ac| { + ac.* = ac.*[0..(changes.items.len - (changes.items.len - ac.*.len))]; + try groups.append(ac.*); + } + + // The LCS algorithm works "in reverse", so we're putting everything back + // in ascending order + var a_lines = std.mem.split(u8, a, "\n"); + std.mem.reverse([]Change, groups.items); + for (groups.items) |group| std.mem.reverse(Change, group); + + var edit_results = std.ArrayList(Edit).init(allocator); + errdefer edit_results.deinit(); + + // Convert our grouped changes into `Edit`s + for (groups.items) |group| { + var range_start = group[0].pos; + var range_len: usize = 0; + var newText = std.ArrayList(u8).init(allocator); + _ = range_start; + _ = range_len; + for (group) |ch| { + switch (ch.operation) { + .Addition => try newText.append(ch.value.?), + .Deletion => range_len += 1, + else => {}, + } + } + var range = try char_pos_to_range( + &a_lines, + a_trim_offset + range_start, + a_trim_offset + range_start + range_len, + ); + a_lines.reset(); + try edit_results.append(Edit{ + .range = range, + .newText = newText, + }); + } + + return edit_results; +} + +fn recur_changes( + lcs: *Array2D, + changes: *std.ArrayList(Change), + a: []const u8, + b: []const u8, + i: i64, + j: i64, +) anyerror!void { + // This function recursively works backwards through the LCS table in + // order to figure out what kind of changes took place to transform `a` + // into `b` + + const ii = @intCast(usize, i); + const jj = @intCast(usize, j); + + if (i > 0 and j > 0 and a[ii - 1] == b[jj - 1]) { + try changes.append(.{ + .operation = .Nothing, + .pos = ii - 1, + .value = null, + }); + try recur_changes(lcs, changes, a, b, i - 1, j - 1); + } else if (j > 0 and (i == 0 or lcs.get(ii, jj - 1).* >= lcs.get(ii - 1, jj).*)) { + try changes.append(.{ + .operation = .Addition, + .pos = ii, + .value = b[jj - 1], + }); + try recur_changes(lcs, changes, a, b, i, j - 1); + } else if (i > 0 and (j == 0 or lcs.get(ii, jj - 1).* < lcs.get(ii - 1, jj).*)) { + try changes.append(.{ + .operation = .Deletion, + .pos = ii - 1, + .value = a[ii - 1], + }); + try recur_changes(lcs, changes, a, b, i - 1, j); + } +} + +/// Accept a range that is solely based on buffer/character position and +/// convert it to line number & character position range +fn char_pos_to_range( + lines: *std.mem.SplitIterator(u8), + start: usize, + end: usize, +) !types.Range { + var char_pos: usize = 0; + var line_pos: usize = 0; + var result_start_pos: ?types.Position = null; + var result_end_pos: ?types.Position = null; + + while (lines.next()) |line| : ({ + char_pos += line.len + 1; + line_pos += 1; + }) { + if (start >= char_pos and start <= char_pos + line.len) { + result_start_pos = .{ + .line = @intCast(i64, line_pos), + .character = @intCast(i64, start - char_pos), + }; + } + if (end >= char_pos and end <= char_pos + line.len) { + result_end_pos = .{ + .line = @intCast(i64, line_pos), + .character = @intCast(i64, end - char_pos), + }; + } + } + + if (result_start_pos == null) return error.InvalidRange; + + // If we did not find an end position, it is outside the range of the + // string for some reason so clamp it to the string end position + if (result_end_pos == null) { + result_end_pos = types.Position{ + .line = @intCast(i64, line_pos), + .character = @intCast(i64, char_pos), + }; + } + + return types.Range{ + .start = result_start_pos.?, + .end = result_end_pos.?, + }; +}