diff --git a/java/mvnw b/java/mvnw old mode 100644 new mode 100755 diff --git a/java/src/main/java/com/github/copilot/rpc/ToolDefer.java b/java/src/main/java/com/github/copilot/rpc/ToolDefer.java index 1955f02ec..ba888ca97 100644 --- a/java/src/main/java/com/github/copilot/rpc/ToolDefer.java +++ b/java/src/main/java/com/github/copilot/rpc/ToolDefer.java @@ -21,6 +21,27 @@ */ public enum ToolDefer { + /** + * No deferral preference set. This is an annotation-only sentinel used + * as the default for {@code @CopilotTool(defer = ToolDefer.NONE)}. + *

+ * This constant must not be passed to {@link ToolDefinition} factory + * methods. The annotation processor and {@code ToolDefinition.fromObject()} + * must map {@code NONE} to a {@code null} field reference so that + * {@code @JsonInclude(NON_NULL)} on {@link ToolDefinition} omits the + * {@code defer} key from the JSON-RPC wire payload entirely (matching the + * nullable/optional semantics used by all other SDKs). + *

+ * As a secondary safety net, {@link #getValue()} returns {@code null} for this + * constant. Note that this alone does not cause field omission: if a + * non-null {@code NONE} reference reaches a {@link ToolDefinition} field, + * Jackson's {@code @JsonInclude(NON_NULL)} will still emit the field (as + * {@code "defer": null}) because the field reference itself is not null. The + * primary protection is mapping {@code NONE} to a null field reference before + * constructing the {@link ToolDefinition}. + */ + NONE(""), + /** The tool can be deferred and surfaced through tool search. */ AUTO("auto"), @@ -35,12 +56,18 @@ public enum ToolDefer { /** * Returns the JSON value for this deferral mode. + *

+ * Returns {@code null} for {@link #NONE} to avoid emitting an empty string + * ({@code "defer": ""}) if this sentinel accidentally reaches serialization. + * With {@code null}, the worst-case leak becomes {@code "defer": null} rather + * than an invalid empty string. * - * @return the string value used in JSON serialization + * @return the string value used in JSON serialization, or {@code null} for + * {@link #NONE} */ @JsonValue public String getValue() { - return value; + return this == NONE ? null : value; } /** diff --git a/java/src/main/java/com/github/copilot/rpc/ToolDefinition.java b/java/src/main/java/com/github/copilot/rpc/ToolDefinition.java index 23b7fe30d..b3fa2bc53 100644 --- a/java/src/main/java/com/github/copilot/rpc/ToolDefinition.java +++ b/java/src/main/java/com/github/copilot/rpc/ToolDefinition.java @@ -4,11 +4,21 @@ package com.github.copilot.rpc; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import com.github.copilot.CopilotExperimental; /** * Defines a tool that can be invoked by the AI assistant. @@ -163,4 +173,101 @@ public static ToolDefinition createWithDefer(String name, String description, Ma ToolHandler handler, ToolDefer defer) { return new ToolDefinition(name, description, schema, handler, null, null, defer); } + + /** + * Discovers tool definitions from an object whose methods are annotated with + * {@code @CopilotTool}. Requires that the {@code CopilotToolProcessor} + * annotation processor ran at compile time (generating the + * {@code $$CopilotToolMeta} companion class). + * + * @param instance + * the object containing {@code @CopilotTool}-annotated methods + * @return list of tool definitions with working invocation handlers + * @throws IllegalStateException + * if the generated {@code $$CopilotToolMeta} class is not found + * (annotation processor did not run) + * @since 1.0.2 + */ + @CopilotExperimental + public static List fromObject(Object instance) { + if (instance == null) { + throw new IllegalArgumentException("instance must not be null"); + } + Class clazz = instance.getClass(); + return loadDefinitions(clazz, instance); + } + + /** + * Discovers tool definitions from a class with static + * {@code @CopilotTool}-annotated methods. Requires that the + * {@code CopilotToolProcessor} annotation processor ran at compile time + * (generating the {@code $$CopilotToolMeta} companion class). + * + * @param clazz + * the class containing static {@code @CopilotTool}-annotated methods + * @return list of tool definitions with working invocation handlers + * @throws IllegalStateException + * if the generated {@code $$CopilotToolMeta} class is not found + * (annotation processor did not run) + * @since 1.0.2 + */ + @CopilotExperimental + public static List fromClass(Class clazz) { + if (clazz == null) { + throw new IllegalArgumentException("clazz must not be null"); + } + List instanceMethods = Arrays.stream(clazz.getDeclaredMethods()) + .filter(m -> m.isAnnotationPresent(com.github.copilot.tool.CopilotTool.class)) + .filter(m -> !Modifier.isStatic(m.getModifiers())).map(Method::getName).collect(Collectors.toList()); + if (!instanceMethods.isEmpty()) { + throw new IllegalArgumentException( + "fromClass() requires all @CopilotTool methods to be static, but found instance methods: " + + instanceMethods + ". Use fromObject(new " + clazz.getSimpleName() + "()) instead."); + } + return loadDefinitions(clazz, null); + } + + @SuppressWarnings("unchecked") + private static List loadDefinitions(Class clazz, Object instance) { + String metaClassName = clazz.getName() + "$$CopilotToolMeta"; + try { + Class metaClass = Class.forName(metaClassName, true, clazz.getClassLoader()); + var provider = (com.github.copilot.tool.CopilotToolMetadataProvider) metaClass + .getDeclaredConstructor().newInstance(); + return provider.definitions(instance, getConfiguredMapper()); + } catch (ClassNotFoundException e) { + throw new IllegalStateException("Generated class " + metaClassName + " not found. " + + "Ensure the CopilotToolProcessor annotation processor ran during compilation. " + + "Add the copilot-sdk-java dependency to your annotation processor path.", e); + } catch (ReflectiveOperationException e) { + throw new IllegalStateException("Failed to invoke " + metaClassName + ".definitions()", e); + } + } + + /** + * Returns the SDK-configured ObjectMapper for tool argument/result + * serialization. Configuration mirrors + * {@code JsonRpcClient.createObjectMapper()}. + */ + private static ObjectMapper getConfiguredMapper() { + return ConfiguredMapperHolder.INSTANCE; + } + + /** + * Lazy holder for the configured ObjectMapper (thread-safe, initialized on + * first access). + */ + private static final class ConfiguredMapperHolder { + static final ObjectMapper INSTANCE = createMapper(); + + private static ObjectMapper createMapper() { + // Configuration must match JsonRpcClient.createObjectMapper() + var mapper = new ObjectMapper(); + mapper.registerModule(new JavaTimeModule()); + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + mapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false); + mapper.setDefaultPropertyInclusion(JsonInclude.Include.NON_NULL); + return mapper; + } + } } diff --git a/java/src/main/java/com/github/copilot/tool/CopilotTool.java b/java/src/main/java/com/github/copilot/tool/CopilotTool.java new file mode 100644 index 000000000..9cde49b20 --- /dev/null +++ b/java/src/main/java/com/github/copilot/tool/CopilotTool.java @@ -0,0 +1,52 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.tool; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import com.github.copilot.CopilotExperimental; +import com.github.copilot.rpc.ToolDefer; + +/** + * Marks a method as a Copilot tool. The annotated method will be exposed to the + * model as a callable tool during a session. + * + *

+ * Example usage: + * + *

+ * @CopilotTool("Get weather for a location")
+ * public CompletableFuture<String> getWeather(@Param(value = "City name", required = true) String location) {
+ * 	return CompletableFuture.completedFuture("Sunny in " + location);
+ * }
+ * 
+ * + * @since 1.0.2 + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +@CopilotExperimental +public @interface CopilotTool { + + /** Tool description (sent to the model). */ + String value(); + + /** Tool name. Defaults to method name converted to snake_case. */ + String name() default ""; + + /** Whether this tool overrides a built-in tool. */ + boolean overridesBuiltInTool() default false; + + /** Whether to skip permission checks. */ + boolean skipPermission() default false; + + /** Defer configuration for this tool. */ + ToolDefer defer() default ToolDefer.NONE; +} diff --git a/java/src/main/java/com/github/copilot/tool/CopilotToolMetadataProvider.java b/java/src/main/java/com/github/copilot/tool/CopilotToolMetadataProvider.java new file mode 100644 index 000000000..25194626e --- /dev/null +++ b/java/src/main/java/com/github/copilot/tool/CopilotToolMetadataProvider.java @@ -0,0 +1,42 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.tool; + +import java.util.List; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.CopilotExperimental; +import com.github.copilot.rpc.ToolDefinition; + +/** + * Contract for classes that provide {@link ToolDefinition} metadata for + * {@code @CopilotTool}-annotated methods. + * + *

+ * The {@link CopilotToolProcessor} annotation processor generates an + * implementation of this interface as a {@code $$CopilotToolMeta} companion + * class. Users may also implement this interface directly for full manual + * control over tool registration without using annotation processing. + * + * @param + * the tool class whose methods are described by this provider + * @since 1.0.2 + */ +@CopilotExperimental +public interface CopilotToolMetadataProvider { + + /** + * Returns tool definitions for the given instance. + * + * @param instance + * the object containing tool methods, or {@code null} for static + * methods + * @param mapper + * the SDK-configured {@link ObjectMapper} for argument + * deserialization + * @return list of tool definitions with working invocation handlers + */ + List definitions(T instance, ObjectMapper mapper); +} diff --git a/java/src/main/java/com/github/copilot/tool/CopilotToolProcessor.java b/java/src/main/java/com/github/copilot/tool/CopilotToolProcessor.java new file mode 100644 index 000000000..51e20c5b0 --- /dev/null +++ b/java/src/main/java/com/github/copilot/tool/CopilotToolProcessor.java @@ -0,0 +1,636 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.tool; + +import java.io.IOException; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.annotation.processing.AbstractProcessor; +import javax.annotation.processing.RoundEnvironment; +import javax.annotation.processing.SupportedAnnotationTypes; +import javax.annotation.processing.SupportedSourceVersion; +import javax.lang.model.SourceVersion; +import javax.lang.model.element.Element; +import javax.lang.model.element.ElementKind; +import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.Modifier; +import javax.lang.model.element.TypeElement; +import javax.lang.model.element.VariableElement; +import javax.lang.model.type.DeclaredType; +import javax.lang.model.type.TypeKind; +import javax.lang.model.type.TypeMirror; +import javax.tools.Diagnostic; +import javax.tools.JavaFileObject; + +import com.github.copilot.CopilotExperimental; + +/** + * JSR 269 annotation processor that finds {@link CopilotTool}-annotated methods + * and generates {@code $$CopilotToolMeta} companion classes containing tool + * definitions, JSON Schema, and invocation lambdas. + * + *

+ * For a class {@code com.example.MyTools} containing {@code @CopilotTool} + * methods, this processor generates + * {@code com.example.MyTools$$CopilotToolMeta} in the same package. + * + * @since 1.0.2 + */ +@SupportedAnnotationTypes("com.github.copilot.tool.CopilotTool") +@SupportedSourceVersion(SourceVersion.RELEASE_17) +@CopilotExperimental +public class CopilotToolProcessor extends AbstractProcessor { + + private final SchemaGenerator schemaGenerator = new SchemaGenerator(); + + @Override + public boolean process(Set annotations, RoundEnvironment roundEnv) { + List annotatedElements = getCopilotToolAnnotatedElements(roundEnv); + for (Element element : annotatedElements) { + if (element.getKind() != ElementKind.METHOD) { + continue; + } + ExecutableElement method = (ExecutableElement) element; + + // Validate: private methods are not allowed + if (method.getModifiers().contains(Modifier.PRIVATE)) { + processingEnv.getMessager().printMessage(Diagnostic.Kind.ERROR, + "@CopilotTool methods must not be private", method); + continue; + } + + // Validate @Param conflicts + for (VariableElement param : method.getParameters()) { + Param paramAnnotation = param.getAnnotation(Param.class); + if (paramAnnotation != null && paramAnnotation.required() + && !paramAnnotation.defaultValue().isEmpty()) { + processingEnv.getMessager().printMessage(Diagnostic.Kind.ERROR, + "@Param cannot have both required=true and a non-empty defaultValue", param); + } + } + } + + // Group methods by enclosing type + Map> methodsByClass = new LinkedHashMap<>(); + for (Element element : annotatedElements) { + if (element.getKind() != ElementKind.METHOD) { + continue; + } + ExecutableElement method = (ExecutableElement) element; + if (method.getModifiers().contains(Modifier.PRIVATE)) { + continue; + } + TypeElement enclosingType = (TypeElement) method.getEnclosingElement(); + methodsByClass.computeIfAbsent(enclosingType, k -> new ArrayList<>()).add(method); + } + + // Generate $$CopilotToolMeta for each class + for (Map.Entry> entry : methodsByClass.entrySet()) { + generateMetaClass(entry.getKey(), entry.getValue()); + } + + return false; + } + + private List getCopilotToolAnnotatedElements(RoundEnvironment roundEnv) { + TypeElement copilotToolType = processingEnv.getElementUtils() + .getTypeElement("com.github.copilot.tool.CopilotTool"); + if (copilotToolType != null) { + return new ArrayList<>(roundEnv.getElementsAnnotatedWith(copilotToolType)); + } + return new ArrayList<>(roundEnv.getElementsAnnotatedWith(CopilotTool.class)); + } + + private void generateMetaClass(TypeElement classElement, List methods) { + String packageName = processingEnv.getElementUtils().getPackageOf(classElement).getQualifiedName().toString(); + String simpleClassName = classElement.getSimpleName().toString(); + String metaClassName = simpleClassName + "$$CopilotToolMeta"; + String qualifiedMetaClassName = packageName.isEmpty() ? metaClassName : packageName + "." + metaClassName; + + try { + JavaFileObject sourceFile = processingEnv.getFiler().createSourceFile(qualifiedMetaClassName, classElement); + try (PrintWriter out = new PrintWriter(sourceFile.openWriter())) { + writeMetaClass(out, packageName, simpleClassName, metaClassName, methods); + } + } catch (IOException e) { + processingEnv.getMessager().printMessage(Diagnostic.Kind.ERROR, + "Failed to generate " + metaClassName + ": " + e.getMessage(), classElement); + } + } + + private void writeMetaClass(PrintWriter out, String packageName, String simpleClassName, String metaClassName, + List methods) { + out.println("// GENERATED by CopilotToolProcessor — do not edit"); + + if (!packageName.isEmpty()) { + out.println("package " + packageName + ";"); + out.println(); + } + + out.println("import com.github.copilot.rpc.ToolDefinition;"); + out.println("import com.github.copilot.rpc.ToolDefer;"); + out.println("import com.github.copilot.tool.CopilotToolMetadataProvider;"); + out.println("import com.fasterxml.jackson.databind.ObjectMapper;"); + out.println("import java.util.*;"); + out.println("import java.util.concurrent.CompletableFuture;"); + out.println(); + + out.println("public final class " + metaClassName + " implements CopilotToolMetadataProvider<" + simpleClassName + + "> {"); + out.println(); + + // Helper method for adding description/default to schema maps + if (needsWithMetaHelper(methods)) { + out.println( + " private static Map withMeta(Map base, String description, Object defaultValue) {"); + out.println(" var result = new LinkedHashMap(base);"); + out.println(" if (description != null) result.put(\"description\", description);"); + out.println(" if (defaultValue != null) result.put(\"default\", defaultValue);"); + out.println(" return Collections.unmodifiableMap(result);"); + out.println(" }"); + out.println(); + } + + // definitions method + out.println(" @Override"); + out.println(" @SuppressWarnings({\"unchecked\", \"rawtypes\"})"); + out.println( + " public List definitions(" + simpleClassName + " instance, ObjectMapper mapper) {"); + out.println(" return List.of("); + + for (int i = 0; i < methods.size(); i++) { + ExecutableElement method = methods.get(i); + writeToolDefinition(out, method); + if (i < methods.size() - 1) { + out.println(","); + } else { + out.println(); + } + } + + out.println(" );"); + out.println(" }"); + out.println("}"); + } + + private boolean needsWithMetaHelper(List methods) { + for (ExecutableElement method : methods) { + for (VariableElement param : method.getParameters()) { + Param paramAnnotation = param.getAnnotation(Param.class); + if (paramAnnotation != null + && (!paramAnnotation.value().isEmpty() || !paramAnnotation.defaultValue().isEmpty())) { + return true; + } + } + } + return false; + } + + private void writeToolDefinition(PrintWriter out, ExecutableElement method) { + CopilotTool annotation = method.getAnnotation(CopilotTool.class); + String toolName = annotation.name().isEmpty() + ? toSnakeCase(method.getSimpleName().toString()) + : annotation.name(); + String description = annotation.value(); + boolean overridesBuiltIn = annotation.overridesBuiltInTool(); + boolean skipPermission = annotation.skipPermission(); + com.github.copilot.rpc.ToolDefer defer = annotation.defer(); + + // Generate schema with @Param metadata (descriptions, names, defaults) + String schemaSource = generateSchemaWithParamMetadata(method.getParameters()); + + // Generate invocation lambda + String lambdaBody = generateLambdaBody(method); + + // Use the record constructor directly so all flags apply independently + String overridesArg = overridesBuiltIn ? "Boolean.TRUE" : "null"; + String skipPermArg = skipPermission ? "Boolean.TRUE" : "null"; + String deferArg = defer != com.github.copilot.rpc.ToolDefer.NONE ? "ToolDefer." + defer.name() : "null"; + + out.println(" new ToolDefinition("); + out.println(" \"" + escapeJava(toolName) + "\","); + out.println(" \"" + escapeJava(description) + "\","); + out.println(" " + schemaSource + ","); + out.println(" invocation -> {"); + out.println(" " + lambdaBody); + out.println(" },"); + out.println(" " + overridesArg + ","); + out.println(" " + skipPermArg + ","); + out.println(" " + deferArg); + out.print(" )"); + } + + private String generateSchemaWithParamMetadata(List parameters) { + if (parameters.isEmpty()) { + return "Map.of(\"type\", \"object\", \"properties\", Map.of(), \"required\", List.of())"; + } + + List propertyEntries = new ArrayList<>(); + List requiredNames = new ArrayList<>(); + + for (VariableElement param : parameters) { + String paramName = getParamName(param); + TypeMirror paramType = param.asType(); + Param paramAnnotation = param.getAnnotation(Param.class); + + // Generate the type schema for this parameter + String typeSchema = schemaGenerator.generateSchemaSource(paramType, processingEnv.getTypeUtils(), + processingEnv.getElementUtils()); + + // Build property schema with description and default if present + String propertySchema = buildPropertySchema(typeSchema, paramAnnotation, paramType); + + // Cast to Map via raw type for consistent Map.ofEntries typing + propertyEntries.add("Map.entry(\"" + paramName + "\", (Map)(Map) " + propertySchema + ")"); + + // Determine if required (Optional* types are never required) + boolean isOptionalType = paramType.getKind() == TypeKind.DECLARED && Set + .of("java.util.Optional", "java.util.OptionalInt", "java.util.OptionalLong", + "java.util.OptionalDouble") + .contains(((TypeElement) ((DeclaredType) paramType).asElement()).getQualifiedName().toString()); + if (!isOptionalType && (paramAnnotation == null || paramAnnotation.required())) { + requiredNames.add("\"" + paramName + "\""); + } + } + + String properties = "Map.ofEntries(" + String.join(", ", propertyEntries) + ")"; + String required = "List.of(" + String.join(", ", requiredNames) + ")"; + + return "Map.of(\"type\", \"object\", \"properties\", " + properties + ", \"required\", " + required + ")"; + } + + private String buildPropertySchema(String typeSchema, Param paramAnnotation, TypeMirror paramType) { + if (paramAnnotation == null) { + return typeSchema; + } + + String desc = paramAnnotation.value(); + String defaultValue = paramAnnotation.defaultValue(); + + boolean hasDescription = !desc.isEmpty(); + boolean hasDefault = !defaultValue.isEmpty(); + + if (!hasDescription && !hasDefault) { + return typeSchema; + } + + // Use the withMeta helper method in the generated class + String descArg = hasDescription ? "\"" + escapeJava(desc) + "\"" : "null"; + String defaultArg = hasDefault ? generateDefaultLiteral(paramType, defaultValue) : "null"; + + return "withMeta(" + typeSchema + ", " + descArg + ", " + defaultArg + ")"; + } + + private String generateLambdaBody(ExecutableElement method) { + List params = method.getParameters(); + StringBuilder sb = new StringBuilder(); + + // Generate argument extraction + if (!params.isEmpty()) { + sb.append("Map args = invocation.getArguments();\n"); + + // Check if single-record-parameter shortcut applies + if (params.size() == 1 && isRecord(params.get(0).asType())) { + String typeName = getTypeString(params.get(0).asType()); + String paramName = params.get(0).getSimpleName().toString(); + sb.append(" ").append(typeName).append(" ").append(paramName) + .append(" = mapper.convertValue(args, ").append(typeName).append(".class);\n"); + } else { + for (VariableElement param : params) { + String paramName = getParamName(param); + String varName = param.getSimpleName().toString(); + TypeMirror paramType = param.asType(); + + // Handle default values + Param paramAnnotation = param.getAnnotation(Param.class); + boolean hasDefault = paramAnnotation != null && !paramAnnotation.defaultValue().isEmpty(); + + if (hasDefault) { + String defaultValue = paramAnnotation.defaultValue(); + sb.append(" Object ").append(varName).append("Raw = args.containsKey(\"") + .append(paramName).append("\") ? args.get(\"").append(paramName).append("\") : ") + .append(generateDefaultLiteral(paramType, defaultValue)).append(";\n"); + sb.append(" ").append(getTypeString(paramType)).append(" ").append(varName) + .append(" = ").append(generateArgExtraction(varName + "Raw", paramType)).append(";\n"); + } else if (isOptionalType(paramType)) { + generateOptionalExtraction(sb, paramName, varName, paramType); + } else { + sb.append(" ").append(getTypeString(paramType)).append(" ").append(varName) + .append(" = ").append(generateArgExtractionFromMap(paramName, paramType)).append(";\n"); + } + } + } + } + + // Generate method invocation based on return type + TypeMirror returnType = method.getReturnType(); + String callTarget = method.getModifiers().contains(Modifier.STATIC) + ? ((TypeElement) method.getEnclosingElement()).getQualifiedName().toString() + : "instance"; + String methodCall = callTarget + "." + method.getSimpleName() + "(" + generateArgList(params) + ")"; + + if (returnType.getKind() == TypeKind.VOID) { + sb.append(" ").append(methodCall).append(";\n"); + sb.append(" return CompletableFuture.completedFuture(\"Success\");"); + } else if (isCompletableFuture(returnType)) { + TypeMirror typeArg = getCompletableFutureTypeArg(returnType); + if (typeArg != null && isStringType(typeArg)) { + // CompletableFuture -> CompletableFuture via thenApply + sb.append(" return ").append(methodCall).append(".thenApply(r -> (Object) r);"); + } else { + // CompletableFuture -> serialize to JSON + sb.append(" return ").append(methodCall) + .append(".thenApply(r -> { try { return (Object) mapper.writeValueAsString(r); }") + .append(" catch (Exception e) { throw new RuntimeException(e); } });"); + } + } else if (isStringType(returnType)) { + sb.append(" return CompletableFuture.completedFuture(").append(methodCall).append(");"); + } else { + sb.append(" try { return CompletableFuture.completedFuture(mapper.writeValueAsString(") + .append(methodCall).append(")); } catch (Exception e) { throw new RuntimeException(e); }"); + } + + return sb.toString(); + } + + private String generateArgList(List params) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < params.size(); i++) { + if (i > 0) { + sb.append(", "); + } + sb.append(params.get(i).getSimpleName().toString()); + } + return sb.toString(); + } + + private String generateArgExtractionFromMap(String paramName, TypeMirror type) { + if (type.getKind().isPrimitive()) { + return generatePrimitiveExtraction("args.get(\"" + paramName + "\")", type); + } + if (type.getKind() == TypeKind.DECLARED) { + TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement(); + String qualifiedName = typeElement.getQualifiedName().toString(); + if ("java.lang.String".equals(qualifiedName)) { + return "(String) args.get(\"" + paramName + "\")"; + } + if (isBoxedNumeric(qualifiedName)) { + return generateBoxedNumericExtraction("args.get(\"" + paramName + "\")", qualifiedName); + } + if ("java.lang.Boolean".equals(qualifiedName)) { + return "(Boolean) args.get(\"" + paramName + "\")"; + } + // Complex types: enums, records, POJOs + return "mapper.convertValue(args.get(\"" + paramName + "\"), " + qualifiedName + ".class)"; + } + return "(Object) args.get(\"" + paramName + "\")"; + } + + private String generateArgExtraction(String varExpr, TypeMirror type) { + if (type.getKind().isPrimitive()) { + return generatePrimitiveExtraction(varExpr, type); + } + if (type.getKind() == TypeKind.DECLARED) { + TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement(); + String qualifiedName = typeElement.getQualifiedName().toString(); + if ("java.lang.String".equals(qualifiedName)) { + return "(String) " + varExpr; + } + if (isBoxedNumeric(qualifiedName)) { + return generateBoxedNumericExtraction(varExpr, qualifiedName); + } + if ("java.lang.Boolean".equals(qualifiedName)) { + return "(Boolean) " + varExpr; + } + return "mapper.convertValue(" + varExpr + ", " + qualifiedName + ".class)"; + } + return "(Object) " + varExpr; + } + + private String generatePrimitiveExtraction(String expr, TypeMirror type) { + switch (type.getKind()) { + case INT : + return "((Number) " + expr + ").intValue()"; + case LONG : + return "((Number) " + expr + ").longValue()"; + case DOUBLE : + return "((Number) " + expr + ").doubleValue()"; + case FLOAT : + return "((Number) " + expr + ").floatValue()"; + case SHORT : + return "((Number) " + expr + ").shortValue()"; + case BYTE : + return "((Number) " + expr + ").byteValue()"; + case BOOLEAN : + return "(Boolean) " + expr; + case CHAR : + return "((String) " + expr + ").charAt(0)"; + default : + return "(" + type + ") " + expr; + } + } + + private boolean isOptionalType(TypeMirror type) { + if (type.getKind() != TypeKind.DECLARED) { + return false; + } + TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement(); + String name = typeElement.getQualifiedName().toString(); + return "java.util.Optional".equals(name) || "java.util.OptionalInt".equals(name) + || "java.util.OptionalLong".equals(name) || "java.util.OptionalDouble".equals(name); + } + + private void generateOptionalExtraction(StringBuilder sb, String paramName, String varName, TypeMirror paramType) { + TypeElement typeElement = (TypeElement) ((DeclaredType) paramType).asElement(); + String qualifiedName = typeElement.getQualifiedName().toString(); + + sb.append(" Object ").append(varName).append("Raw = args.get(\"").append(paramName) + .append("\");\n"); + + switch (qualifiedName) { + case "java.util.OptionalInt" : + sb.append(" java.util.OptionalInt ").append(varName).append(" = ").append(varName) + .append("Raw != null ? java.util.OptionalInt.of(((Number) ").append(varName) + .append("Raw).intValue()) : java.util.OptionalInt.empty();\n"); + break; + case "java.util.OptionalLong" : + sb.append(" java.util.OptionalLong ").append(varName).append(" = ").append(varName) + .append("Raw != null ? java.util.OptionalLong.of(((Number) ").append(varName) + .append("Raw).longValue()) : java.util.OptionalLong.empty();\n"); + break; + case "java.util.OptionalDouble" : + sb.append(" java.util.OptionalDouble ").append(varName).append(" = ").append(varName) + .append("Raw != null ? java.util.OptionalDouble.of(((Number) ").append(varName) + .append("Raw).doubleValue()) : java.util.OptionalDouble.empty();\n"); + break; + default : + // java.util.Optional — unwrap the type argument + List typeArgs = ((DeclaredType) paramType).getTypeArguments(); + if (!typeArgs.isEmpty()) { + TypeMirror innerType = typeArgs.get(0); + String innerExtraction = generateArgExtraction(varName + "Raw", innerType); + sb.append(" java.util.Optional ").append(varName).append(" = ").append(varName) + .append("Raw != null ? java.util.Optional.of(").append(innerExtraction) + .append(") : java.util.Optional.empty();\n"); + } else { + sb.append(" java.util.Optional ").append(varName).append(" = ").append(varName) + .append("Raw != null ? java.util.Optional.of(").append(varName) + .append("Raw) : java.util.Optional.empty();\n"); + } + break; + } + } + + private boolean isBoxedNumeric(String qualifiedName) { + return "java.lang.Integer".equals(qualifiedName) || "java.lang.Long".equals(qualifiedName) + || "java.lang.Double".equals(qualifiedName) || "java.lang.Float".equals(qualifiedName) + || "java.lang.Short".equals(qualifiedName) || "java.lang.Byte".equals(qualifiedName); + } + + private String generateBoxedNumericExtraction(String expr, String qualifiedName) { + switch (qualifiedName) { + case "java.lang.Integer" : + return "((Number) " + expr + ").intValue()"; + case "java.lang.Long" : + return "((Number) " + expr + ").longValue()"; + case "java.lang.Double" : + return "((Number) " + expr + ").doubleValue()"; + case "java.lang.Float" : + return "((Number) " + expr + ").floatValue()"; + case "java.lang.Short" : + return "((Number) " + expr + ").shortValue()"; + case "java.lang.Byte" : + return "((Number) " + expr + ").byteValue()"; + default : + return "(" + qualifiedName + ") " + expr; + } + } + + private String generateDefaultLiteral(TypeMirror type, String defaultValue) { + if (type.getKind().isPrimitive()) { + switch (type.getKind()) { + case INT : + case LONG : + case SHORT : + case BYTE : + return defaultValue; + case DOUBLE : + case FLOAT : + return defaultValue; + case BOOLEAN : + return defaultValue; + case CHAR : + return "\"" + escapeJava(defaultValue) + "\""; + default : + return "\"" + escapeJava(defaultValue) + "\""; + } + } + if (type.getKind() == TypeKind.DECLARED) { + TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement(); + String qualifiedName = typeElement.getQualifiedName().toString(); + if ("java.lang.String".equals(qualifiedName)) { + return "\"" + escapeJava(defaultValue) + "\""; + } + if (isBoxedNumeric(qualifiedName) || "java.lang.Boolean".equals(qualifiedName)) { + return defaultValue; + } + } + return "\"" + escapeJava(defaultValue) + "\""; + } + + private String getParamName(VariableElement param) { + Param paramAnnotation = param.getAnnotation(Param.class); + if (paramAnnotation != null && !paramAnnotation.name().isEmpty()) { + return paramAnnotation.name(); + } + return param.getSimpleName().toString(); + } + + private String getTypeString(TypeMirror type) { + if (type.getKind().isPrimitive()) { + return type.toString(); + } + if (type.getKind() == TypeKind.DECLARED) { + TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement(); + return typeElement.getQualifiedName().toString(); + } + return type.toString(); + } + + private boolean isRecord(TypeMirror type) { + if (type.getKind() != TypeKind.DECLARED) { + return false; + } + TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement(); + return typeElement.getKind() == ElementKind.RECORD; + } + + private boolean isCompletableFuture(TypeMirror type) { + if (type.getKind() != TypeKind.DECLARED) { + return false; + } + TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement(); + return "java.util.concurrent.CompletableFuture".equals(typeElement.getQualifiedName().toString()); + } + + private TypeMirror getCompletableFutureTypeArg(TypeMirror type) { + if (type.getKind() != TypeKind.DECLARED) { + return null; + } + DeclaredType declaredType = (DeclaredType) type; + List typeArgs = declaredType.getTypeArguments(); + if (typeArgs.isEmpty()) { + return null; + } + return typeArgs.get(0); + } + + private boolean isStringType(TypeMirror type) { + if (type.getKind() != TypeKind.DECLARED) { + return false; + } + TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement(); + return "java.lang.String".equals(typeElement.getQualifiedName().toString()); + } + + /** + * Converts a camelCase method name to snake_case. + * + * @param name + * the method name + * @return the snake_case tool name + */ + static String toSnakeCase(String name) { + if (name == null || name.isEmpty()) { + return name; + } + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < name.length(); i++) { + char c = name.charAt(i); + if (Character.isUpperCase(c)) { + if (i > 0) { + sb.append('_'); + } + sb.append(Character.toLowerCase(c)); + } else { + sb.append(c); + } + } + return sb.toString(); + } + + private static String escapeJava(String s) { + if (s == null) { + return ""; + } + return s.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", + "\\t"); + } +} diff --git a/java/src/main/java/com/github/copilot/tool/Param.java b/java/src/main/java/com/github/copilot/tool/Param.java new file mode 100644 index 000000000..aaef04947 --- /dev/null +++ b/java/src/main/java/com/github/copilot/tool/Param.java @@ -0,0 +1,49 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.tool; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import com.github.copilot.CopilotExperimental; + +/** + * Annotates a parameter of a {@link CopilotTool}-annotated method to provide + * metadata about the parameter that is sent to the model. + * + *

+ * Example usage: + * + *

+ * @CopilotTool("Search for issues")
+ * public CompletableFuture<String> searchIssues(@Param(value = "Search query", required = true) String query,
+ * 		@Param(value = "Max results", required = false, defaultValue = "10") int limit) {
+ * 	// ...
+ * }
+ * 
+ * + * @since 1.0.2 + */ +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.PARAMETER) +@CopilotExperimental +public @interface Param { + + /** Parameter description (sent to the model). */ + String value() default ""; + + /** Parameter name override. Defaults to the actual parameter name. */ + String name() default ""; + + /** Whether this parameter is required. Default true. */ + boolean required() default true; + + /** Optional default value when the argument is omitted. */ + String defaultValue() default ""; +} diff --git a/java/src/main/java/com/github/copilot/tool/SchemaGenerator.java b/java/src/main/java/com/github/copilot/tool/SchemaGenerator.java new file mode 100644 index 000000000..fb321ae9d --- /dev/null +++ b/java/src/main/java/com/github/copilot/tool/SchemaGenerator.java @@ -0,0 +1,392 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.tool; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import javax.lang.model.element.Element; +import javax.lang.model.element.ElementKind; +import javax.lang.model.element.RecordComponentElement; +import javax.lang.model.element.TypeElement; +import javax.lang.model.element.VariableElement; +import javax.lang.model.type.ArrayType; +import javax.lang.model.type.DeclaredType; +import javax.lang.model.type.TypeKind; +import javax.lang.model.type.TypeMirror; +import javax.lang.model.util.Elements; +import javax.lang.model.util.Types; + +import com.github.copilot.CopilotExperimental; + +/** + * Compile-time utility that maps {@code javax.lang.model} types to JSON Schema + * represented as Java source code literals ({@code Map.of(...)} expressions). + * + *

+ * This class is invoked by the annotation processor and operates exclusively + * with the {@code javax.lang.model} API. It does NOT use + * {@code java.lang.reflect}. + * + * @since 1.0.2 + */ +@CopilotExperimental +public class SchemaGenerator { + + /** + * Given a {@link TypeMirror} from the annotation processing environment, + * returns a {@code String} containing Java source code for a {@code Map} + * literal representing the JSON Schema of that type. + * + * @param type + * the type to generate schema for + * @param typeUtils + * the {@link Types} utility from the processing environment + * @param elementUtils + * the {@link Elements} utility from the processing environment + * @return a Java source code string representing the JSON Schema + */ + public String generateSchemaSource(TypeMirror type, Types typeUtils, Elements elementUtils) { + return generateSchema(type, typeUtils, elementUtils); + } + + /** + * Generates the full "parameters" schema source for a method's parameters. + * Produces a + * {@code Map.of("type", "object", "properties", Map.of(...), "required", List.of(...))}. + * + * @param parameters + * the method parameters to generate schema for + * @param typeUtils + * the {@link Types} utility from the processing environment + * @param elementUtils + * the {@link Elements} utility from the processing environment + * @return a Java source code string representing the parameters JSON Schema + */ + public String generateParametersSchemaSource(List parameters, Types typeUtils, + Elements elementUtils) { + if (parameters.isEmpty()) { + return "Map.of(\"type\", \"object\", \"properties\", Map.of(), \"required\", List.of())"; + } + + List propertyEntries = new ArrayList<>(); + List requiredNames = new ArrayList<>(); + + for (VariableElement param : parameters) { + String paramName = param.getSimpleName().toString(); + TypeMirror paramType = param.asType(); + + boolean isOptional = isOptionalType(paramType); + String schema; + if (isOptional) { + schema = generateSchema(unwrapOptional(paramType, typeUtils), typeUtils, elementUtils); + } else { + schema = generateSchema(paramType, typeUtils, elementUtils); + } + + propertyEntries.add("Map.entry(\"" + paramName + "\", " + schema + ")"); + + if (!isOptional) { + Param paramAnnotation = param.getAnnotation(Param.class); + if (paramAnnotation == null || paramAnnotation.required()) { + requiredNames.add("\"" + paramName + "\""); + } + } + } + + String properties = "Map.ofEntries(" + String.join(", ", propertyEntries) + ")"; + String required = "List.of(" + String.join(", ", requiredNames) + ")"; + + return "Map.of(\"type\", \"object\", \"properties\", " + properties + ", \"required\", " + required + ")"; + } + + private String generateSchema(TypeMirror type, Types typeUtils, Elements elementUtils) { + // Handle primitive types + if (type.getKind().isPrimitive()) { + return generatePrimitiveSchema(type.getKind()); + } + + // Handle array types + if (type.getKind() == TypeKind.ARRAY) { + ArrayType arrayType = (ArrayType) type; + TypeMirror componentType = arrayType.getComponentType(); + String itemsSchema = generateSchema(componentType, typeUtils, elementUtils); + return "Map.of(\"type\", \"array\", \"items\", " + itemsSchema + ")"; + } + + // Handle declared types (classes, interfaces, enums, records) + if (type.getKind() == TypeKind.DECLARED) { + return generateDeclaredTypeSchema((DeclaredType) type, typeUtils, elementUtils); + } + + // Fallback: any + return "Map.of()"; + } + + private String generatePrimitiveSchema(TypeKind kind) { + switch (kind) { + case INT : + case LONG : + case BYTE : + case SHORT : + return "Map.of(\"type\", \"integer\")"; + case DOUBLE : + case FLOAT : + return "Map.of(\"type\", \"number\")"; + case BOOLEAN : + return "Map.of(\"type\", \"boolean\")"; + case CHAR : + return "Map.of(\"type\", \"string\")"; + default : + return "Map.of()"; + } + } + + private String generateDeclaredTypeSchema(DeclaredType type, Types typeUtils, Elements elementUtils) { + TypeElement typeElement = (TypeElement) type.asElement(); + String qualifiedName = typeElement.getQualifiedName().toString(); + + // String + if ("java.lang.String".equals(qualifiedName)) { + return "Map.of(\"type\", \"string\")"; + } + + // Boxed primitives + if ("java.lang.Integer".equals(qualifiedName) || "java.lang.Long".equals(qualifiedName) + || "java.lang.Byte".equals(qualifiedName) || "java.lang.Short".equals(qualifiedName)) { + return "Map.of(\"type\", \"integer\")"; + } + if ("java.lang.Double".equals(qualifiedName) || "java.lang.Float".equals(qualifiedName)) { + return "Map.of(\"type\", \"number\")"; + } + if ("java.lang.Boolean".equals(qualifiedName)) { + return "Map.of(\"type\", \"boolean\")"; + } + if ("java.lang.Character".equals(qualifiedName)) { + return "Map.of(\"type\", \"string\")"; + } + + // UUID + if ("java.util.UUID".equals(qualifiedName)) { + return "Map.of(\"type\", \"string\", \"format\", \"uuid\")"; + } + + // Date-time types (ISO-8601 format hints for the model) + if ("java.time.OffsetDateTime".equals(qualifiedName) || "java.time.LocalDateTime".equals(qualifiedName) + || "java.time.Instant".equals(qualifiedName) || "java.time.ZonedDateTime".equals(qualifiedName)) { + return "Map.of(\"type\", \"string\", \"format\", \"date-time\")"; + } + if ("java.time.LocalDate".equals(qualifiedName)) { + return "Map.of(\"type\", \"string\", \"format\", \"date\")"; + } + if ("java.time.LocalTime".equals(qualifiedName)) { + return "Map.of(\"type\", \"string\", \"format\", \"time\")"; + } + + // JsonNode (any) + if ("com.fasterxml.jackson.databind.JsonNode".equals(qualifiedName)) { + return "Map.of()"; + } + + // Object (any) + if ("java.lang.Object".equals(qualifiedName)) { + return "Map.of()"; + } + + // Optional types + if ("java.util.Optional".equals(qualifiedName)) { + List typeArgs = type.getTypeArguments(); + if (!typeArgs.isEmpty()) { + return generateSchema(typeArgs.get(0), typeUtils, elementUtils); + } + return "Map.of()"; + } + if ("java.util.OptionalInt".equals(qualifiedName)) { + return "Map.of(\"type\", \"integer\")"; + } + if ("java.util.OptionalDouble".equals(qualifiedName)) { + return "Map.of(\"type\", \"number\")"; + } + if ("java.util.OptionalLong".equals(qualifiedName)) { + return "Map.of(\"type\", \"integer\")"; + } + + // List / Collection + if (isCollectionType(qualifiedName)) { + List typeArgs = type.getTypeArguments(); + if (!typeArgs.isEmpty()) { + String itemsSchema = generateSchema(typeArgs.get(0), typeUtils, elementUtils); + return "Map.of(\"type\", \"array\", \"items\", " + itemsSchema + ")"; + } + return "Map.of(\"type\", \"array\")"; + } + + // Map + if (isMapType(qualifiedName)) { + List typeArgs = type.getTypeArguments(); + if (typeArgs.size() == 2) { + TypeMirror valueType = typeArgs.get(1); + if (valueType.getKind() == TypeKind.DECLARED) { + TypeElement valueElement = (TypeElement) ((DeclaredType) valueType).asElement(); + String valueQName = valueElement.getQualifiedName().toString(); + if ("java.lang.Object".equals(valueQName)) { + return "Map.of(\"type\", \"object\")"; + } + } + String valueSchema = generateSchema(valueType, typeUtils, elementUtils); + return "Map.of(\"type\", \"object\", \"additionalProperties\", " + valueSchema + ")"; + } + return "Map.of(\"type\", \"object\")"; + } + + // Enum types + if (typeElement.getKind() == ElementKind.ENUM) { + List constants = typeElement.getEnclosedElements().stream() + .filter(e -> e.getKind() == ElementKind.ENUM_CONSTANT) + .map(e -> "\"" + e.getSimpleName().toString() + "\"").collect(Collectors.toList()); + return "Map.of(\"type\", \"string\", \"enum\", List.of(" + String.join(", ", constants) + "))"; + } + + // Record types + if (typeElement.getKind() == ElementKind.RECORD) { + return generateRecordSchema(typeElement, typeUtils, elementUtils); + } + + // POJO / class types — treat as object with fields + if (typeElement.getKind() == ElementKind.CLASS) { + return generateClassSchema(typeElement, typeUtils, elementUtils); + } + + // Sealed interfaces — oneOf via permitted subclasses + if (typeElement.getKind() == ElementKind.INTERFACE) { + return generateSealedSchema(typeElement, typeUtils, elementUtils); + } + + return "Map.of()"; + } + + private String generateRecordSchema(TypeElement typeElement, Types typeUtils, Elements elementUtils) { + List propertyEntries = new ArrayList<>(); + List requiredNames = new ArrayList<>(); + + for (Element enclosed : typeElement.getEnclosedElements()) { + if (enclosed.getKind() == ElementKind.RECORD_COMPONENT) { + RecordComponentElement component = (RecordComponentElement) enclosed; + String name = component.getSimpleName().toString(); + TypeMirror componentType = component.asType(); + + boolean isOptional = isOptionalType(componentType); + String schema; + if (isOptional) { + schema = generateSchema(unwrapOptional(componentType, typeUtils), typeUtils, elementUtils); + } else { + schema = generateSchema(componentType, typeUtils, elementUtils); + requiredNames.add("\"" + name + "\""); + } + + propertyEntries.add("Map.entry(\"" + name + "\", " + schema + ")"); + } + } + + String properties = "Map.ofEntries(" + String.join(", ", propertyEntries) + ")"; + String required = "List.of(" + String.join(", ", requiredNames) + ")"; + + return "Map.of(\"type\", \"object\", \"properties\", " + properties + ", \"required\", " + required + ")"; + } + + private String generateClassSchema(TypeElement typeElement, Types typeUtils, Elements elementUtils) { + List propertyEntries = new ArrayList<>(); + List requiredNames = new ArrayList<>(); + + for (Element enclosed : typeElement.getEnclosedElements()) { + if (enclosed.getKind() == ElementKind.FIELD) { + VariableElement field = (VariableElement) enclosed; + // Skip static fields + if (field.getModifiers().contains(javax.lang.model.element.Modifier.STATIC)) { + continue; + } + String name = field.getSimpleName().toString(); + TypeMirror fieldType = field.asType(); + + boolean isOptional = isOptionalType(fieldType); + String schema; + if (isOptional) { + schema = generateSchema(unwrapOptional(fieldType, typeUtils), typeUtils, elementUtils); + } else { + schema = generateSchema(fieldType, typeUtils, elementUtils); + requiredNames.add("\"" + name + "\""); + } + + propertyEntries.add("Map.entry(\"" + name + "\", " + schema + ")"); + } + } + + if (propertyEntries.isEmpty()) { + return "Map.of(\"type\", \"object\")"; + } + + String properties = "Map.ofEntries(" + String.join(", ", propertyEntries) + ")"; + String required = "List.of(" + String.join(", ", requiredNames) + ")"; + + return "Map.of(\"type\", \"object\", \"properties\", " + properties + ", \"required\", " + required + ")"; + } + + private String generateSealedSchema(TypeElement typeElement, Types typeUtils, Elements elementUtils) { + List permittedSubclasses = typeElement.getPermittedSubclasses(); + if (permittedSubclasses != null && !permittedSubclasses.isEmpty()) { + List schemas = permittedSubclasses.stream().map(sub -> generateSchema(sub, typeUtils, elementUtils)) + .collect(Collectors.toList()); + return "Map.of(\"oneOf\", List.of(" + String.join(", ", schemas) + "))"; + } + return "Map.of(\"type\", \"object\")"; + } + + private boolean isOptionalType(TypeMirror type) { + if (type.getKind() != TypeKind.DECLARED) { + return false; + } + DeclaredType declaredType = (DeclaredType) type; + TypeElement element = (TypeElement) declaredType.asElement(); + String name = element.getQualifiedName().toString(); + return "java.util.Optional".equals(name) || "java.util.OptionalInt".equals(name) + || "java.util.OptionalDouble".equals(name) || "java.util.OptionalLong".equals(name); + } + + private TypeMirror unwrapOptional(TypeMirror type, Types typeUtils) { + if (type.getKind() != TypeKind.DECLARED) { + return type; + } + DeclaredType declaredType = (DeclaredType) type; + TypeElement element = (TypeElement) declaredType.asElement(); + String name = element.getQualifiedName().toString(); + + if ("java.util.Optional".equals(name)) { + List typeArgs = declaredType.getTypeArguments(); + if (!typeArgs.isEmpty()) { + return typeArgs.get(0); + } + } + if ("java.util.OptionalInt".equals(name)) { + return typeUtils.getPrimitiveType(TypeKind.INT); + } + if ("java.util.OptionalDouble".equals(name)) { + return typeUtils.getPrimitiveType(TypeKind.DOUBLE); + } + if ("java.util.OptionalLong".equals(name)) { + return typeUtils.getPrimitiveType(TypeKind.LONG); + } + return type; + } + + private boolean isCollectionType(String qualifiedName) { + return "java.util.List".equals(qualifiedName) || "java.util.Collection".equals(qualifiedName) + || "java.util.Set".equals(qualifiedName); + } + + private boolean isMapType(String qualifiedName) { + return "java.util.Map".equals(qualifiedName); + } +} diff --git a/java/src/main/java/module-info.java b/java/src/main/java/module-info.java index 9f48b3747..38bc1f93d 100644 --- a/java/src/main/java/module-info.java +++ b/java/src/main/java/module-info.java @@ -19,11 +19,13 @@ exports com.github.copilot.generated; exports com.github.copilot.generated.rpc; exports com.github.copilot.rpc; + exports com.github.copilot.tool; opens com.github.copilot to com.fasterxml.jackson.databind; opens com.github.copilot.generated to com.fasterxml.jackson.databind; opens com.github.copilot.generated.rpc to com.fasterxml.jackson.databind; opens com.github.copilot.rpc to com.fasterxml.jackson.databind; - provides javax.annotation.processing.Processor with com.github.copilot.CopilotExperimentalProcessor; + provides javax.annotation.processing.Processor + with com.github.copilot.CopilotExperimentalProcessor, com.github.copilot.tool.CopilotToolProcessor; } diff --git a/java/src/main/resources/META-INF/services/javax.annotation.processing.Processor b/java/src/main/resources/META-INF/services/javax.annotation.processing.Processor index 1e7feda8c..3b2e17d2f 100644 --- a/java/src/main/resources/META-INF/services/javax.annotation.processing.Processor +++ b/java/src/main/resources/META-INF/services/javax.annotation.processing.Processor @@ -1 +1,2 @@ com.github.copilot.CopilotExperimentalProcessor +com.github.copilot.tool.CopilotToolProcessor diff --git a/java/src/test/java/com/github/copilot/CopilotSessionTest.java b/java/src/test/java/com/github/copilot/CopilotSessionTest.java index 44a7373ec..eb061b029 100644 --- a/java/src/test/java/com/github/copilot/CopilotSessionTest.java +++ b/java/src/test/java/com/github/copilot/CopilotSessionTest.java @@ -756,8 +756,27 @@ void testShouldGetLastSessionId() throws Exception { ctx.configureForTest("session", "should_get_last_session_id"); try (CopilotClient client = ctx.createClient()) { - CopilotSession session = client - .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)).get(); + CopilotSession session = null; + for (int attempt = 1; attempt <= 2; attempt++) { + CompletableFuture createFuture = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)); + try { + session = createFuture.get(45, TimeUnit.SECONDS); + break; + } catch (java.util.concurrent.TimeoutException e) { + createFuture.cancel(true); + if (attempt == 2) { + throw e; + } + } catch (java.util.concurrent.ExecutionException e) { + if (e.getCause() instanceof java.util.concurrent.TimeoutException && attempt < 2) { + createFuture.cancel(true); + continue; + } + throw e; + } + } + assertNotNull(session, "Session should be created"); session.sendAndWait(new MessageOptions().setPrompt("Say hello")).get(60, TimeUnit.SECONDS); String sessionId = session.getSessionId(); diff --git a/java/src/test/java/com/github/copilot/e2e/ErgonomicTestTools$$CopilotToolMeta.java b/java/src/test/java/com/github/copilot/e2e/ErgonomicTestTools$$CopilotToolMeta.java new file mode 100644 index 000000000..703a6b010 --- /dev/null +++ b/java/src/test/java/com/github/copilot/e2e/ErgonomicTestTools$$CopilotToolMeta.java @@ -0,0 +1,49 @@ +// Hand-written test fixture mimicking CopilotToolProcessor output. +package com.github.copilot.e2e; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.rpc.ToolDefinition; +import com.github.copilot.tool.CopilotToolMetadataProvider; + +import java.util.*; +import java.util.concurrent.CompletableFuture; + +public final class ErgonomicTestTools$$CopilotToolMeta implements CopilotToolMetadataProvider { + + private static Map withMeta(Map base, String description, Object defaultValue) { + var result = new LinkedHashMap(base); + if (description != null) + result.put("description", description); + if (defaultValue != null) + result.put("default", defaultValue); + return Collections.unmodifiableMap(result); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public List definitions(ErgonomicTestTools instance, ObjectMapper mapper) { + return List.of(new ToolDefinition("set_current_phase", "Sets the current phase of the agent", + Map.of("type", "object", "properties", + Map.ofEntries(Map.entry("phase", + (Map) (Map) withMeta(Map.of("type", "string"), + "The phase to transition to", null))), + "required", List.of("phase")), + invocation -> { + Map args = invocation.getArguments(); + String phase = (String) args.get("phase"); + return CompletableFuture.completedFuture(instance.setCurrentPhase(phase)); + }, null, null, null), + new ToolDefinition( + "search_items", "Search for items by keyword", Map + .of("type", "object", "properties", + Map.ofEntries(Map.entry("keyword", + (Map) (Map) withMeta(Map.of("type", "string"), + "Search keyword", null))), + "required", List.of("keyword")), + invocation -> { + Map args = invocation.getArguments(); + String keyword = (String) args.get("keyword"); + return CompletableFuture.completedFuture(instance.searchItems(keyword)); + }, null, null, null)); + } +} diff --git a/java/src/test/java/com/github/copilot/e2e/ErgonomicTestTools.java b/java/src/test/java/com/github/copilot/e2e/ErgonomicTestTools.java new file mode 100644 index 000000000..35f191db9 --- /dev/null +++ b/java/src/test/java/com/github/copilot/e2e/ErgonomicTestTools.java @@ -0,0 +1,32 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.e2e; + +import com.github.copilot.tool.CopilotTool; +import com.github.copilot.tool.Param; + +/** + * Tool fixture for the ergonomic {@code @CopilotTool} E2E integration test. + * + *

+ * This class exercises the annotation-based tool definition API, producing + * identical wire-level tool schemas to the low-level + * {@code ToolDefinition.create()} API. + */ +class ErgonomicTestTools { + + String currentPhase; + + @CopilotTool("Sets the current phase of the agent") + public String setCurrentPhase(@Param("The phase to transition to") String phase) { + currentPhase = phase; + return "Phase set to " + phase; + } + + @CopilotTool("Search for items by keyword") + public String searchItems(@Param("Search keyword") String keyword) { + return "Found: " + keyword + " -> item_alpha, item_beta"; + } +} diff --git a/java/src/test/java/com/github/copilot/e2e/ErgonomicToolDefinitionIT.java b/java/src/test/java/com/github/copilot/e2e/ErgonomicToolDefinitionIT.java new file mode 100644 index 000000000..c74e94544 --- /dev/null +++ b/java/src/test/java/com/github/copilot/e2e/ErgonomicToolDefinitionIT.java @@ -0,0 +1,85 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.e2e; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import com.github.copilot.CopilotClient; +import com.github.copilot.CopilotSession; +import com.github.copilot.E2ETestContext; +import com.github.copilot.generated.AssistantMessageEvent; +import com.github.copilot.rpc.MessageOptions; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.SessionConfig; +import com.github.copilot.rpc.ToolDefinition; +import com.github.copilot.rpc.ToolSet; + +/** + * Failsafe integration test for the ergonomic {@code @CopilotTool} + + * {@code ToolDefinition.fromObject()} API. + * + *

+ * This test proves that the ergonomic annotation-based API produces identical + * wire behavior to the low-level {@code ToolDefinition.create()} API tested in + * {@code LowLevelToolDefinitionIT}. + * + * @see Snapshot: tools/ergonomic_tool_definition + */ +class ErgonomicToolDefinitionIT { + + private static E2ETestContext ctx; + + @BeforeAll + static void setup() throws Exception { + ctx = E2ETestContext.create(); + } + + @AfterAll + static void teardown() throws Exception { + if (ctx != null) { + ctx.close(); + } + } + + @Test + void ergonomicToolDefinition() throws Exception { + ctx.configureForTest("tools", "ergonomic_tool_definition"); + + ErgonomicTestTools tools = new ErgonomicTestTools(); + List toolDefs = ToolDefinition.fromObject(tools); + + try (CopilotClient client = ctx.createClient()) { + CopilotSession session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setAvailableTools(new ToolSet().addCustom("*").addBuiltIn("web_fetch")).setTools(toolDefs)) + .get(30, TimeUnit.SECONDS); + + try { + AssistantMessageEvent response = session.sendAndWait(new MessageOptions().setPrompt( + "First, set the current phase to 'analyzing'. Then search for items with keyword 'copilot'. Report the phase and search results."), + 60_000).get(90, TimeUnit.SECONDS); + + assertNotNull(response, "Expected a response from the assistant"); + String content = response.getData().content().toLowerCase(); + assertTrue(content.contains("analyzing"), + "Response should contain the updated phase: " + response.getData().content()); + assertTrue(content.contains("item_alpha") || content.contains("item_beta"), + "Response should contain search results: " + response.getData().content()); + assertTrue("analyzing".equals(tools.currentPhase), + "Expected currentPhase to be 'analyzing' but was: " + tools.currentPhase); + } finally { + session.close(); + } + } + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/ToolDefinitionFromObjectTest.java b/java/src/test/java/com/github/copilot/rpc/ToolDefinitionFromObjectTest.java new file mode 100644 index 000000000..25765e057 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/ToolDefinitionFromObjectTest.java @@ -0,0 +1,324 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import com.github.copilot.AllowCopilotExperimental; +import com.github.copilot.rpc.fixtures.ArgCoercionTools; +import com.github.copilot.rpc.fixtures.DateTimeTools; +import com.github.copilot.rpc.fixtures.DefaultValueTools; +import com.github.copilot.rpc.fixtures.MultiReturnTools; +import com.github.copilot.rpc.fixtures.OptionalParamTools; +import com.github.copilot.rpc.fixtures.OverrideTools; +import com.github.copilot.rpc.fixtures.SimpleTools; +import com.github.copilot.rpc.fixtures.StaticTools; + +/** + * End-to-end tests for {@link ToolDefinition#fromObject(Object)}. + *

+ * These tests use hand-written {@code $$CopilotToolMeta} companion classes + * under {@code com.github.copilot.rpc.fixtures} that mimic + * {@link com.github.copilot.tool.CopilotToolProcessor} output. + */ +@AllowCopilotExperimental +class ToolDefinitionFromObjectTest { + + // ── Test 1: Basic end-to-end ──────────────────────────────────────────────── + + @Test + void fromObject_returnsCorrectNumberOfTools() { + var tools = ToolDefinition.fromObject(new SimpleTools()); + assertEquals(2, tools.size()); + } + + @Test + void fromObject_toolNamesAndDescriptions() { + var tools = ToolDefinition.fromObject(new SimpleTools()); + var tool1 = findTool(tools, "greet_user"); + assertNotNull(tool1); + assertEquals("Greets a user by name", tool1.description()); + + var tool2 = findTool(tools, "add_numbers"); + assertNotNull(tool2); + assertEquals("Adds two numbers together", tool2.description()); + } + + @Test + void fromObject_toolParameterSchema() { + var tools = ToolDefinition.fromObject(new SimpleTools()); + var tool = findTool(tools, "greet_user"); + assertNotNull(tool); + @SuppressWarnings("unchecked") + var schema = (Map) tool.parameters(); + assertEquals("object", schema.get("type")); + @SuppressWarnings("unchecked") + var properties = (Map) schema.get("properties"); + assertTrue(properties.containsKey("name")); + @SuppressWarnings("unchecked") + var required = (List) schema.get("required"); + assertTrue(required.contains("name")); + } + + @Test + void fromObject_handlerInvocation() throws Exception { + var instance = new SimpleTools(); + var tools = ToolDefinition.fromObject(instance); + var tool = findTool(tools, "greet_user"); + assertNotNull(tool); + + var result = tool.handler().invoke(createInvocation("greet_user", Map.of("name", "Alice"))).get(); + assertEquals("Hello, Alice!", result); + } + + // ── Test 2: Handler return type patterns ──────────────────────────────────── + + @Test + void fromObject_stringReturn() throws Exception { + var tools = ToolDefinition.fromObject(new MultiReturnTools()); + var tool = findTool(tools, "string_method"); + assertNotNull(tool); + var result = tool.handler().invoke(createInvocation("string_method", Map.of())).get(); + assertEquals("hello", result); + } + + @Test + void fromObject_voidReturn() throws Exception { + var tools = ToolDefinition.fromObject(new MultiReturnTools()); + var tool = findTool(tools, "void_method"); + assertNotNull(tool); + var result = tool.handler().invoke(createInvocation("void_method", Map.of())).get(); + assertEquals("Success", result); + } + + @Test + void fromObject_asyncReturn() throws Exception { + var tools = ToolDefinition.fromObject(new MultiReturnTools()); + var tool = findTool(tools, "async_method"); + assertNotNull(tool); + var result = tool.handler().invoke(createInvocation("async_method", Map.of())).get(); + assertEquals("async result", result); + } + + // ── Test 3: Argument coercion ─────────────────────────────────────────────── + + @Test + void fromObject_argumentCoercion() throws Exception { + var instance = new ArgCoercionTools(); + var tools = ToolDefinition.fromObject(instance); + var tool = findTool(tools, "mixed_args"); + assertNotNull(tool); + + var result = tool.handler().invoke( + createInvocation("mixed_args", Map.of("text", "hello", "count", 5, "flag", true, "color", "RED"))) + .get(); + assertEquals("hello-5-true-RED", result); + } + + // ── Test 4: Default value ─────────────────────────────────────────────────── + + @Test + void fromObject_defaultValue() throws Exception { + var instance = new DefaultValueTools(); + var tools = ToolDefinition.fromObject(instance); + var tool = findTool(tools, "with_default"); + assertNotNull(tool); + + // Omit "count" key — should use default value 42 + var result = tool.handler().invoke(createInvocation("with_default", Map.of("label", "test"))).get(); + assertEquals("test:42", result); + } + + // ── Test 5: Error case — missing generated class ──────────────────────────── + + @Test + void fromObject_throwsOnMissingMetaClass() { + // A class that was never processed by CopilotToolProcessor + var ex = assertThrows(IllegalStateException.class, () -> ToolDefinition.fromObject("a plain String")); + assertTrue(ex.getMessage().contains("not found")); + assertTrue(ex.getMessage().contains("CopilotToolProcessor")); + } + + // ── Test 5b: fromClass rejects instance methods ───────────────────────────── + + @Test + void fromClass_throwsOnInstanceMethods() { + // SimpleTools has instance (non-static) @CopilotTool methods + var ex = assertThrows(IllegalArgumentException.class, () -> ToolDefinition.fromClass(SimpleTools.class)); + assertTrue(ex.getMessage().contains("fromClass()")); + assertTrue(ex.getMessage().contains("static")); + assertTrue(ex.getMessage().contains("fromObject")); + } + + // ── Test 6: java.time argument ────────────────────────────────────────────── + + @Test + void fromObject_javaTimeArgument() throws Exception { + var instance = new DateTimeTools(); + var tools = ToolDefinition.fromObject(instance); + var tool = findTool(tools, "schedule_event"); + assertNotNull(tool); + + var result = tool.handler().invoke(createInvocation("schedule_event", Map.of("when", "2024-06-15T10:30:00"))) + .get(); + assertEquals("Scheduled at 2024-06-15T10:30", result); + } + + // ── Test 7: Override tool ──────────────────────────────────────────────────── + + @Test + void fromObject_overrideTool() { + var tools = ToolDefinition.fromObject(new OverrideTools()); + var tool = findTool(tools, "grep"); + assertNotNull(tool); + assertEquals(Boolean.TRUE, tool.overridesBuiltInTool()); + } + + // ── Test 8: ToolDefer.NONE → null mapping (defer absent from JSON) ────────── + + @Test + void fromObject_deferNone_absentFromJson() throws Exception { + var tools = ToolDefinition.fromObject(new SimpleTools()); + var tool = findTool(tools, "greet_user"); + assertNotNull(tool); + // The defer field should be null (NONE maps to null) + assertNull(tool.defer()); + + // Serialize to JSON and verify "defer" key is absent + var mapper = new ObjectMapper(); + mapper.registerModule(new JavaTimeModule()); + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + mapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false); + mapper.setDefaultPropertyInclusion(JsonInclude.Include.NON_NULL); + + String json = mapper.writeValueAsString(tool); + var node = (ObjectNode) mapper.readTree(json); + assertFalse(node.has("defer"), "defer key should be absent from JSON, got: " + json); + } + + // ── Test 9: fromClass with static methods invokes handler without NPE ───── + + @Test + void fromClass_staticToolInvocation() throws Exception { + var tools = ToolDefinition.fromClass(StaticTools.class); + assertEquals(1, tools.size()); + var tool = findTool(tools, "greet"); + assertNotNull(tool); + + // This should NOT throw NPE — static methods don't need an instance + var result = tool.handler().invoke(createInvocation("greet", Map.of("name", "World"))).get(); + assertEquals("Hi, World!", result); + } + + // ── Test 10: Optional parameter handling ──────────────────────────────────── + + @Test + void fromObject_optionalStringPresent() throws Exception { + var instance = new OptionalParamTools(); + var tools = ToolDefinition.fromObject(instance); + var tool = findTool(tools, "greet_with_title"); + assertNotNull(tool); + + var result = tool.handler() + .invoke(createInvocation("greet_with_title", Map.of("name", "Alice", "title", "Dr."))).get(); + assertEquals("Dr. Alice", result); + } + + @Test + void fromObject_optionalStringAbsent() throws Exception { + var instance = new OptionalParamTools(); + var tools = ToolDefinition.fromObject(instance); + var tool = findTool(tools, "greet_with_title"); + assertNotNull(tool); + + var result = tool.handler().invoke(createInvocation("greet_with_title", Map.of("name", "Alice"))).get(); + assertEquals("Alice", result); + } + + @Test + void fromObject_optionalIntPresent() throws Exception { + var instance = new OptionalParamTools(); + var tools = ToolDefinition.fromObject(instance); + var tool = findTool(tools, "multiply"); + assertNotNull(tool); + + var result = tool.handler().invoke(createInvocation("multiply", Map.of("base", 5, "factor", 3))).get(); + assertEquals("15", result); + } + + @Test + void fromObject_optionalIntAbsent() throws Exception { + var instance = new OptionalParamTools(); + var tools = ToolDefinition.fromObject(instance); + var tool = findTool(tools, "multiply"); + assertNotNull(tool); + + var result = tool.handler().invoke(createInvocation("multiply", Map.of("base", 5))).get(); + assertEquals("5", result); + } + + @Test + void fromObject_optionalDoublePresent() throws Exception { + var instance = new OptionalParamTools(); + var tools = ToolDefinition.fromObject(instance); + var tool = findTool(tools, "scale"); + assertNotNull(tool); + + var result = tool.handler().invoke(createInvocation("scale", Map.of("value", 2.0, "ratio", 3.5))).get(); + assertEquals("7.0", result); + } + + @Test + void fromObject_optionalLongPresent() throws Exception { + var instance = new OptionalParamTools(); + var tools = ToolDefinition.fromObject(instance); + var tool = findTool(tools, "offset"); + assertNotNull(tool); + + var result = tool.handler().invoke(createInvocation("offset", Map.of("base", 100, "delta", 50))).get(); + assertEquals("150", result); + } + + @Test + void fromObject_optionalLongAbsent() throws Exception { + var instance = new OptionalParamTools(); + var tools = ToolDefinition.fromObject(instance); + var tool = findTool(tools, "offset"); + assertNotNull(tool); + + var result = tool.handler().invoke(createInvocation("offset", Map.of("base", 100))).get(); + assertEquals("100", result); + } + + // ── Helpers ───────────────────────────────────────────────────────────────── + + private static ToolDefinition findTool(List tools, String name) { + return tools.stream().filter(t -> name.equals(t.name())).findFirst().orElse(null); + } + + private static ToolInvocation createInvocation(String toolName, Map args) { + ObjectNode argsNode = JsonNodeFactory.instance.objectNode(); + ObjectMapper mapper = new ObjectMapper(); + argsNode.setAll((ObjectNode) mapper.valueToTree(args)); + return new ToolInvocation().setToolName(toolName).setArguments(argsNode); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/ArgCoercionTools$$CopilotToolMeta.java b/java/src/test/java/com/github/copilot/rpc/fixtures/ArgCoercionTools$$CopilotToolMeta.java new file mode 100644 index 000000000..882c0555f --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/ArgCoercionTools$$CopilotToolMeta.java @@ -0,0 +1,50 @@ +// Hand-written test fixture mimicking CopilotToolProcessor output. +package com.github.copilot.rpc.fixtures; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.rpc.ToolDefinition; +import com.github.copilot.tool.CopilotToolMetadataProvider; + +import java.util.*; +import java.util.concurrent.CompletableFuture; + +public final class ArgCoercionTools$$CopilotToolMeta implements CopilotToolMetadataProvider { + + private static Map withMeta(Map base, String description, Object defaultValue) { + var result = new LinkedHashMap(base); + if (description != null) + result.put("description", description); + if (defaultValue != null) + result.put("default", defaultValue); + return Collections.unmodifiableMap(result); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public List definitions(ArgCoercionTools instance, ObjectMapper mapper) { + return List + .of(new ToolDefinition("mixed_args", "Method with mixed argument types", Map.of( + "type", "object", "properties", Map + .ofEntries( + Map.entry("text", + (Map) (Map) withMeta(Map.of("type", "string"), + "Text input", null)), + Map.entry("count", + (Map) (Map) withMeta(Map.of("type", "integer"), + "A count", null)), + Map.entry("flag", + (Map) (Map) withMeta(Map.of("type", "boolean"), + "A flag", null)), + Map.entry("color", + (Map) (Map) withMeta(Map.of("type", "string", "enum", + List.of("RED", "GREEN", "BLUE")), "A color", null))), + "required", List.of("text", "count", "flag", "color")), invocation -> { + Map args = invocation.getArguments(); + String text = (String) args.get("text"); + int count = ((Number) args.get("count")).intValue(); + boolean flag = (Boolean) args.get("flag"); + ArgCoercionTools.Color color = ArgCoercionTools.Color.valueOf((String) args.get("color")); + return CompletableFuture.completedFuture(instance.mixedArgs(text, count, flag, color)); + }, null, null, null)); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/ArgCoercionTools.java b/java/src/test/java/com/github/copilot/rpc/fixtures/ArgCoercionTools.java new file mode 100644 index 000000000..7f85bd2c7 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/ArgCoercionTools.java @@ -0,0 +1,24 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc.fixtures; + +import com.github.copilot.tool.CopilotTool; +import com.github.copilot.tool.Param; + +/** + * Fixture testing argument coercion with multiple types including an enum. + */ +public class ArgCoercionTools { + + public enum Color { + RED, GREEN, BLUE + } + + @CopilotTool("Method with mixed argument types") + public String mixedArgs(@Param("Text input") String text, @Param("A count") int count, + @Param("A flag") boolean flag, @Param("A color") Color color) { + return text + "-" + count + "-" + flag + "-" + color.name(); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/DateTimeTools$$CopilotToolMeta.java b/java/src/test/java/com/github/copilot/rpc/fixtures/DateTimeTools$$CopilotToolMeta.java new file mode 100644 index 000000000..700133650 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/DateTimeTools$$CopilotToolMeta.java @@ -0,0 +1,38 @@ +// Hand-written test fixture mimicking CopilotToolProcessor output. +package com.github.copilot.rpc.fixtures; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.rpc.ToolDefinition; +import com.github.copilot.tool.CopilotToolMetadataProvider; + +import java.time.LocalDateTime; +import java.util.*; +import java.util.concurrent.CompletableFuture; + +public final class DateTimeTools$$CopilotToolMeta implements CopilotToolMetadataProvider { + + private static Map withMeta(Map base, String description, Object defaultValue) { + var result = new LinkedHashMap(base); + if (description != null) + result.put("description", description); + if (defaultValue != null) + result.put("default", defaultValue); + return Collections.unmodifiableMap(result); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public List definitions(DateTimeTools instance, ObjectMapper mapper) { + return List.of(new ToolDefinition("schedule_event", "Schedule an event at a given time", + Map.of("type", "object", "properties", + Map.ofEntries(Map.entry("when", + (Map) (Map) withMeta(Map.of("type", "string", "format", "date-time"), + "When to schedule", null))), + "required", List.of("when")), + invocation -> { + Map args = invocation.getArguments(); + LocalDateTime when = mapper.convertValue(args.get("when"), LocalDateTime.class); + return CompletableFuture.completedFuture(instance.scheduleEvent(when)); + }, null, null, null)); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/DateTimeTools.java b/java/src/test/java/com/github/copilot/rpc/fixtures/DateTimeTools.java new file mode 100644 index 000000000..541c2c6d8 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/DateTimeTools.java @@ -0,0 +1,24 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc.fixtures; + +import java.time.LocalDateTime; + +import com.github.copilot.tool.CopilotTool; +import com.github.copilot.tool.Param; + +/** + * Fixture testing java.time argument deserialization via ObjectMapper with + * JavaTimeModule. + */ +public class DateTimeTools { + + @CopilotTool("Schedule an event at a given time") + public String scheduleEvent(@Param(value = "When to schedule", required = true) LocalDateTime when) { + return "Scheduled at " + when.getYear() + "-" + String.format("%02d", when.getMonthValue()) + "-" + + String.format("%02d", when.getDayOfMonth()) + "T" + String.format("%02d", when.getHour()) + ":" + + String.format("%02d", when.getMinute()); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/DefaultValueTools$$CopilotToolMeta.java b/java/src/test/java/com/github/copilot/rpc/fixtures/DefaultValueTools$$CopilotToolMeta.java new file mode 100644 index 000000000..ee5369b84 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/DefaultValueTools$$CopilotToolMeta.java @@ -0,0 +1,45 @@ +// Hand-written test fixture mimicking CopilotToolProcessor output. +package com.github.copilot.rpc.fixtures; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.rpc.ToolDefinition; +import com.github.copilot.tool.CopilotToolMetadataProvider; + +import java.util.*; +import java.util.concurrent.CompletableFuture; + +public final class DefaultValueTools$$CopilotToolMeta implements CopilotToolMetadataProvider { + + private static Map withMeta(Map base, String description, Object defaultValue) { + var result = new LinkedHashMap(base); + if (description != null) + result.put("description", description); + if (defaultValue != null) + result.put("default", defaultValue); + return Collections.unmodifiableMap(result); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public List definitions(DefaultValueTools instance, ObjectMapper mapper) { + return List + .of(new ToolDefinition( + "with_default", "Method with a default value parameter", Map + .of("type", "object", "properties", + Map.ofEntries( + Map.entry("label", + (Map) (Map) withMeta(Map.of("type", "string"), + "A label", null)), + Map.entry("count", + (Map) (Map) withMeta(Map.of("type", "integer"), + "A count", 42))), + "required", List.of("label")), + invocation -> { + Map args = invocation.getArguments(); + String label = (String) args.get("label"); + Object countRaw = args.containsKey("count") ? args.get("count") : 42; + int count = ((Number) countRaw).intValue(); + return CompletableFuture.completedFuture(instance.withDefault(label, count)); + }, null, null, null)); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/DefaultValueTools.java b/java/src/test/java/com/github/copilot/rpc/fixtures/DefaultValueTools.java new file mode 100644 index 000000000..6e2c3106e --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/DefaultValueTools.java @@ -0,0 +1,20 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc.fixtures; + +import com.github.copilot.tool.CopilotTool; +import com.github.copilot.tool.Param; + +/** + * Fixture testing default parameter values. + */ +public class DefaultValueTools { + + @CopilotTool("Method with a default value parameter") + public String withDefault(@Param(value = "A label", required = true) String label, + @Param(value = "A count", required = false, defaultValue = "42") int count) { + return label + ":" + count; + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/MultiReturnTools$$CopilotToolMeta.java b/java/src/test/java/com/github/copilot/rpc/fixtures/MultiReturnTools$$CopilotToolMeta.java new file mode 100644 index 000000000..e1ac5c38d --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/MultiReturnTools$$CopilotToolMeta.java @@ -0,0 +1,29 @@ +// Hand-written test fixture mimicking CopilotToolProcessor output. +package com.github.copilot.rpc.fixtures; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.rpc.ToolDefinition; +import com.github.copilot.tool.CopilotToolMetadataProvider; + +import java.util.*; +import java.util.concurrent.CompletableFuture; + +public final class MultiReturnTools$$CopilotToolMeta implements CopilotToolMetadataProvider { + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public List definitions(MultiReturnTools instance, ObjectMapper mapper) { + return List.of(new ToolDefinition("string_method", "Returns a string", + Map.of("type", "object", "properties", Map.of(), "required", List.of()), invocation -> { + return CompletableFuture.completedFuture(instance.stringMethod()); + }, null, null, null), new ToolDefinition("void_method", "Void method", + Map.of("type", "object", "properties", Map.of(), "required", List.of()), invocation -> { + instance.voidMethod(); + return CompletableFuture.completedFuture("Success"); + }, null, null, null), + new ToolDefinition("async_method", "Async method", + Map.of("type", "object", "properties", Map.of(), "required", List.of()), invocation -> { + return instance.asyncMethod().thenApply(r -> (Object) r); + }, null, null, null)); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/MultiReturnTools.java b/java/src/test/java/com/github/copilot/rpc/fixtures/MultiReturnTools.java new file mode 100644 index 000000000..62a6a2500 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/MultiReturnTools.java @@ -0,0 +1,30 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc.fixtures; + +import java.util.concurrent.CompletableFuture; + +import com.github.copilot.tool.CopilotTool; + +/** + * Fixture testing different return type patterns. + */ +public class MultiReturnTools { + + @CopilotTool("Returns a string") + public String stringMethod() { + return "hello"; + } + + @CopilotTool("Void method") + public void voidMethod() { + // side-effect only + } + + @CopilotTool("Async method") + public CompletableFuture asyncMethod() { + return CompletableFuture.completedFuture("async result"); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/OptionalParamTools$$CopilotToolMeta.java b/java/src/test/java/com/github/copilot/rpc/fixtures/OptionalParamTools$$CopilotToolMeta.java new file mode 100644 index 000000000..df6c39fd6 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/OptionalParamTools$$CopilotToolMeta.java @@ -0,0 +1,101 @@ +// Hand-written test fixture mimicking CopilotToolProcessor output for Optional parameters. +package com.github.copilot.rpc.fixtures; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.rpc.ToolDefinition; +import com.github.copilot.tool.CopilotToolMetadataProvider; + +import java.util.*; +import java.util.concurrent.CompletableFuture; + +public final class OptionalParamTools$$CopilotToolMeta implements CopilotToolMetadataProvider { + + private static Map withMeta(Map base, String description, Object defaultValue) { + var result = new LinkedHashMap(base); + if (description != null) + result.put("description", description); + if (defaultValue != null) + result.put("default", defaultValue); + return Collections.unmodifiableMap(result); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public List definitions(OptionalParamTools instance, ObjectMapper mapper) { + return List.of(new ToolDefinition( + "greet_with_title", "Greet with optional title", Map + .of("type", "object", "properties", + Map.ofEntries( + Map.entry("name", + (Map) (Map) withMeta(Map.of("type", "string"), "Name", + null)), + Map.entry("title", + (Map) (Map) withMeta(Map.of("type", "string"), + "Optional title", null))), + "required", List.of("name")), + invocation -> { + Map args = invocation.getArguments(); + String name = (String) args.get("name"); + Object titleRaw = args.get("title"); + Optional title = titleRaw != null ? Optional.of((String) titleRaw) : Optional.empty(); + return CompletableFuture.completedFuture(instance.greetWithTitle(name, title)); + }, null, null, null), + new ToolDefinition("multiply", "Multiply with optional factor", + Map.of("type", "object", "properties", + Map.ofEntries( + Map.entry("base", + (Map) (Map) withMeta(Map.of("type", "integer"), + "Base value", null)), + Map.entry("factor", + (Map) (Map) withMeta(Map.of("type", "integer"), + "Optional factor", null))), + "required", List.of("base")), + invocation -> { + Map args = invocation.getArguments(); + int base = ((Number) args.get("base")).intValue(); + Object factorRaw = args.get("factor"); + OptionalInt factor = factorRaw != null + ? OptionalInt.of(((Number) factorRaw).intValue()) + : OptionalInt.empty(); + return CompletableFuture.completedFuture(instance.multiply(base, factor)); + }, null, null, null), + new ToolDefinition("scale", "Scale with optional ratio", + Map.of("type", "object", "properties", + Map.ofEntries( + Map.entry("value", + (Map) (Map) withMeta(Map.of("type", "number"), "Value", + null)), + Map.entry("ratio", + (Map) (Map) withMeta(Map.of("type", "number"), + "Optional ratio", null))), + "required", List.of("value")), + invocation -> { + Map args = invocation.getArguments(); + double value = ((Number) args.get("value")).doubleValue(); + Object ratioRaw = args.get("ratio"); + OptionalDouble ratio = ratioRaw != null + ? OptionalDouble.of(((Number) ratioRaw).doubleValue()) + : OptionalDouble.empty(); + return CompletableFuture.completedFuture(instance.scale(value, ratio)); + }, null, null, null), + new ToolDefinition("offset", "Offset with optional delta", + Map.of("type", "object", "properties", + Map.ofEntries( + Map.entry("base", + (Map) (Map) withMeta(Map.of("type", "integer"), "Base", + null)), + Map.entry("delta", + (Map) (Map) withMeta(Map.of("type", "integer"), + "Optional delta", null))), + "required", List.of("base")), + invocation -> { + Map args = invocation.getArguments(); + long base = ((Number) args.get("base")).longValue(); + Object deltaRaw = args.get("delta"); + OptionalLong delta = deltaRaw != null + ? OptionalLong.of(((Number) deltaRaw).longValue()) + : OptionalLong.empty(); + return CompletableFuture.completedFuture(instance.offset(base, delta)); + }, null, null, null)); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/OptionalParamTools.java b/java/src/test/java/com/github/copilot/rpc/fixtures/OptionalParamTools.java new file mode 100644 index 000000000..98e7dda62 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/OptionalParamTools.java @@ -0,0 +1,40 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc.fixtures; + +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalInt; +import java.util.OptionalLong; + +import com.github.copilot.tool.CopilotTool; +import com.github.copilot.tool.Param; + +/** + * Tool fixture with Optional parameter types for testing correct argument + * extraction (null-check + wrapping instead of mapper.convertValue). + */ +public class OptionalParamTools { + + @CopilotTool("Greet with optional title") + public String greetWithTitle(@Param("Name") String name, @Param("Optional title") Optional title) { + return title.map(t -> t + " " + name).orElse(name); + } + + @CopilotTool("Multiply with optional factor") + public String multiply(@Param("Base value") int base, @Param("Optional factor") OptionalInt factor) { + return String.valueOf(base * factor.orElse(1)); + } + + @CopilotTool("Scale with optional ratio") + public String scale(@Param("Value") double value, @Param("Optional ratio") OptionalDouble ratio) { + return String.valueOf(value * ratio.orElse(1.0)); + } + + @CopilotTool("Offset with optional delta") + public String offset(@Param("Base") long base, @Param("Optional delta") OptionalLong delta) { + return String.valueOf(base + delta.orElse(0L)); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/OverrideTools$$CopilotToolMeta.java b/java/src/test/java/com/github/copilot/rpc/fixtures/OverrideTools$$CopilotToolMeta.java new file mode 100644 index 000000000..127dc922b --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/OverrideTools$$CopilotToolMeta.java @@ -0,0 +1,39 @@ +// Hand-written test fixture mimicking CopilotToolProcessor output. +package com.github.copilot.rpc.fixtures; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.rpc.ToolDefinition; +import com.github.copilot.tool.CopilotToolMetadataProvider; + +import java.util.*; +import java.util.concurrent.CompletableFuture; + +public final class OverrideTools$$CopilotToolMeta implements CopilotToolMetadataProvider { + + private static Map withMeta(Map base, String description, Object defaultValue) { + var result = new LinkedHashMap(base); + if (description != null) + result.put("description", description); + if (defaultValue != null) + result.put("default", defaultValue); + return Collections.unmodifiableMap(result); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public List definitions(OverrideTools instance, ObjectMapper mapper) { + return List + .of(new ToolDefinition( + "grep", "Custom grep implementation", Map + .of("type", "object", "properties", + Map.ofEntries(Map.entry("pattern", + (Map) (Map) withMeta(Map.of("type", "string"), + "Search pattern", null))), + "required", List.of("pattern")), + invocation -> { + Map args = invocation.getArguments(); + String pattern = (String) args.get("pattern"); + return CompletableFuture.completedFuture(instance.customGrep(pattern)); + }, Boolean.TRUE, null, null)); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/OverrideTools.java b/java/src/test/java/com/github/copilot/rpc/fixtures/OverrideTools.java new file mode 100644 index 000000000..5fbb432f9 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/OverrideTools.java @@ -0,0 +1,19 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc.fixtures; + +import com.github.copilot.tool.CopilotTool; +import com.github.copilot.tool.Param; + +/** + * Fixture testing tool override flag. + */ +public class OverrideTools { + + @CopilotTool(value = "Custom grep implementation", name = "grep", overridesBuiltInTool = true) + public String customGrep(@Param(value = "Search pattern", required = true) String pattern) { + return "Found: " + pattern; + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/SimpleTools$$CopilotToolMeta.java b/java/src/test/java/com/github/copilot/rpc/fixtures/SimpleTools$$CopilotToolMeta.java new file mode 100644 index 000000000..0b52bd1ef --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/SimpleTools$$CopilotToolMeta.java @@ -0,0 +1,51 @@ +// Hand-written test fixture mimicking CopilotToolProcessor output. +package com.github.copilot.rpc.fixtures; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.rpc.ToolDefinition; +import com.github.copilot.tool.CopilotToolMetadataProvider; + +import java.util.*; +import java.util.concurrent.CompletableFuture; + +public final class SimpleTools$$CopilotToolMeta implements CopilotToolMetadataProvider { + + private static Map withMeta(Map base, String description, Object defaultValue) { + var result = new LinkedHashMap(base); + if (description != null) + result.put("description", description); + if (defaultValue != null) + result.put("default", defaultValue); + return Collections.unmodifiableMap(result); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public List definitions(SimpleTools instance, ObjectMapper mapper) { + return List.of(new ToolDefinition("greet_user", "Greets a user by name", + Map.of("type", "object", "properties", Map.ofEntries(Map.entry("name", + (Map) (Map) withMeta(Map.of("type", "string"), "The user's name", null))), + "required", List.of("name")), + invocation -> { + Map args = invocation.getArguments(); + String name = (String) args.get("name"); + return CompletableFuture.completedFuture(instance.greetUser(name)); + }, null, null, null), + new ToolDefinition("add_numbers", "Adds two numbers together", + Map.of("type", "object", "properties", + Map.ofEntries( + Map.entry("a", + (Map) (Map) withMeta(Map.of("type", "integer"), + "First number", null)), + Map.entry("b", + (Map) (Map) withMeta(Map.of("type", "integer"), + "Second number", null))), + "required", List.of("a", "b")), + invocation -> { + Map args = invocation.getArguments(); + int a = ((Number) args.get("a")).intValue(); + int b = ((Number) args.get("b")).intValue(); + return CompletableFuture.completedFuture(instance.addNumbers(a, b)); + }, null, null, null)); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/SimpleTools.java b/java/src/test/java/com/github/copilot/rpc/fixtures/SimpleTools.java new file mode 100644 index 000000000..5bdee36e5 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/SimpleTools.java @@ -0,0 +1,24 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc.fixtures; + +import com.github.copilot.tool.CopilotTool; +import com.github.copilot.tool.Param; + +/** + * Simple tool fixture with basic String-returning methods. + */ +public class SimpleTools { + + @CopilotTool("Greets a user by name") + public String greetUser(@Param(value = "The user's name", required = true) String name) { + return "Hello, " + name + "!"; + } + + @CopilotTool("Adds two numbers together") + public String addNumbers(@Param(value = "First number") int a, @Param(value = "Second number") int b) { + return String.valueOf(a + b); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/StaticTools$$CopilotToolMeta.java b/java/src/test/java/com/github/copilot/rpc/fixtures/StaticTools$$CopilotToolMeta.java new file mode 100644 index 000000000..842547b68 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/StaticTools$$CopilotToolMeta.java @@ -0,0 +1,37 @@ +// Hand-written test fixture mimicking CopilotToolProcessor output for static methods. +package com.github.copilot.rpc.fixtures; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.rpc.ToolDefinition; +import com.github.copilot.tool.CopilotToolMetadataProvider; + +import java.util.*; +import java.util.concurrent.CompletableFuture; + +public final class StaticTools$$CopilotToolMeta implements CopilotToolMetadataProvider { + + private static Map withMeta(Map base, String description, Object defaultValue) { + var result = new LinkedHashMap(base); + if (description != null) + result.put("description", description); + if (defaultValue != null) + result.put("default", defaultValue); + return Collections.unmodifiableMap(result); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public List definitions(StaticTools instance, ObjectMapper mapper) { + return List.of(new ToolDefinition("greet", "Returns a greeting for the given name", + Map.of("type", "object", "properties", Map.ofEntries(Map.entry("name", + (Map) (Map) withMeta(Map.of("type", "string"), "The name to greet", null))), + "required", List.of("name")), + invocation -> { + Map args = invocation.getArguments(); + String name = (String) args.get("name"); + // Mimics what the processor now generates for static methods: + // QualifiedClassName.method(...) instead of instance.method(...) + return CompletableFuture.completedFuture(StaticTools.greet(name)); + }, null, null, null)); + } +} diff --git a/java/src/test/java/com/github/copilot/rpc/fixtures/StaticTools.java b/java/src/test/java/com/github/copilot/rpc/fixtures/StaticTools.java new file mode 100644 index 000000000..7e681aa46 --- /dev/null +++ b/java/src/test/java/com/github/copilot/rpc/fixtures/StaticTools.java @@ -0,0 +1,20 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc.fixtures; + +import com.github.copilot.tool.CopilotTool; +import com.github.copilot.tool.Param; + +/** + * Tool fixture with a static {@code @CopilotTool} method, used to test + * {@code ToolDefinition.fromClass()} invocation path. + */ +public class StaticTools { + + @CopilotTool("Returns a greeting for the given name") + public static String greet(@Param(value = "The name to greet", required = true) String name) { + return "Hi, " + name + "!"; + } +} diff --git a/java/src/test/java/com/github/copilot/tool/CopilotToolAnnotationTest.java b/java/src/test/java/com/github/copilot/tool/CopilotToolAnnotationTest.java new file mode 100644 index 000000000..9052c6b1c --- /dev/null +++ b/java/src/test/java/com/github/copilot/tool/CopilotToolAnnotationTest.java @@ -0,0 +1,154 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.tool; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.InputStream; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CompletableFuture; + +import org.junit.jupiter.api.Test; + +import com.github.copilot.CopilotExperimental; +import com.github.copilot.rpc.ToolDefer; + +/** + * Unit tests for {@link CopilotTool} and {@link Param} annotations. + */ +public class CopilotToolAnnotationTest { + + // --- @CopilotTool attribute verification --- + + @Test + void copilotToolHasRuntimeRetention() { + Retention retention = CopilotTool.class.getAnnotation(Retention.class); + assertNotNull(retention); + assertEquals(RetentionPolicy.RUNTIME, retention.value()); + } + + @Test + void copilotToolTargetsMethod() { + Target target = CopilotTool.class.getAnnotation(Target.class); + assertNotNull(target); + assertArrayEquals(new ElementType[]{ElementType.METHOD}, target.value()); + } + + @Test + void copilotExperimentalTargetsTypeForAnnotationDeclarations() { + Target expTarget = CopilotExperimental.class.getAnnotation(Target.class); + assertNotNull(expTarget); + boolean includesType = false; + for (ElementType et : expTarget.value()) { + if (et == ElementType.TYPE) { + includesType = true; + break; + } + } + assertTrue(includesType, "@CopilotExperimental must target TYPE to be applicable to annotation declarations"); + } + + @Test + void copilotToolDeclaresCopilotExperimentalInClassFile() throws Exception { + String classFileResourcePath = "/" + CopilotTool.class.getName().replace('.', '/') + ".class"; + try (InputStream classFile = CopilotTool.class.getResourceAsStream(classFileResourcePath)) { + assertNotNull(classFile, "CopilotTool class file must be readable as a resource"); + String classFileText = new String(classFile.readAllBytes(), StandardCharsets.ISO_8859_1); + assertTrue(classFileText.contains("com/github/copilot/CopilotExperimental")); + } + } + + @Test + void copilotToolDefaultValues() throws Exception { + Method nameMethod = CopilotTool.class.getDeclaredMethod("name"); + assertEquals("", nameMethod.getDefaultValue()); + + Method overridesMethod = CopilotTool.class.getDeclaredMethod("overridesBuiltInTool"); + assertEquals(false, overridesMethod.getDefaultValue()); + + Method skipMethod = CopilotTool.class.getDeclaredMethod("skipPermission"); + assertEquals(false, skipMethod.getDefaultValue()); + + Method deferMethod = CopilotTool.class.getDeclaredMethod("defer"); + assertEquals(ToolDefer.NONE, deferMethod.getDefaultValue()); + } + + // --- @Param attribute verification --- + + @Test + void paramHasRuntimeRetention() { + Retention retention = Param.class.getAnnotation(Retention.class); + assertNotNull(retention); + assertEquals(RetentionPolicy.RUNTIME, retention.value()); + } + + @Test + void paramTargetsParameter() { + Target target = Param.class.getAnnotation(Target.class); + assertNotNull(target); + assertArrayEquals(new ElementType[]{ElementType.PARAMETER}, target.value()); + } + + @Test + void paramDefaultValues() throws Exception { + Method valueMethod = Param.class.getDeclaredMethod("value"); + assertEquals("", valueMethod.getDefaultValue()); + + Method nameMethod = Param.class.getDeclaredMethod("name"); + assertEquals("", nameMethod.getDefaultValue()); + + Method requiredMethod = Param.class.getDeclaredMethod("required"); + assertEquals(true, requiredMethod.getDefaultValue()); + + Method defaultValueMethod = Param.class.getDeclaredMethod("defaultValue"); + assertEquals("", defaultValueMethod.getDefaultValue()); + } + + // --- Applicability test --- + + @SuppressWarnings("unused") + static class SampleToolHolder { + + @CopilotTool(value = "Get weather for a location", name = "get_weather", defer = ToolDefer.AUTO) + public CompletableFuture getWeather(@Param(value = "City name", required = true) String location, + @Param(value = "Temperature unit", required = false, defaultValue = "celsius") String unit) { + return CompletableFuture.completedFuture("Sunny in " + location); + } + } + + @Test + void annotationsAreAccessibleViaReflection() throws Exception { + Method method = SampleToolHolder.class.getDeclaredMethod("getWeather", String.class, String.class); + + CopilotTool toolAnnotation = method.getAnnotation(CopilotTool.class); + assertNotNull(toolAnnotation); + assertEquals("Get weather for a location", toolAnnotation.value()); + assertEquals("get_weather", toolAnnotation.name()); + assertFalse(toolAnnotation.overridesBuiltInTool()); + assertFalse(toolAnnotation.skipPermission()); + assertEquals(ToolDefer.AUTO, toolAnnotation.defer()); + + Parameter[] params = method.getParameters(); + assertEquals(2, params.length); + + Param locationParam = params[0].getAnnotation(Param.class); + assertNotNull(locationParam); + assertEquals("City name", locationParam.value()); + assertTrue(locationParam.required()); + assertEquals("", locationParam.defaultValue()); + + Param unitParam = params[1].getAnnotation(Param.class); + assertNotNull(unitParam); + assertEquals("Temperature unit", unitParam.value()); + assertFalse(unitParam.required()); + assertEquals("celsius", unitParam.defaultValue()); + } +} diff --git a/java/src/test/java/com/github/copilot/tool/CopilotToolProcessorTest.java b/java/src/test/java/com/github/copilot/tool/CopilotToolProcessorTest.java new file mode 100644 index 000000000..12077f718 --- /dev/null +++ b/java/src/test/java/com/github/copilot/tool/CopilotToolProcessorTest.java @@ -0,0 +1,728 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.tool; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.File; +import java.io.FilterWriter; +import java.io.IOException; +import java.io.Writer; +import java.net.URI; +import java.nio.file.Path; +import java.security.CodeSource; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.tools.Diagnostic; +import javax.tools.DiagnosticCollector; +import javax.tools.FileObject; +import javax.tools.ForwardingJavaFileManager; +import javax.tools.ForwardingJavaFileObject; +import javax.tools.JavaCompiler; +import javax.tools.JavaFileObject; +import javax.tools.SimpleJavaFileObject; +import javax.tools.StandardJavaFileManager; +import javax.tools.StandardLocation; +import javax.tools.ToolProvider; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +/** + * Tests that {@link CopilotToolProcessor} correctly generates + * {@code $$CopilotToolMeta} companion classes and emits compile errors for + * invalid usages. + */ +class CopilotToolProcessorTest { + + @TempDir + java.nio.file.Path tempDir; + + // ── Test: Basic generation ────────────────────────────────────────────────── + + @Test + void generatesMetaClass_withCorrectToolNames() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.tool.Param; + public class MyTools { + @CopilotTool("Sets the current phase") + public String setCurrentPhase(@Param("The phase") String phase) { + return "done"; + } + @CopilotTool("Search for items") + public String searchItems(@Param("Keyword") String keyword) { + return "found"; + } + @CopilotTool(value = "Custom grep", name = "grep") + public String grepOverride(@Param("Query") String query) { + return "result"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.MyTools", source))); + + assertNoErrors(result); + // Verify generated source contains the expected tool names + String generated = result.getGeneratedSource("test.MyTools$$CopilotToolMeta"); + assertTrue(generated != null, "Expected $$CopilotToolMeta to be generated"); + assertTrue(generated.contains("\"set_current_phase\""), "Expected snake_case name: set_current_phase"); + assertTrue(generated.contains("\"search_items\""), "Expected snake_case name: search_items"); + assertTrue(generated.contains("\"grep\""), "Expected explicit name: grep"); + } + + // ── Test: Compile error for private methods ───────────────────────────────── + + @Test + void emitsError_forPrivateMethods() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + public class PrivateTools { + @CopilotTool("Private tool") + private String doSomething() { + return "done"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.PrivateTools", source))); + + assertTrue(hasErrorContaining(result, "must not be private"), + "Expected compile error for private @CopilotTool method, got: " + result.diagnostics); + } + + // ── Test: Compile error for required + defaultValue conflict ───────────── + + @Test + void emitsError_forRequiredWithDefaultValue() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.tool.Param; + public class ConflictTools { + @CopilotTool("Conflicting params") + public String doSomething(@Param(value = "desc", required = true, defaultValue = "hello") String param) { + return "done"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.ConflictTools", source))); + + assertTrue(hasErrorContaining(result, "required=true"), + "Expected compile error for required+defaultValue conflict, got: " + result.diagnostics); + } + + // ── Test: Return type handling ────────────────────────────────────────────── + + @Test + void generatesCorrectCode_forStringReturnType() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.tool.Param; + public class StringReturn { + @CopilotTool("Returns string") + public String doSomething(@Param("Input") String input) { + return input; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.StringReturn", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.StringReturn$$CopilotToolMeta"); + assertTrue(generated.contains("CompletableFuture.completedFuture(instance.doSomething("), + "Expected completedFuture wrapping for String return, got:\n" + generated); + } + + @Test + void generatesCorrectCode_forVoidReturnType() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.tool.Param; + public class VoidReturn { + @CopilotTool("Void method") + public void doSomething(@Param("Input") String input) { + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.VoidReturn", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.VoidReturn$$CopilotToolMeta"); + assertTrue(generated.contains("instance.doSomething("), "Expected method call in generated code"); + assertTrue(generated.contains("CompletableFuture.completedFuture(\"Success\")"), + "Expected 'Success' return for void methods, got:\n" + generated); + } + + @Test + void generatesCorrectCode_forCompletableFutureStringReturnType() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.tool.Param; + import java.util.concurrent.CompletableFuture; + public class AsyncReturn { + @CopilotTool("Async method") + public CompletableFuture doSomething(@Param("Input") String input) { + return CompletableFuture.completedFuture(input); + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.AsyncReturn", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.AsyncReturn$$CopilotToolMeta"); + assertTrue(generated.contains("return instance.doSomething("), + "Expected direct return for CompletableFuture, got:\n" + generated); + assertTrue(generated.contains("thenApply(r -> (Object) r)"), + "Expected thenApply cast for CompletableFuture, got:\n" + generated); + } + + @Test + void generatesCorrectCode_forIntReturnType() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.tool.Param; + public class IntReturn { + @CopilotTool("Returns int") + public int doSomething(@Param("Input") String input) { + return 42; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.IntReturn", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.IntReturn$$CopilotToolMeta"); + assertTrue(generated.contains("mapper.writeValueAsString(instance.doSomething("), + "Expected JSON serialization for int return type, got:\n" + generated); + } + + // ── Test: Argument coercion ───────────────────────────────────────────────── + + @Test + void generatesCorrectArgExtraction_forPrimitiveAndStringTypes() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.tool.Param; + public class ArgTypes { + @CopilotTool("Mixed args") + public String doSomething( + @Param("Name") String name, + @Param("Count") int count, + @Param("Flag") boolean flag) { + return "done"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.ArgTypes", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.ArgTypes$$CopilotToolMeta"); + assertTrue(generated.contains("(String) args.get(\"name\")"), + "Expected String cast for String param, got:\n" + generated); + assertTrue(generated.contains("((Number) args.get(\"count\")).intValue()"), + "Expected Number cast for int param, got:\n" + generated); + assertTrue(generated.contains("(Boolean) args.get(\"flag\")"), + "Expected Boolean cast for boolean param, got:\n" + generated); + } + + // ── Test: snake_case conversion ───────────────────────────────────────────── + + @Test + void snakeCaseConversion() { + assertEquals("set_current_phase", CopilotToolProcessor.toSnakeCase("setCurrentPhase")); + assertEquals("search_items", CopilotToolProcessor.toSnakeCase("searchItems")); + assertEquals("grep", CopilotToolProcessor.toSnakeCase("grep")); + assertEquals("get_u_r_l", CopilotToolProcessor.toSnakeCase("getURL")); + assertEquals("a", CopilotToolProcessor.toSnakeCase("a")); + assertEquals("", CopilotToolProcessor.toSnakeCase("")); + } + + // ── Test: Processor registration ──────────────────────────────────────────── + + @Test + void processorIsRegisteredInMetaInfServices() throws Exception { + var resource = getClass().getClassLoader() + .getResource("META-INF/services/javax.annotation.processing.Processor"); + assertTrue(resource != null, "META-INF/services/javax.annotation.processing.Processor should exist"); + String content = new String(resource.openStream().readAllBytes()); + assertTrue(content.contains("com.github.copilot.tool.CopilotToolProcessor"), + "Service file should contain CopilotToolProcessor"); + } + + // ── Test: Schema generation in generated code ─────────────────────────────── + + @Test + void generatesCorrectSchema() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.tool.Param; + public class SchemaTools { + @CopilotTool("Search items") + public String search( + @Param(value = "Query", required = true) String query, + @Param(value = "Limit", required = false) int limit) { + return "done"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.SchemaTools", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.SchemaTools$$CopilotToolMeta"); + // Verify the schema contains the expected keys + assertTrue(generated.contains("\"type\", \"object\""), "Expected object type in schema"); + assertTrue(generated.contains("\"properties\""), "Expected properties in schema"); + assertTrue(generated.contains("\"required\""), "Expected required in schema"); + assertTrue(generated.contains("\"query\""), "Expected query property"); + } + + // ── Test: Typed default values in schema ──────────────────────────────────── + + @Test + void emitsTypedDefaultValuesInSchema() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.tool.Param; + public class DefaultTools { + @CopilotTool("Tool with defaults") + public String doWork( + @Param(value = "Limit", required = false, defaultValue = "10") int limit, + @Param(value = "Enabled", required = false, defaultValue = "true") boolean enabled, + @Param(value = "Label", required = false, defaultValue = "hello") String label) { + return "done"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.DefaultTools", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.DefaultTools$$CopilotToolMeta"); + assertNotNull(generated, "Expected generated source for DefaultTools$$CopilotToolMeta"); + + // Numeric default should be an unquoted literal, not a string + assertTrue(generated.contains("withMeta(") && generated.contains(", 10)"), + "Expected numeric default 10 as typed literal, not string. Generated:\n" + generated); + // Boolean default should be an unquoted literal + assertTrue(generated.contains(", true)"), + "Expected boolean default true as typed literal, not string. Generated:\n" + generated); + // String default should remain a quoted string + assertTrue(generated.contains(", \"hello\")"), + "Expected string default \"hello\" as quoted string. Generated:\n" + generated); + } + + // ── Test: package-private methods are allowed ─────────────────────────────── + + @Test + void allowsPackagePrivateMethods() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + public class PackagePrivateTools { + @CopilotTool("Package private tool") + String doSomething() { + return "done"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.PackagePrivateTools", source))); + assertNoErrors(result); + } + + // ── Test: protected methods are allowed ───────────────────────────────────── + + @Test + void allowsProtectedMethods() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + public class ProtectedTools { + @CopilotTool("Protected tool") + protected String doSomething() { + return "done"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.ProtectedTools", source))); + assertNoErrors(result); + } + + // ── Test: overridesBuiltInTool generates createOverride ───────────────────── + + @Test + void generatesCreateOverride_whenOverridesBuiltInTool() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.tool.Param; + public class OverrideTools { + @CopilotTool(value = "Custom grep", name = "grep", overridesBuiltInTool = true) + public String grep(@Param("Query") String query) { + return "result"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.OverrideTools", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.OverrideTools$$CopilotToolMeta"); + assertTrue(generated.contains("new ToolDefinition("), "Expected record constructor, got:\n" + generated); + assertTrue(generated.contains("Boolean.TRUE"), + "Expected Boolean.TRUE for overridesBuiltInTool, got:\n" + generated); + } + + // ── Test: Combined flags all apply independently ──────────────────────────── + + @Test + void generatesCombinedFlags() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.rpc.ToolDefer; + public class CombinedTools { + @CopilotTool(value = "Combined", overridesBuiltInTool = true, skipPermission = true, defer = ToolDefer.AUTO) + public String doAll() { + return "done"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.CombinedTools", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.CombinedTools$$CopilotToolMeta"); + assertNotNull(generated, "Expected generated source for CombinedTools$$CopilotToolMeta"); + assertTrue(generated.contains("new ToolDefinition("), "Expected record constructor, got:\n" + generated); + // All three flags must be present — not silently dropped + assertTrue(generated.contains("Boolean.TRUE"), + "Expected Boolean.TRUE for override/skipPermission, got:\n" + generated); + assertTrue(generated.contains("ToolDefer.AUTO"), "Expected ToolDefer.AUTO, got:\n" + generated); + // Count Boolean.TRUE occurrences — should be 2 (overridesBuiltInTool + + // skipPermission) + long boolCount = generated.lines().filter(l -> l.contains("Boolean.TRUE")).count(); + assertEquals(2, boolCount, + "Expected 2 Boolean.TRUE lines (overridesBuiltInTool + skipPermission), got:\n" + generated); + } + + // ── Test: ToolDefer.NONE results in regular create ────────────────────────── + + @Test + void generatesCreate_whenDeferIsNone() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.rpc.ToolDefer; + public class DeferNoneTools { + @CopilotTool(value = "Simple tool", defer = ToolDefer.NONE) + public String doSomething() { + return "done"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.DeferNoneTools", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.DeferNoneTools$$CopilotToolMeta"); + assertTrue(generated.contains("new ToolDefinition("), + "Expected record constructor for NONE, got:\n" + generated); + assertFalse(generated.contains("ToolDefer."), "Should NOT reference ToolDefer for NONE, got:\n" + generated); + } + + // ── Test: ToolDefer.AUTO results in createWithDefer ────────────────────────── + + @Test + void generatesCreateWithDefer_whenDeferIsAuto() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.rpc.ToolDefer; + public class DeferAutoTools { + @CopilotTool(value = "Deferrable tool", defer = ToolDefer.AUTO) + public String doSomething() { + return "done"; + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.DeferAutoTools", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.DeferAutoTools$$CopilotToolMeta"); + assertTrue(generated.contains("new ToolDefinition("), + "Expected record constructor for AUTO, got:\n" + generated); + assertTrue(generated.contains("ToolDefer.AUTO"), "Expected ToolDefer.AUTO argument, got:\n" + generated); + } + + // ── Test: Optional parameter extraction ───────────────────────────────────── + + @Test + void generatesCorrectOptionalExtraction() { + String source = """ + package test; + import com.github.copilot.tool.CopilotTool; + import com.github.copilot.tool.Param; + import java.util.Optional; + import java.util.OptionalInt; + import java.util.OptionalLong; + import java.util.OptionalDouble; + public class OptionalTools { + @CopilotTool("Tool with optional string") + public String withOptionalString(@Param("A name") Optional name) { + return name.orElse("default"); + } + @CopilotTool("Tool with optional int") + public String withOptionalInt(@Param("A count") OptionalInt count) { + return String.valueOf(count.orElse(0)); + } + @CopilotTool("Tool with optional long") + public String withOptionalLong(@Param("A timestamp") OptionalLong ts) { + return String.valueOf(ts.orElse(0L)); + } + @CopilotTool("Tool with optional double") + public String withOptionalDouble(@Param("A ratio") OptionalDouble ratio) { + return String.valueOf(ratio.orElse(0.0)); + } + } + """; + + CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.OptionalTools", source))); + assertNoErrors(result); + String generated = result.getGeneratedSource("test.OptionalTools$$CopilotToolMeta"); + assertNotNull(generated, "Expected $$CopilotToolMeta to be generated"); + + // Optional should use null-check + Optional.of wrapping + assertTrue(generated.contains("Optional.of(") || generated.contains("java.util.Optional.of("), + "Expected Optional.of() wrapping for Optional, got:\n" + generated); + assertTrue(generated.contains("Optional.empty()") || generated.contains("java.util.Optional.empty()"), + "Expected Optional.empty() fallback, got:\n" + generated); + + // OptionalInt should use OptionalInt.of(((Number)...).intValue()) + assertTrue(generated.contains("OptionalInt.of(((Number)"), + "Expected OptionalInt.of(((Number)...).intValue()), got:\n" + generated); + assertTrue(generated.contains("OptionalInt.empty()"), + "Expected OptionalInt.empty() fallback, got:\n" + generated); + + // OptionalLong should use OptionalLong.of(((Number)...).longValue()) + assertTrue(generated.contains("OptionalLong.of(((Number)"), + "Expected OptionalLong.of(((Number)...).longValue()), got:\n" + generated); + assertTrue(generated.contains("OptionalLong.empty()"), + "Expected OptionalLong.empty() fallback, got:\n" + generated); + + // OptionalDouble should use OptionalDouble.of(((Number)...).doubleValue()) + assertTrue(generated.contains("OptionalDouble.of(((Number)"), + "Expected OptionalDouble.of(((Number)...).doubleValue()), got:\n" + generated); + assertTrue(generated.contains("OptionalDouble.empty()"), + "Expected OptionalDouble.empty() fallback, got:\n" + generated); + + // Should NOT use mapper.convertValue for Optional types + assertFalse(generated.contains("mapper.convertValue(args.get(\"name\"), java.util.Optional.class)"), + "Should NOT use mapper.convertValue for Optional, got:\n" + generated); + } + + // ── Helpers ───────────────────────────────────────────────────────────────── + + private CompilationResult compileWithProcessor(List sources) { + JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); + DiagnosticCollector diagnostics = new DiagnosticCollector<>(); + + String classpath = resolveClasspath(); + List options = new ArrayList<>(); + options.add("-proc:full"); + options.addAll(List.of("-processor", "com.github.copilot.tool.CopilotToolProcessor")); + options.addAll(List.of("-classpath", classpath)); + options.addAll(List.of("-d", tempDir.toString())); + options.addAll(List.of("-s", tempDir.toString())); + // Allow experimental APIs during test compilation + options.add("-Acopilot.experimental.allowed=true"); + + try (StandardJavaFileManager fileManager = compiler.getStandardFileManager(diagnostics, null, null)) { + fileManager.setLocation(StandardLocation.SOURCE_OUTPUT, List.of(tempDir.toFile())); + fileManager.setLocation(StandardLocation.CLASS_OUTPUT, List.of(tempDir.toFile())); + CollectingFileManager collectingFileManager = new CollectingFileManager(fileManager); + + JavaCompiler.CompilationTask task = compiler.getTask(null, collectingFileManager, diagnostics, options, + null, sources); + task.call(); + + List generatedSources = collectingFileManager.getGeneratedSources(); + if (generatedSources.isEmpty()) { + // Fallback for file-manager implementations that only materialize on disk. + collectGeneratedFiles(tempDir, generatedSources); + } + + return new CompilationResult(diagnostics.getDiagnostics(), generatedSources, tempDir); + } catch (Exception e) { + throw new RuntimeException("Compilation setup failed", e); + } + } + + private void collectGeneratedFiles(java.nio.file.Path dir, List files) { + try (var stream = java.nio.file.Files.walk(dir)) { + stream.filter(p -> p.toString().endsWith(".java")).forEach(p -> { + try { + files.add(java.nio.file.Files.readString(p)); + } catch (java.io.IOException e) { + // ignore read errors for generated file collection + } + }); + } catch (java.io.IOException e) { + // ignore walk errors + } + } + + private static String resolveClasspath() { + // Collect classpath entries from CodeSource of key classes needed for + // compiling both the source and the generated $$CopilotToolMeta code. + Set paths = new LinkedHashSet<>(); + + // Add system classpath entries (may include manifest-only jars) + String systemCp = System.getProperty("java.class.path", ""); + if (!systemCp.isEmpty()) { + for (String p : systemCp.split(java.util.regex.Pattern.quote(File.pathSeparator))) { + if (!p.isEmpty()) { + paths.add(p); + } + } + } + + // Also resolve CodeSource paths for key classes (SDK + Jackson + RPC types) + Class[] keyClasses = {CopilotTool.class, com.fasterxml.jackson.databind.ObjectMapper.class, + com.fasterxml.jackson.core.JsonFactory.class, com.fasterxml.jackson.annotation.JsonProperty.class, + com.github.copilot.rpc.ToolDefinition.class}; + for (Class cls : keyClasses) { + try { + CodeSource cs = cls.getProtectionDomain().getCodeSource(); + if (cs != null && cs.getLocation() != null) { + paths.add(Path.of(cs.getLocation().toURI()).toString()); + } + } catch (Exception e) { + // skip this class + } + } + + return paths.isEmpty() ? "." : String.join(File.pathSeparator, paths); + } + + private static JavaFileObject inMemorySource(String className, String code) { + return new SimpleJavaFileObject(URI.create("string:///" + className.replace('.', '/') + ".java"), + JavaFileObject.Kind.SOURCE) { + @Override + public CharSequence getCharContent(boolean ignoreEncodingErrors) { + return code; + } + }; + } + + private static void assertNoErrors(CompilationResult result) { + List> errors = result.diagnostics.stream() + .filter(d -> d.getKind() == Diagnostic.Kind.ERROR).toList(); + assertTrue(errors.isEmpty(), "Expected no errors, got: " + errors); + } + + private static boolean hasErrorContaining(CompilationResult result, String substring) { + return result.diagnostics.stream() + .anyMatch(d -> d.getKind() == Diagnostic.Kind.ERROR && d.getMessage(null).contains(substring)); + } + + private static class CompilationResult { + final List> diagnostics; + final List generatedSources; + final java.nio.file.Path outputDir; + + CompilationResult(List> diagnostics, List generatedSources, + java.nio.file.Path outputDir) { + this.diagnostics = diagnostics; + this.generatedSources = generatedSources; + this.outputDir = outputDir; + } + + String getGeneratedSource(String qualifiedName) { + String fileName = qualifiedName.replace('.', '/') + ".java"; + java.nio.file.Path filePath = outputDir.resolve(fileName); + try { + if (java.nio.file.Files.exists(filePath)) { + return java.nio.file.Files.readString(filePath); + } + } catch (java.io.IOException e) { + // fall through + } + // Also check in collected sources + String simpleName = qualifiedName.substring(qualifiedName.lastIndexOf('.') + 1); + for (String source : generatedSources) { + if (source.contains("class " + simpleName)) { + return source; + } + } + return null; + } + } + + private static class CollectingFileManager extends ForwardingJavaFileManager { + private final Map generatedByClass = new LinkedHashMap<>(); + + CollectingFileManager(StandardJavaFileManager fileManager) { + super(fileManager); + } + + @Override + public JavaFileObject getJavaFileForOutput(Location location, String className, JavaFileObject.Kind kind, + FileObject sibling) throws IOException { + JavaFileObject delegate = super.getJavaFileForOutput(location, className, kind, sibling); + if (kind != JavaFileObject.Kind.SOURCE) { + return delegate; + } + StringBuilder captured = new StringBuilder(); + generatedByClass.put(className, captured); + return new ForwardingJavaFileObject<>(delegate) { + @Override + public Writer openWriter() throws IOException { + Writer target = delegate.openWriter(); + return new FilterWriter(target) { + @Override + public void write(char[] cbuf, int off, int len) throws IOException { + captured.append(cbuf, off, len); + super.write(cbuf, off, len); + } + + @Override + public void write(int c) throws IOException { + captured.append((char) c); + super.write(c); + } + + @Override + public void write(String str, int off, int len) throws IOException { + captured.append(str, off, off + len); + super.write(str, off, len); + } + }; + } + }; + } + + List getGeneratedSources() { + return generatedByClass.values().stream().map(StringBuilder::toString).toList(); + } + } +} diff --git a/java/src/test/java/com/github/copilot/tool/SchemaGeneratorTest.java b/java/src/test/java/com/github/copilot/tool/SchemaGeneratorTest.java new file mode 100644 index 000000000..00bb1d969 --- /dev/null +++ b/java/src/test/java/com/github/copilot/tool/SchemaGeneratorTest.java @@ -0,0 +1,762 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.tool; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.IOException; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Set; + +import javax.annotation.processing.AbstractProcessor; +import javax.annotation.processing.ProcessingEnvironment; +import javax.annotation.processing.RoundEnvironment; +import javax.annotation.processing.SupportedAnnotationTypes; +import javax.annotation.processing.SupportedSourceVersion; +import javax.lang.model.SourceVersion; +import javax.lang.model.element.Element; +import javax.lang.model.element.ElementKind; +import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.TypeElement; +import javax.lang.model.element.VariableElement; +import javax.lang.model.type.TypeMirror; +import javax.lang.model.util.Elements; +import javax.lang.model.util.Types; +import javax.tools.DiagnosticCollector; +import javax.tools.JavaCompiler; +import javax.tools.JavaFileObject; +import javax.tools.SimpleJavaFileObject; +import javax.tools.StandardJavaFileManager; +import javax.tools.StandardLocation; +import javax.tools.ToolProvider; + +import org.junit.jupiter.api.Test; + +/** + * Tests for {@link SchemaGenerator} using the compilation-testing approach. A + * test annotation processor exercises SchemaGenerator during compilation of + * small source snippets. + */ +public class SchemaGeneratorTest { + + /** + * In-memory Java source file for compilation testing. + */ + private static class InMemorySource extends SimpleJavaFileObject { + + private final String code; + + InMemorySource(String className, String code) { + super(URI.create("string:///" + className.replace('.', '/') + Kind.SOURCE.extension), Kind.SOURCE); + this.code = code; + } + + @Override + public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException { + return code; + } + } + + /** + * Test processor that captures schema generation results. + */ + @SupportedAnnotationTypes("*") + @SupportedSourceVersion(SourceVersion.RELEASE_17) + public static class SchemaCapturingProcessor extends AbstractProcessor { + + static final List capturedSchemas = new ArrayList<>(); + static final List capturedParameterSchemas = new ArrayList<>(); + + private Types typeUtils; + private Elements elementUtils; + + @Override + public synchronized void init(ProcessingEnvironment processingEnv) { + super.init(processingEnv); + this.typeUtils = processingEnv.getTypeUtils(); + this.elementUtils = processingEnv.getElementUtils(); + } + + @Override + public boolean process(Set annotations, RoundEnvironment roundEnv) { + if (roundEnv.processingOver()) { + return false; + } + + SchemaGenerator generator = new SchemaGenerator(); + + for (Element rootElement : roundEnv.getRootElements()) { + if (rootElement.getKind() == ElementKind.CLASS || rootElement.getKind() == ElementKind.RECORD + || rootElement.getKind() == ElementKind.INTERFACE + || rootElement.getKind() == ElementKind.ENUM) { + // Find methods named "schemaTarget" to capture schemas for their return type + for (Element enclosed : rootElement.getEnclosedElements()) { + if (enclosed.getKind() == ElementKind.METHOD) { + ExecutableElement method = (ExecutableElement) enclosed; + String methodName = method.getSimpleName().toString(); + if (methodName.startsWith("schemaTarget")) { + TypeMirror returnType = method.getReturnType(); + String schema = generator.generateSchemaSource(returnType, typeUtils, elementUtils); + capturedSchemas.add(methodName + "=" + schema); + } + if ("parametersTarget".equals(methodName)) { + List params = method.getParameters(); + String schema = generator.generateParametersSchemaSource(params, typeUtils, + elementUtils); + capturedParameterSchemas.add(schema); + } + } + } + + // For record/enum types, generate schema for the type itself + TypeElement typeElement = (TypeElement) rootElement; + String typeName = typeElement.getSimpleName().toString(); + if (typeName.startsWith("TestRecord") || typeName.startsWith("TestEnum") + || typeName.startsWith("TestSealed")) { + String schema = generator.generateSchemaSource(typeElement.asType(), typeUtils, elementUtils); + capturedSchemas.add(typeName + "=" + schema); + } + } + } + + return false; + } + } + + private static final Path CLASS_OUTPUT_DIR = Path.of("target", "test-schema-classes"); + + /** + * Creates a StandardJavaFileManager that writes compiled .class files to + * target/test-schema-classes/ instead of the working directory. + */ + private StandardJavaFileManager createFileManager(JavaCompiler compiler, + DiagnosticCollector diagnostics) throws IOException { + Files.createDirectories(CLASS_OUTPUT_DIR); + StandardJavaFileManager fm = compiler.getStandardFileManager(diagnostics, null, null); + fm.setLocation(StandardLocation.CLASS_OUTPUT, List.of(CLASS_OUTPUT_DIR.toFile())); + return fm; + } + + private List compileAndCapture(String... sources) { + return compileAndCapture(Arrays.asList(sources)); + } + + private List compileAndCapture(List sourceTexts) { + SchemaCapturingProcessor.capturedSchemas.clear(); + SchemaCapturingProcessor.capturedParameterSchemas.clear(); + + JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); + assertNotNull(compiler, "System Java compiler not available"); + + DiagnosticCollector diagnostics = new DiagnosticCollector<>(); + + List compilationUnits = new ArrayList<>(); + for (String sourceText : sourceTexts) { + // Extract class name from source + String className = extractClassName(sourceText); + compilationUnits.add(new InMemorySource(className, sourceText)); + } + + try (StandardJavaFileManager fm = createFileManager(compiler, diagnostics)) { + // Compile with the processor on classpath + JavaCompiler.CompilationTask task = compiler.getTask(null, // writer + fm, // file manager + diagnostics, // diagnostics + List.of("--add-modules", "ALL-MODULE-PATH"), // options + null, // annotation classes + compilationUnits); + + task.setProcessors(List.of(new SchemaCapturingProcessor())); + boolean success = task.call(); + + if (!success) { + // Try without module options for simpler environments + diagnostics = new DiagnosticCollector<>(); + try (StandardJavaFileManager fm2 = createFileManager(compiler, diagnostics)) { + task = compiler.getTask(null, fm2, diagnostics, null, null, compilationUnits); + task.setProcessors(List.of(new SchemaCapturingProcessor())); + success = task.call(); + } + } + + assertTrue(success, "Compilation failed: " + diagnostics.getDiagnostics()); + } catch (IOException e) { + fail("Failed to create file manager: " + e.getMessage()); + } + return new ArrayList<>(SchemaCapturingProcessor.capturedSchemas); + } + + private List compileAndCaptureParams(String source) { + SchemaCapturingProcessor.capturedSchemas.clear(); + SchemaCapturingProcessor.capturedParameterSchemas.clear(); + + JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); + assertNotNull(compiler, "System Java compiler not available"); + + DiagnosticCollector diagnostics = new DiagnosticCollector<>(); + + String className = extractClassName(source); + List compilationUnits = List.of(new InMemorySource(className, source)); + + try (StandardJavaFileManager fm = createFileManager(compiler, diagnostics)) { + JavaCompiler.CompilationTask task = compiler.getTask(null, fm, diagnostics, null, null, compilationUnits); + task.setProcessors(List.of(new SchemaCapturingProcessor())); + boolean success = task.call(); + + assertTrue(success, "Compilation failed: " + diagnostics.getDiagnostics()); + } catch (IOException e) { + fail("Failed to create file manager: " + e.getMessage()); + } + return new ArrayList<>(SchemaCapturingProcessor.capturedParameterSchemas); + } + + private String extractClassName(String source) { + // Simple extraction: find "class X", "record X", "enum X", or "interface X" + for (String keyword : new String[]{"class ", "record ", "enum ", "interface "}) { + int idx = source.indexOf(keyword); + if (idx >= 0) { + int start = idx + keyword.length(); + int end = start; + while (end < source.length() && Character.isJavaIdentifierPart(source.charAt(end))) { + end++; + } + return source.substring(start, end); + } + } + return "Unknown"; + } + + // --- Type mapping tests --- + + @Test + void stringType() { + String source = """ + public class TestStringHolder { + public String schemaTargetString() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetString", "Map.of(\"type\", \"string\")"); + } + + @Test + void intPrimitiveType() { + String source = """ + public class TestIntHolder { + public int schemaTargetInt() { return 0; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetInt", "Map.of(\"type\", \"integer\")"); + } + + @Test + void integerBoxedType() { + String source = """ + public class TestIntegerHolder { + public Integer schemaTargetInteger() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetInteger", "Map.of(\"type\", \"integer\")"); + } + + @Test + void longType() { + String source = """ + public class TestLongHolder { + public long schemaTargetLong() { return 0L; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetLong", "Map.of(\"type\", \"integer\")"); + } + + @Test + void doubleType() { + String source = """ + public class TestDoubleHolder { + public double schemaTargetDouble() { return 0.0; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetDouble", "Map.of(\"type\", \"number\")"); + } + + @Test + void floatType() { + String source = """ + public class TestFloatHolder { + public float schemaTargetFloat() { return 0.0f; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetFloat", "Map.of(\"type\", \"number\")"); + } + + @Test + void booleanPrimitiveType() { + String source = """ + public class TestBooleanHolder { + public boolean schemaTargetBoolean() { return false; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetBoolean", "Map.of(\"type\", \"boolean\")"); + } + + @Test + void booleanBoxedType() { + String source = """ + public class TestBooleanBoxedHolder { + public Boolean schemaTargetBooleanBoxed() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetBooleanBoxed", "Map.of(\"type\", \"boolean\")"); + } + + @Test + void byteBoxedType() { + String source = """ + public class TestByteHolder { + public Byte schemaTargetByte() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetByte", "Map.of(\"type\", \"integer\")"); + } + + @Test + void shortBoxedType() { + String source = """ + public class TestShortHolder { + public Short schemaTargetShort() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetShort", "Map.of(\"type\", \"integer\")"); + } + + @Test + void characterBoxedType() { + String source = """ + public class TestCharHolder { + public Character schemaTargetChar() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetChar", "Map.of(\"type\", \"string\")"); + } + + @Test + void stringArrayType() { + String source = """ + public class TestArrayHolder { + public String[] schemaTargetArray() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetArray", + "Map.of(\"type\", \"array\", \"items\", Map.of(\"type\", \"string\"))"); + } + + @Test + void enumType() { + String source = """ + public enum TestEnumColor { RED, GREEN, BLUE } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "TestEnumColor", + "Map.of(\"type\", \"string\", \"enum\", List.of(\"RED\", \"GREEN\", \"BLUE\"))"); + } + + @Test + void listOfStringType() { + String source = """ + import java.util.List; + public class TestListHolder { + public List schemaTargetList() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetList", + "Map.of(\"type\", \"array\", \"items\", Map.of(\"type\", \"string\"))"); + } + + @Test + void mapStringStringType() { + String source = """ + import java.util.Map; + public class TestMapHolder { + public Map schemaTargetMap() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetMap", + "Map.of(\"type\", \"object\", \"additionalProperties\", Map.of(\"type\", \"string\"))"); + } + + @Test + void mapStringObjectType() { + String source = """ + import java.util.Map; + public class TestMapObjectHolder { + public Map schemaTargetMapObject() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetMapObject", "Map.of(\"type\", \"object\")"); + } + + @Test + void mapStringBooleanType() { + String source = """ + import java.util.Map; + public class TestMapBoolHolder { + public Map schemaTargetMapBool() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetMapBool", + "Map.of(\"type\", \"object\", \"additionalProperties\", Map.of(\"type\", \"boolean\"))"); + } + + @Test + void mapStringLongType() { + String source = """ + import java.util.Map; + public class TestMapLongHolder { + public Map schemaTargetMapLong() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetMapLong", + "Map.of(\"type\", \"object\", \"additionalProperties\", Map.of(\"type\", \"integer\"))"); + } + + @Test + void optionalStringType() { + String source = """ + import java.util.Optional; + public class TestOptionalHolder { + public Optional schemaTargetOptional() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetOptional", "Map.of(\"type\", \"string\")"); + } + + @Test + void optionalIntType() { + String source = """ + import java.util.OptionalInt; + public class TestOptionalIntHolder { + public OptionalInt schemaTargetOptionalInt() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetOptionalInt", "Map.of(\"type\", \"integer\")"); + } + + @Test + void optionalLongType() { + String source = """ + import java.util.OptionalLong; + public class TestOptionalLongHolder { + public OptionalLong schemaTargetOptionalLong() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetOptionalLong", "Map.of(\"type\", \"integer\")"); + } + + @Test + void optionalDoubleType() { + String source = """ + import java.util.OptionalDouble; + public class TestOptionalDoubleHolder { + public OptionalDouble schemaTargetOptionalDouble() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetOptionalDouble", "Map.of(\"type\", \"number\")"); + } + + @Test + void uuidType() { + String source = """ + import java.util.UUID; + public class TestUuidHolder { + public UUID schemaTargetUuid() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetUuid", "Map.of(\"type\", \"string\", \"format\", \"uuid\")"); + } + + @Test + void offsetDateTimeType() { + String source = """ + import java.time.OffsetDateTime; + public class TestDateTimeHolder { + public OffsetDateTime schemaTargetDateTime() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetDateTime", + "Map.of(\"type\", \"string\", \"format\", \"date-time\")"); + } + + @Test + void localDateTimeType() { + String source = """ + import java.time.LocalDateTime; + public class TestLocalDateTimeHolder { + public LocalDateTime schemaTargetLocalDateTime() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetLocalDateTime", + "Map.of(\"type\", \"string\", \"format\", \"date-time\")"); + } + + @Test + void instantType() { + String source = """ + import java.time.Instant; + public class TestInstantHolder { + public Instant schemaTargetInstant() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetInstant", "Map.of(\"type\", \"string\", \"format\", \"date-time\")"); + } + + @Test + void zonedDateTimeType() { + String source = """ + import java.time.ZonedDateTime; + public class TestZonedDateTimeHolder { + public ZonedDateTime schemaTargetZonedDateTime() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetZonedDateTime", + "Map.of(\"type\", \"string\", \"format\", \"date-time\")"); + } + + @Test + void localDateType() { + String source = """ + import java.time.LocalDate; + public class TestLocalDateHolder { + public LocalDate schemaTargetLocalDate() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetLocalDate", "Map.of(\"type\", \"string\", \"format\", \"date\")"); + } + + @Test + void localTimeType() { + String source = """ + import java.time.LocalTime; + public class TestLocalTimeHolder { + public LocalTime schemaTargetLocalTime() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetLocalTime", "Map.of(\"type\", \"string\", \"format\", \"time\")"); + } + + @Test + void recordType() { + String source = """ + public record TestRecordPerson(String name, int age, boolean active) {} + """; + List schemas = compileAndCapture(source); + String expected = "Map.of(\"type\", \"object\", \"properties\", " + + "Map.ofEntries(Map.entry(\"name\", Map.of(\"type\", \"string\")), " + + "Map.entry(\"age\", Map.of(\"type\", \"integer\")), " + + "Map.entry(\"active\", Map.of(\"type\", \"boolean\"))), " + + "\"required\", List.of(\"name\", \"age\", \"active\"))"; + assertContainsSchema(schemas, "TestRecordPerson", expected); + } + + @Test + void recordWithOptionalField() { + String source = """ + import java.util.Optional; + public record TestRecordWithOptional(String name, Optional nickname) {} + """; + List schemas = compileAndCapture(source); + String expected = "Map.of(\"type\", \"object\", \"properties\", " + + "Map.ofEntries(Map.entry(\"name\", Map.of(\"type\", \"string\")), " + + "Map.entry(\"nickname\", Map.of(\"type\", \"string\"))), " + "\"required\", List.of(\"name\"))"; + assertContainsSchema(schemas, "TestRecordWithOptional", expected); + } + + @Test + void recordWithMoreThanTenFields() { + String source = """ + public record TestRecordLarge( + String f1, String f2, String f3, String f4, String f5, + String f6, String f7, String f8, String f9, String f10, + String f11) {} + """; + List schemas = compileAndCapture(source); + // Verify the schema contains all 11 fields and uses Map.ofEntries + String schema = schemas.stream().filter(s -> s.startsWith("TestRecordLarge=")).findFirst().orElse(""); + assertFalse(schema.isEmpty(), "Expected schema for TestRecordLarge"); + assertTrue(schema.contains("Map.ofEntries("), "Should use Map.ofEntries for >10 fields: " + schema); + assertTrue(schema.contains("Map.entry(\"f1\""), "Should have f1: " + schema); + assertTrue(schema.contains("Map.entry(\"f11\""), "Should have f11: " + schema); + // Verify the generated source expression is compilable by re-compiling it + String schemaExpr = schema.substring(schema.indexOf('=') + 1); + String validationSource = "import java.util.Map;\nimport java.util.List;\n" + + "public class LargeRecordValidation {\n" + " @SuppressWarnings(\"unchecked\")\n" + + " public Object schema() { return " + schemaExpr + "; }\n}\n"; + JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); + DiagnosticCollector diagnostics = new DiagnosticCollector<>(); + List units = List.of(new InMemorySource("LargeRecordValidation", validationSource)); + try (StandardJavaFileManager fm = createFileManager(compiler, diagnostics)) { + JavaCompiler.CompilationTask task = compiler.getTask(null, fm, diagnostics, null, null, units); + boolean success = task.call(); + assertTrue(success, "Generated schema for >10-field record does not compile: " + + diagnostics.getDiagnostics() + "\nSource:\n" + validationSource); + } catch (IOException e) { + fail("Failed to create file manager: " + e.getMessage()); + } + } + + @Test + void parametersSchema() { + String source = """ + public class TestParamsHolder { + public void parametersTarget(String query, int limit, boolean verbose) {} + } + """; + List paramSchemas = compileAndCaptureParams(source); + assertFalse(paramSchemas.isEmpty(), "Expected parameter schemas"); + String schema = paramSchemas.get(0); + assertTrue(schema.contains("\"type\", \"object\""), "Should be object type: " + schema); + assertTrue(schema.contains("Map.entry(\"query\", Map.of(\"type\", \"string\"))"), + "Should have query property: " + schema); + assertTrue(schema.contains("Map.entry(\"limit\", Map.of(\"type\", \"integer\"))"), + "Should have limit property: " + schema); + assertTrue(schema.contains("Map.entry(\"verbose\", Map.of(\"type\", \"boolean\"))"), + "Should have verbose property: " + schema); + assertTrue(schema.contains("\"required\", List.of("), "Should have required list: " + schema); + } + + @Test + void generatedSourceIsValidJava() { + // Verify that generated schema source code compiles when embedded in a method + // body + String source = """ + import java.util.List; + import java.util.Map; + import java.util.Optional; + public class TestValidJavaHolder { + public String schemaTargetStr() { return null; } + public List schemaTargetListStr() { return null; } + public Map schemaTargetMapStr() { return null; } + public Optional schemaTargetOpt() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertFalse(schemas.isEmpty()); + + // Build a Java source that uses the generated schema expressions + StringBuilder validationSource = new StringBuilder(); + validationSource.append("import java.util.Map;\n"); + validationSource.append("import java.util.List;\n"); + validationSource.append("public class SchemaValidation {\n"); + validationSource.append(" @SuppressWarnings(\"unchecked\")\n"); + validationSource.append(" public void validate() {\n"); + for (int i = 0; i < schemas.size(); i++) { + String schema = schemas.get(i); + String schemaExpr = schema.substring(schema.indexOf('=') + 1); + validationSource.append(" Object s" + i + " = " + schemaExpr + ";\n"); + } + validationSource.append(" }\n"); + validationSource.append("}\n"); + + // Compile the validation source to verify syntactic validity + JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); + DiagnosticCollector diagnostics = new DiagnosticCollector<>(); + List compilationUnits = List + .of(new InMemorySource("SchemaValidation", validationSource.toString())); + + try (StandardJavaFileManager fm = createFileManager(compiler, diagnostics)) { + JavaCompiler.CompilationTask task = compiler.getTask(null, fm, diagnostics, null, null, compilationUnits); + boolean success = task.call(); + + assertTrue(success, "Generated schema source code is not valid Java: " + diagnostics.getDiagnostics() + + "\nSource:\n" + validationSource); + } catch (IOException e) { + fail("Failed to create file manager: " + e.getMessage()); + } + } + + @Test + void nestedMapListType() { + String source = """ + import java.util.List; + import java.util.Map; + public class TestNestedHolder { + public Map> schemaTargetNestedMap() { return null; } + } + """; + List schemas = compileAndCapture(source); + String expected = "Map.of(\"type\", \"object\", \"additionalProperties\", " + + "Map.of(\"type\", \"array\", \"items\", Map.of(\"type\", \"string\")))"; + assertContainsSchema(schemas, "schemaTargetNestedMap", expected); + } + + @Test + void objectType() { + String source = """ + public class TestObjectHolder { + public Object schemaTargetObject() { return null; } + } + """; + List schemas = compileAndCapture(source); + assertContainsSchema(schemas, "schemaTargetObject", "Map.of()"); + } + + @Test + void sealedInterfaceType() { + String sealedInterface = """ + public sealed interface TestSealedShape permits TestSealedCircle, TestSealedRect {} + """; + String circle = """ + public record TestSealedCircle(double radius) implements TestSealedShape {} + """; + String rect = """ + public record TestSealedRect(double width, double height) implements TestSealedShape {} + """; + List schemas = compileAndCapture(sealedInterface, circle, rect); + String expected = "Map.of(\"oneOf\", List.of(" + "Map.of(\"type\", \"object\", \"properties\", " + + "Map.ofEntries(Map.entry(\"radius\", Map.of(\"type\", \"number\"))), " + + "\"required\", List.of(\"radius\")), " + "Map.of(\"type\", \"object\", \"properties\", " + + "Map.ofEntries(Map.entry(\"width\", Map.of(\"type\", \"number\")), " + + "Map.entry(\"height\", Map.of(\"type\", \"number\"))), " + + "\"required\", List.of(\"width\", \"height\"))))"; + assertContainsSchema(schemas, "TestSealedShape", expected); + } + + private void assertContainsSchema(List schemas, String methodName, String expectedSchema) { + String expected = methodName + "=" + expectedSchema; + assertTrue(schemas.stream().anyMatch(s -> s.equals(expected)), + "Expected schema '" + expected + "' not found in: " + schemas); + } +} diff --git a/test/snapshots/session/should_abort_a_session.yaml b/test/snapshots/session/should_abort_a_session.yaml index dbbbd32aa..f1217f7f6 100644 --- a/test/snapshots/session/should_abort_a_session.yaml +++ b/test/snapshots/session/should_abort_a_session.yaml @@ -50,3 +50,31 @@ conversations: content: What is 2+2? - role: assistant content: 2 + 2 = 4 + - messages: + - role: system + content: ${system} + - role: user + content: run the shell command 'sleep 100' (note this works on both bash and PowerShell) + - role: assistant + content: I'll run the sleep command for 100 seconds. + tool_calls: + - id: toolcall_0 + type: function + function: + name: report_intent + arguments: '{"intent":"Running sleep command"}' + - id: toolcall_1 + type: function + function: + name: ${shell} + arguments: '{"command":"sleep 100","description":"Run sleep 100 command","mode":"sync","initial_wait":105}' + - role: tool + tool_call_id: toolcall_0 + content: The execution of this tool, or a previous tool was interrupted. + - role: tool + tool_call_id: toolcall_1 + content: The execution of this tool, or a previous tool was interrupted. + - role: user + content: What is 2+2? + - role: assistant + content: 2 + 2 = 4 diff --git a/test/snapshots/tools/ergonomic_tool_definition.yaml b/test/snapshots/tools/ergonomic_tool_definition.yaml new file mode 100644 index 000000000..ebb05ce1b --- /dev/null +++ b/test/snapshots/tools/ergonomic_tool_definition.yaml @@ -0,0 +1,33 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: + First, set the current phase to 'analyzing'. Then search for items with keyword 'copilot'. Report the phase and + search results. + - role: assistant + content: I'll set the phase and run the search now. + tool_calls: + - id: toolcall_0 + type: function + function: + name: set_current_phase + arguments: '{"phase":"analyzing"}' + - id: toolcall_1 + type: function + function: + name: search_items + arguments: '{"keyword":"copilot"}' + - role: tool + tool_call_id: toolcall_0 + content: Phase set to analyzing + - role: tool + tool_call_id: toolcall_1 + content: "Found: copilot -> item_alpha, item_beta" + - role: assistant + content: |- + Current phase: analyzing + Search results: item_alpha, item_beta