diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4523df0..b4d734b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,11 @@ jobs: FILE=zig-cache/bin/zls if test -f "$FILE"; then - exit 0 + if zig build test; then + exit 0 + else + exit 1 + fi else exit 1 fi diff --git a/build.zig b/build.zig index 21e64c9..4c40acf 100644 --- a/build.zig +++ b/build.zig @@ -131,7 +131,15 @@ pub fn build(b: *std.build.Builder) !void { const run_step = b.step("run", "Run the app"); run_step.dependOn(&run_cmd.step); - // const configure_step = std.build.Step.init("config", b.allocator, config); const configure_step = b.step("config", "Configure zls"); configure_step.makeFn = config; + + const test_step = b.step("test", "Run all the tests"); + test_step.dependOn(&exe.step); + + var unit_tests = b.addTest("tests/unit_tests.zig"); + unit_tests.addPackage(.{ .name = "analysis", .path = "src/analysis.zig" }); + unit_tests.addPackage(.{ .name = "types", .path = "src/types.zig" }); + unit_tests.setBuildMode(.Debug); + test_step.dependOn(&unit_tests.step); } diff --git a/src/analysis.zig b/src/analysis.zig index 1b76d39..9199822 100644 --- a/src/analysis.zig +++ b/src/analysis.zig @@ -284,6 +284,92 @@ fn resolveReturnType(analysis_ctx: *AnalysisContext, fn_decl: *ast.Node.FnProto) }; } +/// Resolves the child type of an optional type +fn resolveUnwrapOptionalType(analysis_ctx: *AnalysisContext, opt: *ast.Node) ?*ast.Node { + if (opt.cast(ast.Node.PrefixOp)) |prefix_op| { + if (prefix_op.op == .OptionalType) { + return resolveTypeOfNode(analysis_ctx, prefix_op.rhs); + } + } + + return null; +} + +/// Resolves the child type of a defer type +fn resolveDerefType(analysis_ctx: *AnalysisContext, deref: *ast.Node) ?*ast.Node { + if (deref.cast(ast.Node.PrefixOp)) |pop| { + if (pop.op == .PtrType) { + const op_token_id = analysis_ctx.tree().token_ids[pop.op_token]; + switch (op_token_id) { + .Asterisk => return resolveTypeOfNode(analysis_ctx, pop.rhs), + .LBracket, .AsteriskAsterisk => return null, + else => unreachable, + } + } + } + return null; +} + +fn makeSliceType(analysis_ctx: *AnalysisContext, child_type: *ast.Node) ?*ast.Node { + // TODO: Better values for fields, better way to do this? + var slice_type = analysis_ctx.arena.allocator.create(ast.Node.PrefixOp) catch return null; + slice_type.* = .{ + .op_token = child_type.firstToken(), + .op = .{ + .SliceType = .{ + .allowzero_token = null, + .align_info = null, + .const_token = null, + .volatile_token = null, + .sentinel = null, + }, + }, + .rhs = child_type, + }; + + return &slice_type.base; +} + +/// Resolves bracket access type (both slicing and array access) +fn resolveBracketAccessType( + analysis_ctx: *AnalysisContext, + lhs: *ast.Node, + rhs: enum { Single, Range }, +) ?*ast.Node { + if (lhs.cast(ast.Node.PrefixOp)) |pop| { + switch (pop.op) { + .SliceType => { + if (rhs == .Single) return resolveTypeOfNode(analysis_ctx, pop.rhs); + return lhs; + }, + .ArrayType => { + if (rhs == .Single) return resolveTypeOfNode(analysis_ctx, pop.rhs); + return makeSliceType(analysis_ctx, pop.rhs); + }, + .PtrType => { + if (pop.rhs.cast(std.zig.ast.Node.PrefixOp)) |child_pop| { + switch (child_pop.op) { + .ArrayType => { + if (rhs == .Single) { + return resolveTypeOfNode(analysis_ctx, child_pop.rhs); + } + return lhs; + }, + else => {}, + } + } + }, + else => {}, + } + } + return null; +} + +/// Called to remove one level of pointerness before a field access +fn resolveFieldAccessLhsType(analysis_ctx: *AnalysisContext, lhs: *ast.Node) *ast.Node { + return resolveDerefType(analysis_ctx, lhs) orelse lhs; +} + /// Resolves the type of a node pub fn resolveTypeOfNode(analysis_ctx: *AnalysisContext, node: *ast.Node) ?*ast.Node { switch (node.id) { @@ -318,6 +404,28 @@ pub fn resolveTypeOfNode(analysis_ctx: *AnalysisContext, node: *ast.Node) ?*ast. else => decl, }; }, + .SuffixOp => { + const suffix_op = node.cast(ast.Node.SuffixOp).?; + switch (suffix_op.op) { + .UnwrapOptional => { + const left_type = resolveTypeOfNode(analysis_ctx, suffix_op.lhs) orelse return null; + return resolveUnwrapOptionalType(analysis_ctx, left_type); + }, + .Deref => { + const left_type = resolveTypeOfNode(analysis_ctx, suffix_op.lhs) orelse return null; + return resolveDerefType(analysis_ctx, left_type); + }, + .ArrayAccess => { + const left_type = resolveTypeOfNode(analysis_ctx, suffix_op.lhs) orelse return null; + return resolveBracketAccessType(analysis_ctx, left_type, .Single); + }, + .Slice => { + const left_type = resolveTypeOfNode(analysis_ctx, suffix_op.lhs) orelse return null; + return resolveBracketAccessType(analysis_ctx, left_type, .Range); + }, + else => {}, + } + }, .InfixOp => { const infix_op = node.cast(ast.Node.InfixOp).?; switch (infix_op.op) { @@ -327,25 +435,31 @@ pub fn resolveTypeOfNode(analysis_ctx: *AnalysisContext, node: *ast.Node) ?*ast. var rhs_str = nodeToString(analysis_ctx.tree(), infix_op.rhs) orelse return null; // Use the analysis context temporary arena to store the rhs string. rhs_str = std.mem.dupe(&analysis_ctx.arena.allocator, u8, rhs_str) catch return null; - const left = resolveTypeOfNode(analysis_ctx, infix_op.lhs) orelse return null; - const child = getChild(analysis_ctx.tree(), left, rhs_str) orelse return null; + + // If we are accessing a pointer type, remove one pointerness level :) + const left_type = resolveFieldAccessLhsType( + analysis_ctx, + resolveTypeOfNode(analysis_ctx, infix_op.lhs) orelse return null, + ); + + const child = getChild(analysis_ctx.tree(), left_type, rhs_str) orelse return null; return resolveTypeOfNode(analysis_ctx, child); }, + .UnwrapOptional => { + const left_type = resolveTypeOfNode(analysis_ctx, infix_op.lhs) orelse return null; + return resolveUnwrapOptionalType(analysis_ctx, left_type); + }, else => {}, } }, .PrefixOp => { const prefix_op = node.cast(ast.Node.PrefixOp).?; switch (prefix_op.op) { - .SliceType, .ArrayType => return node, - .PtrType => { - const op_token_id = analysis_ctx.tree().token_ids[prefix_op.op_token]; - switch (op_token_id) { - .Asterisk => return resolveTypeOfNode(analysis_ctx, prefix_op.rhs), - .LBracket, .AsteriskAsterisk => return null, - else => unreachable, - } - }, + .SliceType, + .ArrayType, + .OptionalType, + .PtrType, + => return node, .Try => { const rhs_type = resolveTypeOfNode(analysis_ctx, prefix_op.rhs) orelse return null; switch (rhs_type.id) { @@ -433,37 +547,50 @@ pub fn collectImports(import_arr: *std.ArrayList([]const u8), tree: *ast.Tree) ! pub fn getFieldAccessTypeNode( analysis_ctx: *AnalysisContext, tokenizer: *std.zig.Tokenizer, - line_length: usize, ) ?*ast.Node { var current_node = analysis_ctx.in_container; - var current_container = analysis_ctx.in_container; while (true) { - var next = tokenizer.next(); - switch (next.id) { - .Eof => return current_node, + const tok = tokenizer.next(); + switch (tok.id) { + .Eof => return resolveFieldAccessLhsType(analysis_ctx, current_node), .Identifier => { - if (getChildOfSlice(analysis_ctx.tree(), analysis_ctx.scope_nodes, tokenizer.buffer[next.loc.start..next.loc.end])) |child| { - if (resolveTypeOfNode(analysis_ctx, child)) |node_type| { - current_node = node_type; + if (getChildOfSlice(analysis_ctx.tree(), analysis_ctx.scope_nodes, tokenizer.buffer[tok.loc.start..tok.loc.end])) |child| { + if (resolveTypeOfNode(analysis_ctx, child)) |child_type| { + current_node = child_type; } else return null; } else return null; }, .Period => { - var after_period = tokenizer.next(); - if (after_period.id == .Eof or after_period.id == .Comma) { - return current_node; - } else if (after_period.id == .Identifier) { - // TODO: This works for now, maybe we should filter based on the partial identifier ourselves? - if (after_period.loc.end == line_length) return current_node; + const after_period = tokenizer.next(); + switch (after_period.id) { + .Eof => return resolveFieldAccessLhsType(analysis_ctx, current_node), + .Identifier => { + if (after_period.loc.end == tokenizer.buffer.len) return resolveFieldAccessLhsType(analysis_ctx, current_node); - if (getChild(analysis_ctx.tree(), current_node, tokenizer.buffer[after_period.loc.start..after_period.loc.end])) |child| { - if (resolveTypeOfNode(analysis_ctx, child)) |child_type| { + current_node = resolveFieldAccessLhsType(analysis_ctx, current_node); + if (getChild(analysis_ctx.tree(), current_node, tokenizer.buffer[after_period.loc.start..after_period.loc.end])) |child| { + if (resolveTypeOfNode(analysis_ctx, child)) |child_type| { + current_node = child_type; + } else return null; + } else return null; + }, + .QuestionMark => { + if (resolveUnwrapOptionalType(analysis_ctx, current_node)) |child_type| { current_node = child_type; } else return null; - } else return null; + }, + else => { + std.debug.warn("Unrecognized token {} after period.\n", .{after_period.id}); + return null; + }, } }, + .PeriodAsterisk => { + if (resolveDerefType(analysis_ctx, current_node)) |child_type| { + current_node = child_type; + } else return null; + }, .LParen => { switch (current_node.id) { .FnProto => { @@ -472,7 +599,7 @@ pub fn getFieldAccessTypeNode( current_node = ret; // Skip to the right paren var paren_count: usize = 1; - next = tokenizer.next(); + var next = tokenizer.next(); while (next.id != .Eof) : (next = tokenizer.next()) { if (next.id == .RParen) { paren_count -= 1; @@ -481,30 +608,46 @@ pub fn getFieldAccessTypeNode( paren_count += 1; } } else return null; - } else { - return null; - } + } else return null; }, else => {}, } }, - .Keyword_const, .Keyword_var => { - next = tokenizer.next(); - if (next.id == .Identifier) { - next = tokenizer.next(); - if (next.id != .Equal) return null; - continue; - } + .LBracket => { + var brack_count: usize = 1; + var next = tokenizer.next(); + var is_range = false; + while (next.id != .Eof) : (next = tokenizer.next()) { + if (next.id == .RBracket) { + brack_count -= 1; + if (brack_count == 0) break; + } else if (next.id == .LBracket) { + brack_count += 1; + } else if (next.id == .Ellipsis2 and brack_count == 1) { + is_range = true; + } + } else return null; + + if (resolveBracketAccessType( + analysis_ctx, + current_node, + if (is_range) .Range else .Single, + )) |child_type| { + current_node = child_type; + } else return null; + }, + else => { + std.debug.warn("Unimplemented token: {}\n", .{tok.id}); + return null; }, - else => std.debug.warn("Not implemented; {}\n", .{next.id}), } - if (current_node.id == .ContainerDecl or current_node.id == .Root) { - current_container = current_node; + if (current_node.cast(ast.Node.ContainerDecl)) |container_decl| { + analysis_ctx.onContainer(container_decl) catch return null; } } - return current_node; + return resolveFieldAccessLhsType(analysis_ctx, current_node); } pub fn isNodePublic(tree: *ast.Tree, node: *ast.Node) bool { @@ -763,3 +906,137 @@ pub fn getImportStr(tree: *ast.Tree, source_index: usize) ?[]const u8 { } return null; } + +const types = @import("types.zig"); +pub const SourceRange = std.zig.Token.Loc; + +pub const PositionContext = union(enum) { + builtin: SourceRange, + comment, + string_literal: SourceRange, + field_access: SourceRange, + var_access: SourceRange, + enum_literal, + other, + empty, + + pub fn range(self: PositionContext) ?SourceRange { + return switch (self) { + .builtin => |r| r, + .comment => null, + .string_literal => |r| r, + .field_access => |r| r, + .var_access => |r| r, + .enum_literal => null, + .other => null, + .empty => null, + }; + } +}; + +const StackState = struct { + ctx: PositionContext, + stack_id: enum { Paren, Bracket, Global }, +}; + +fn peek(arr: *std.ArrayList(StackState)) !*StackState { + if (arr.items.len == 0) { + try arr.append(.{ .ctx = .empty, .stack_id = .Global }); + } + return &arr.items[arr.items.len - 1]; +} + +fn tokenRangeAppend(prev: SourceRange, token: std.zig.Token) SourceRange { + return .{ + .start = prev.start, + .end = token.loc.end, + }; +} + +pub fn documentPositionContext(allocator: *std.mem.Allocator, document: types.TextDocument, position: types.Position) !PositionContext { + const line = try document.getLine(@intCast(usize, position.line)); + const pos_char = @intCast(usize, position.character) + 1; + const idx = if (pos_char > line.len) line.len else pos_char; + + var arena = std.heap.ArenaAllocator.init(allocator); + defer arena.deinit(); + + var tokenizer = std.zig.Tokenizer.init(line[0..idx]); + var stack = try std.ArrayList(StackState).initCapacity(&arena.allocator, 8); + + while (true) { + const tok = tokenizer.next(); + // Early exits. + switch (tok.id) { + .Invalid, .Invalid_ampersands => { + // Single '@' do not return a builtin token so we check this on our own. + if (line[idx - 1] == '@') { + return PositionContext{ + .builtin = .{ + .start = idx - 1, + .end = idx, + }, + }; + } + return .other; + }, + .LineComment, .DocComment, .ContainerDocComment => return .comment, + .Eof => break, + else => {}, + } + + // State changes + var curr_ctx = try peek(&stack); + switch (tok.id) { + .StringLiteral, .MultilineStringLiteralLine => curr_ctx.ctx = .{ .string_literal = tok.loc }, + .Identifier => switch (curr_ctx.ctx) { + .empty => curr_ctx.ctx = .{ .var_access = tok.loc }, + else => {}, + }, + .Builtin => switch (curr_ctx.ctx) { + .empty => curr_ctx.ctx = .{ .builtin = tok.loc }, + else => {}, + }, + .Period, .PeriodAsterisk => switch (curr_ctx.ctx) { + .empty => curr_ctx.ctx = .enum_literal, + .enum_literal => curr_ctx.ctx = .empty, + .field_access => {}, + .other => {}, + else => curr_ctx.ctx = .{ + .field_access = tokenRangeAppend(curr_ctx.ctx.range().?, tok), + }, + }, + .QuestionMark => switch (curr_ctx.ctx) { + .field_access => {}, + else => curr_ctx.ctx = .empty, + }, + .LParen => try stack.append(.{ .ctx = .empty, .stack_id = .Paren }), + .LBracket => try stack.append(.{ .ctx = .empty, .stack_id = .Bracket }), + .RParen => { + _ = stack.pop(); + if (curr_ctx.stack_id != .Paren) { + (try peek(&stack)).ctx = .empty; + } + }, + .RBracket => { + _ = stack.pop(); + if (curr_ctx.stack_id != .Bracket) { + (try peek(&stack)).ctx = .empty; + } + }, + else => curr_ctx.ctx = .empty, + } + + switch (curr_ctx.ctx) { + .field_access => |r| curr_ctx.ctx = .{ + .field_access = tokenRangeAppend(r, tok), + }, + else => {}, + } + } + + return block: { + if (stack.popOrNull()) |state| break :block state.ctx; + break :block .empty; + }; +} diff --git a/src/main.zig b/src/main.zig index 9f1cfc4..ce88b63 100644 --- a/src/main.zig +++ b/src/main.zig @@ -273,6 +273,20 @@ fn nodeToCompletion( }); }, .PrefixOp => { + const prefix_op = node.cast(std.zig.ast.Node.PrefixOp).?; + switch (prefix_op.op) { + .ArrayType, .SliceType => {}, + .PtrType => { + if (prefix_op.rhs.cast(std.zig.ast.Node.PrefixOp)) |child_pop| { + switch (child_pop.op) { + .ArrayType => {}, + else => return, + } + } else return; + }, + else => return, + } + try list.append(.{ .label = "len", .kind = .Field, @@ -389,7 +403,7 @@ fn hoverDefinitionGlobal(id: i64, pos_index: usize, handle: *DocumentStore.Handl fn getSymbolFieldAccess( analysis_ctx: *DocumentStore.AnalysisContext, position: types.Position, - line_start_idx: usize, + range: analysis.SourceRange, config: Config, ) !?*std.zig.ast.Node { const pos_index = try analysis_ctx.handle.document.positionToIndex(position); @@ -397,12 +411,10 @@ fn getSymbolFieldAccess( if (name.len == 0) return null; const line = try analysis_ctx.handle.document.getLine(@intCast(usize, position.line)); - var tokenizer = std.zig.Tokenizer.init(line[line_start_idx..]); + var tokenizer = std.zig.Tokenizer.init(line[range.start..range.end]); - const line_length = @ptrToInt(name.ptr) - @ptrToInt(line.ptr) + name.len - line_start_idx; name = try std.mem.dupe(&analysis_ctx.arena.allocator, u8, name); - - if (analysis.getFieldAccessTypeNode(analysis_ctx, &tokenizer, line_length)) |container| { + if (analysis.getFieldAccessTypeNode(analysis_ctx, &tokenizer)) |container| { return analysis.getChild(analysis_ctx.tree(), container, name); } return null; @@ -412,14 +424,14 @@ fn gotoDefinitionFieldAccess( id: i64, handle: *DocumentStore.Handle, position: types.Position, - line_start_idx: usize, + range: analysis.SourceRange, config: Config, ) !void { var arena = std.heap.ArenaAllocator.init(allocator); defer arena.deinit(); var analysis_ctx = try document_store.analysisContext(handle, &arena, try handle.document.positionToIndex(position), config.zig_lib_path); - const decl = (try getSymbolFieldAccess(&analysis_ctx, position, line_start_idx, config)) orelse return try respondGeneric(id, null_result_response); + const decl = (try getSymbolFieldAccess(&analysis_ctx, position, range, config)) orelse return try respondGeneric(id, null_result_response); return try gotoDefinitionSymbol(id, &analysis_ctx, decl); } @@ -427,14 +439,14 @@ fn hoverDefinitionFieldAccess( id: i64, handle: *DocumentStore.Handle, position: types.Position, - line_start_idx: usize, + range: analysis.SourceRange, config: Config, ) !void { var arena = std.heap.ArenaAllocator.init(allocator); defer arena.deinit(); var analysis_ctx = try document_store.analysisContext(handle, &arena, try handle.document.positionToIndex(position), config.zig_lib_path); - const decl = (try getSymbolFieldAccess(&analysis_ctx, position, line_start_idx, config)) orelse return try respondGeneric(id, null_result_response); + const decl = (try getSymbolFieldAccess(&analysis_ctx, position, range, config)) orelse return try respondGeneric(id, null_result_response); return try hoverSymbol(id, &analysis_ctx, decl); } @@ -490,7 +502,7 @@ fn completeGlobal(id: i64, pos_index: usize, handle: *DocumentStore.Handle, conf }); } -fn completeFieldAccess(id: i64, handle: *DocumentStore.Handle, position: types.Position, line_start_idx: usize, config: Config) !void { +fn completeFieldAccess(id: i64, handle: *DocumentStore.Handle, position: types.Position, range: analysis.SourceRange, config: Config) !void { var arena = std.heap.ArenaAllocator.init(allocator); defer arena.deinit(); @@ -498,10 +510,9 @@ fn completeFieldAccess(id: i64, handle: *DocumentStore.Handle, position: types.P var completions = std.ArrayList(types.CompletionItem).init(&arena.allocator); const line = try handle.document.getLine(@intCast(usize, position.line)); - var tokenizer = std.zig.Tokenizer.init(line[line_start_idx..]); - const line_length = line.len - line_start_idx; + var tokenizer = std.zig.Tokenizer.init(line[range.start..range.end]); - if (analysis.getFieldAccessTypeNode(&analysis_ctx, &tokenizer, line_length)) |node| { + if (analysis.getFieldAccessTypeNode(&analysis_ctx, &tokenizer)) |node| { try nodeToCompletion(&completions, &analysis_ctx, handle, node, config); } try send(types.Response{ @@ -550,136 +561,6 @@ const builtin_completions = block: { }; }; -const PositionContext = union(enum) { - builtin, - comment, - string_literal, - field_access: usize, - var_access, - other, - empty, -}; - -const token_separators = [_]u8{ - ' ', '\t', '(', ')', '[', ']', - '{', '}', '|', '=', '!', ';', - ',', '?', ':', '%', '+', '*', - '>', '<', '~', '-', '/', '&', -}; - -fn documentPositionContext(doc: types.TextDocument, pos_index: usize) PositionContext { - // First extract the whole current line up to the cursor. - var curr_position = pos_index; - while (curr_position > 0) : (curr_position -= 1) { - if (doc.text[curr_position - 1] == '\n') break; - } - - var line = doc.text[curr_position .. pos_index + 1]; - // Strip any leading whitespace. - var skipped_ws: usize = 0; - while (skipped_ws < line.len and (line[skipped_ws] == ' ' or line[skipped_ws] == '\t')) : (skipped_ws += 1) {} - if (skipped_ws >= line.len) return .empty; - line = line[skipped_ws..]; - - // Quick exit for comment lines and multi line string literals. - if (line.len >= 2 and line[0] == '/' and line[1] == '/') - return .comment; - if (line.len >= 2 and line[0] == '\\' and line[1] == '\\') - return .string_literal; - - // TODO: This does not detect if we are in a string literal over multiple lines. - // Find out what context we are in. - // Go over the current line character by character - // and determine the context. - curr_position = 0; - var expr_start: usize = skipped_ws; - - // std.debug.warn("{}", .{curr_position}); - - if (pos_index != 0 and doc.text[pos_index - 1] == ')') - return .{ .field_access = expr_start }; - - var new_token = true; - var context: PositionContext = .other; - var string_pop_ctx: PositionContext = .other; - while (curr_position < line.len) : (curr_position += 1) { - const c = line[curr_position]; - const next_char = if (curr_position < line.len - 1) line[curr_position + 1] else null; - - if (context != .string_literal and c == '"') { - expr_start = curr_position + skipped_ws; - context = .string_literal; - continue; - } - - if (context == .string_literal) { - // Skip over escaped quotes - if (c == '\\' and next_char != null and next_char.? == '"') { - curr_position += 1; - } else if (c == '"') { - context = string_pop_ctx; - string_pop_ctx = .other; - new_token = true; - } - - continue; - } - - if (c == '/' and next_char != null and next_char.? == '/') { - context = .comment; - break; - } - - if (std.mem.indexOfScalar(u8, &token_separators, c) != null) { - expr_start = curr_position + skipped_ws + 1; - new_token = true; - context = .other; - continue; - } - - if (c == '.' and (!new_token or context == .string_literal)) { - new_token = true; - if (next_char != null and next_char.? == '.') continue; - context = .{ .field_access = expr_start }; - continue; - } - - if (new_token) { - const access_ctx: PositionContext = if (context == .field_access) - .{ .field_access = expr_start } - else - .var_access; - - new_token = false; - - if (c == '_' or std.ascii.isAlpha(c)) { - context = access_ctx; - } else if (c == '@') { - // This checks for @"..." identifiers by controlling - // the context the string will set after it is over. - if (next_char != null and next_char.? == '"') { - string_pop_ctx = access_ctx; - } - context = .builtin; - } else { - context = .other; - } - continue; - } - - if (context == .field_access or context == .var_access or context == .builtin) { - if (c != '_' and !std.ascii.isAlNum(c)) { - context = .other; - } - continue; - } - - context = .other; - } - - return context; -} - fn loadConfig(folder_path: []const u8) ?Config { var folder = std.fs.cwd().openDir(folder_path, .{}) catch return null; defer folder.close(); @@ -851,7 +732,7 @@ fn processJsonRpc(parser: *std.json.Parser, json: []const u8, config: Config) !v }; if (pos.character >= 0) { const pos_index = try handle.document.positionToIndex(pos); - const pos_context = documentPositionContext(handle.document, pos_index); + const pos_context = try analysis.documentPositionContext(allocator, handle.document, pos); const this_config = configFromUriOr(uri, config); switch (pos_context) { @@ -865,7 +746,7 @@ fn processJsonRpc(parser: *std.json.Parser, json: []const u8, config: Config) !v }, }), .var_access, .empty => try completeGlobal(id, pos_index, handle, this_config), - .field_access => |start_idx| try completeFieldAccess(id, handle, pos, start_idx, this_config), + .field_access => |range| try completeFieldAccess(id, handle, pos, range, this_config), else => try respondGeneric(id, no_completions_response), } } else { @@ -894,7 +775,7 @@ fn processJsonRpc(parser: *std.json.Parser, json: []const u8, config: Config) !v }; if (pos.character >= 0) { const pos_index = try handle.document.positionToIndex(pos); - const pos_context = documentPositionContext(handle.document, pos_index); + const pos_context = try analysis.documentPositionContext(allocator, handle.document, pos); switch (pos_context) { .var_access => try gotoDefinitionGlobal( @@ -903,16 +784,18 @@ fn processJsonRpc(parser: *std.json.Parser, json: []const u8, config: Config) !v handle, configFromUriOr(uri, config), ), - .field_access => |start_idx| try gotoDefinitionFieldAccess( + .field_access => |range| try gotoDefinitionFieldAccess( id, handle, pos, - start_idx, + range, configFromUriOr(uri, config), ), .string_literal => try gotoDefinitionString(id, pos_index, handle, config), else => try respondGeneric(id, null_result_response), } + } else { + try respondGeneric(id, null_result_response); } } else if (std.mem.eql(u8, method, "textDocument/hover")) { const document = params.getValue("textDocument").?.Object; @@ -930,7 +813,7 @@ fn processJsonRpc(parser: *std.json.Parser, json: []const u8, config: Config) !v }; if (pos.character >= 0) { const pos_index = try handle.document.positionToIndex(pos); - const pos_context = documentPositionContext(handle.document, pos_index); + const pos_context = try analysis.documentPositionContext(allocator, handle.document, pos); switch (pos_context) { .var_access => try hoverDefinitionGlobal( @@ -939,15 +822,17 @@ fn processJsonRpc(parser: *std.json.Parser, json: []const u8, config: Config) !v handle, configFromUriOr(uri, config), ), - .field_access => |start_idx| try hoverDefinitionFieldAccess( + .field_access => |range| try hoverDefinitionFieldAccess( id, handle, pos, - start_idx, + range, configFromUriOr(uri, config), ), else => try respondGeneric(id, null_result_response), } + } else { + try respondGeneric(id, null_result_response); } } else if (root.Object.getValue("id")) |_| { std.debug.warn("Method with return value not implemented: {}", .{method}); diff --git a/tests/unit_tests.zig b/tests/unit_tests.zig new file mode 100644 index 0000000..86e62e9 --- /dev/null +++ b/tests/unit_tests.zig @@ -0,0 +1,102 @@ +const analysis = @import("analysis"); +const types = @import("types"); + +const std = @import("std"); + +const allocator = std.testing.allocator; + +fn makeDocument(uri: []const u8, text: []const u8) !types.TextDocument { + const mem = try allocator.alloc(u8, text.len); + std.mem.copy(u8, mem, text); + + return types.TextDocument{ + .uri = uri, + .mem = mem, + .text = mem[0..], + }; +} + +fn freeDocument(doc: types.TextDocument) void { + allocator.free(doc.text); +} + +fn makeUnnamedDocument(text: []const u8) !types.TextDocument { + return try makeDocument("test", text); +} + +fn testContext(comptime line: []const u8, comptime tag: var, comptime range: ?[]const u8) !void { + const cursor_idx = comptime std.mem.indexOf(u8, line, "").?; + const final_line = line[0..cursor_idx] ++ line[cursor_idx + "".len ..]; + + const doc = try makeUnnamedDocument(final_line); + defer freeDocument(doc); + + const ctx = try analysis.documentPositionContext(allocator, doc, types.Position{ + .line = 0, + .character = @intCast(i64, cursor_idx - 1), + }); + + if (std.meta.activeTag(ctx) != tag) { + std.debug.warn("Expected tag {}, got {}\n", .{ tag, std.meta.activeTag(ctx) }); + return error.DifferentTag; + } + + if (ctx.range()) |ctx_range| { + if (range == null) { + std.debug.warn("Expected null range, got `{}`\n", .{ + doc.text[ctx_range.start..ctx_range.end], + }); + } else { + const range_start = comptime std.mem.indexOf(u8, final_line, range.?).?; + const range_end = range_start + range.?.len; + + if (range_start != ctx_range.start or range_end != ctx_range.end) { + std.debug.warn("Expected range `{}` ({}..{}), got `{}` ({}..{})\n", .{ + doc.text[range_start..range_end], range_start, range_end, + doc.text[ctx_range.start..ctx_range.end], ctx_range.start, ctx_range.end, + }); + return error.DifferentRange; + } + } + } else if (range != null) { + std.debug.warn("Unexpected null range\n", .{}); + return error.DifferentRange; + } +} + +test "documentPositionContext" { + try testContext( + \\const this_var = identifier; + , + .var_access, + "id", + ); + + try testContext( + \\if (displ.*.?.c.*[0].@"a" == foo) { + , + .field_access, + "displ.*.?.c.*[0].", + ); + + try testContext( + \\const arr = std.ArrayList(SomeStruct(a, b, c, d)).init(allocator); + , + .field_access, + "std.ArrayList(SomeStruct(a, b, c, d)).in", + ); + + try testContext( + \\try erroringFn(the_first[arg], second[a..]); + , + .empty, + null, + ); + + try testContext( + \\ fn add(lhf: lself, rhs: rself) !Se { + , + .var_access, + "Se", + ); +}