diff --git a/arkoala-arkts/libarkts/src/arkts-api/factory/nodeFactory.ts b/arkoala-arkts/libarkts/src/arkts-api/factory/nodeFactory.ts index eb99b8552f94b38f3023d8e0a8ad20f103678ec6..fb5e4f0ef0e65db21ddaa02040e385b9009a3607 100644 --- a/arkoala-arkts/libarkts/src/arkts-api/factory/nodeFactory.ts +++ b/arkoala-arkts/libarkts/src/arkts-api/factory/nodeFactory.ts @@ -33,6 +33,7 @@ import { ExpressionStatement, FunctionDeclaration, FunctionExpression, + FunctionSignature, Identifier, IfStatement, ImportSpecifier, @@ -300,4 +301,7 @@ export const factory = { get updateAnnotationUsageIr() { return compose(UndefinedLiteral.create) }, + get createFunctionSignature() { + return FunctionSignature.create + }, } diff --git a/arkoala-arkts/libarkts/src/arkts-api/types.ts b/arkoala-arkts/libarkts/src/arkts-api/types.ts index b04cc52b5562182b01ffa4c8b6dfa6ac8ccae7b4..7ead28e935e3133d7edf0d4a022633a3d67387c0 100644 --- a/arkoala-arkts/libarkts/src/arkts-api/types.ts +++ b/arkoala-arkts/libarkts/src/arkts-api/types.ts @@ -364,6 +364,14 @@ export class ETSFunctionType extends AstNode { ) ) } + + get params() { + return unpackNodeArray(global.generatedEs2panda._ETSFunctionTypeIrParamsConst(global.context, this.peer)) + } + + get returnType() { + return unpackNode(global.generatedEs2panda._ETSFunctionTypeIrReturnType(global.context, this.peer)) + } } export class Identifier extends Expression { @@ -757,6 +765,16 @@ export class ETSParameterExpression extends AstNode { ); } + get type(): AstNode | undefined { + return unpackNode(global.generatedEs2panda._ETSParameterExpressionTypeAnnotation(global.context, this.peer)) + } + + set type(t: AstNode | undefined) { + if (t === undefined) + return + global.generatedEs2panda._ETSParameterExpressionSetTypeAnnotation(global.context, this.peer, t.peer) + } + identifier: Identifier } diff --git a/arkoala-arkts/memo-plugin/demo/demo.sts b/arkoala-arkts/memo-plugin/demo/demo.sts index 8e49383b95c5b31af7ff9109c978ce932711b6bc..8012d951070d71a5e7b85623afeb17febe5404da 100644 --- a/arkoala-arkts/memo-plugin/demo/demo.sts +++ b/arkoala-arkts/memo-plugin/demo/demo.sts @@ -1,4 +1,4 @@ -import { GlobalStateManager, memoEntry, StateContext, memo } from "@koalaui/runtime" +import { GlobalStateManager, memoEntry, StateContext, memo, memo_intrinsic } from "@koalaui/runtime" import { memo_foo } from "./stub" @memo @@ -13,10 +13,19 @@ function f(s: string) { y("she") } +@memo +function bar( + arg1: number, + @memo arg2: (x: number) => number, +) { + console.log(arg1, arg2(arg1)) +} + @memo function foo_wrapper() { ETSGLOBAL.f("he") memo_foo("hello") + ETSGLOBAL.bar(1, (x: number): number => { return 3 + x; }) } function main() { diff --git a/arkoala-arkts/memo-plugin/demo/demo.ts b/arkoala-arkts/memo-plugin/demo/demo.ts index 1d5af07338529b78ce9c12918343892e955781c4..7e4d48536237fc70d7778779097e96341f36955c 100644 --- a/arkoala-arkts/memo-plugin/demo/demo.ts +++ b/arkoala-arkts/memo-plugin/demo/demo.ts @@ -16,10 +16,19 @@ function f(s: string) { y("she") } +/** @memo */ +function bar( + arg1: number, + /** @memo */ arg2: (x: number) => number, +) { + console.log(arg1, arg2(arg1)) +} + /** @memo */ function foo_wrapper() { f("he") memo_foo("hello") + bar(1, (x: number): number => { return 3 + x; }) } function main() { diff --git a/arkoala-arkts/memo-plugin/src/FunctionTransformer.ts b/arkoala-arkts/memo-plugin/src/FunctionTransformer.ts index 93217012915ae13f6616ee418d07f42186c422c3..bdb4b0ee6bc3c9d7aadd80b4906239a9d59ea144 100644 --- a/arkoala-arkts/memo-plugin/src/FunctionTransformer.ts +++ b/arkoala-arkts/memo-plugin/src/FunctionTransformer.ts @@ -18,17 +18,12 @@ import { factory } from "./MemoFactory" import { AbstractVisitor } from "./AbstractVisitor" import { PositionalIdTracker, - RuntimeNames + hasMemoAnnotation, + hasMemoIntrinsicAnnotation, } from "./utils" import { ParameterTransformer } from "./ParameterTransformer" import { ReturnTransformer } from "./ReturnTranformer" -function hasMemoAnnotation(node: arkts.ScriptFunction | arkts.ETSParameterExpression) { - return node.annotations.some((it) => - it.expr !== undefined && arkts.isIdentifier(it.expr) && it.expr.name === RuntimeNames.ANNOTATION - ) -} - function updateFunctionBody( node: arkts.BlockStatement, parameters: arkts.ETSParameterExpression[], @@ -93,43 +88,71 @@ export class FunctionTransformer extends AbstractVisitor { super() } + updateScriptFunction( + scriptFunction: arkts.ScriptFunction, + name: string = "", + ): arkts.ScriptFunction { + if (!scriptFunction.body) { + return scriptFunction + } + const [body, memoParametersDeclaration, syntheticReturnStatement] = updateFunctionBody( + scriptFunction.body, + scriptFunction.parameters, + scriptFunction.returnTypeAnnotation, + this.positionalIdTracker.id(name), + ) + const afterParameterTransformer = this.parameterTransformer + .withParameters(scriptFunction.parameters) + .skip(memoParametersDeclaration) + .visitor(body) + const afterReturnTransformer = this.returnTransformer + .skip(syntheticReturnStatement) + .visitor(afterParameterTransformer) + const updatedParameters = scriptFunction.parameters.map((param) => { + if (hasMemoAnnotation(param)) { + if (!(param.type instanceof arkts.ETSFunctionType)) { + throw "ArrowFunctionExpression expected for @memo parameter of @memo_intrinsic function" + } + param.type = arkts.factory.createFunctionType( + arkts.factory.createFunctionSignature( + undefined, + [...factory.createHiddenParameters(), ...param.type.params], + param.type.returnType, + ), + arkts.Es2pandaScriptFunctionFlags.SCRIPT_FUNCTION_FLAGS_ARROW, + ) + } + return param + }) + return arkts.factory.updateScriptFunction( + scriptFunction, + afterReturnTransformer, + scriptFunction.scriptFunctionFlags, + scriptFunction.modifiers, + false, + scriptFunction.ident, + [...factory.createHiddenParameters(), ...updatedParameters], + scriptFunction.typeParamsDecl, + scriptFunction.returnTypeAnnotation + ) + } + visitor(beforeChildren: arkts.AstNode): arkts.AstNode { // TODO: Remove (currently annotations are lost on visitor) const methodDefinitionHasMemoAnnotation = beforeChildren instanceof arkts.MethodDefinition && hasMemoAnnotation(beforeChildren.scriptFunction) + const methodDefinitionHasMemoIntrinsicAnnotation = + beforeChildren instanceof arkts.MethodDefinition && hasMemoIntrinsicAnnotation(beforeChildren.scriptFunction) const node = this.visitEachChild(beforeChildren) if (node instanceof arkts.MethodDefinition && node.scriptFunction.body) { - if (methodDefinitionHasMemoAnnotation) { - const [body, memoParametersDeclaration, syntheticReturnStatement] = updateFunctionBody( - node.scriptFunction.body, - node.scriptFunction.parameters, - node.scriptFunction.returnTypeAnnotation, - this.positionalIdTracker.id(node.name.name), - ) - const afterParameterTransformer = this.parameterTransformer - .withParameters(node.scriptFunction.parameters) - .skip(memoParametersDeclaration) - .visitor(body) - const afterReturnTransformer = this.returnTransformer - .skip(syntheticReturnStatement) - .visitor(afterParameterTransformer) + if (methodDefinitionHasMemoAnnotation || methodDefinitionHasMemoIntrinsicAnnotation) { return arkts.factory.updateMethodDefinition( node, arkts.Es2pandaMethodDefinitionKind.METHOD_DEFINITION_KIND_METHOD, node.name, arkts.factory.createFunctionExpression( - arkts.factory.updateScriptFunction( - node.scriptFunction, - afterReturnTransformer, - node.scriptFunction.scriptFunctionFlags, - node.scriptFunction.modifiers, - false, - node.scriptFunction.ident, - [...factory.createHiddenParameters(), ...node.scriptFunction.parameters], - node.scriptFunction.typeParamsDecl, - node.scriptFunction.returnTypeAnnotation - ) + this.updateScriptFunction(node.scriptFunction, node.name.name), ), node.modifiers, false @@ -139,10 +162,16 @@ export class FunctionTransformer extends AbstractVisitor { if (node instanceof arkts.CallExpression) { const expr = node.expression const decl = arkts.getDecl(expr) - if (decl instanceof arkts.MethodDefinition && hasMemoAnnotation(decl.scriptFunction)) { + if (decl instanceof arkts.MethodDefinition && (hasMemoAnnotation(decl.scriptFunction) || hasMemoIntrinsicAnnotation(decl.scriptFunction))) { const updatedArguments = node.arguments.map((it, index) => { - if (hasMemoAnnotation(decl.scriptFunction.parameters[index])) { - return factory.createComputeExpression(this.positionalIdTracker.id(decl.name.name), it) + if (decl.scriptFunction.parameters[index].type instanceof arkts.ETSFunctionType) { + if (!hasMemoAnnotation(decl.scriptFunction.parameters[index]) && !hasMemoIntrinsicAnnotation(decl.scriptFunction.parameters[index])) { + return factory.createComputeExpression(this.positionalIdTracker.id(decl.name.name), it) + } + if (!(it instanceof arkts.ArrowFunctionExpression)) { + throw "ArrowFunctionExpression expected for @memo argument of @memo function" + } + return this.updateScriptFunction(it.scriptFunction) } return it }) diff --git a/arkoala-arkts/memo-plugin/src/MemoFactory.ts b/arkoala-arkts/memo-plugin/src/MemoFactory.ts index 5c9246d91f3f493cf0a58997d1ec827b81324c0a..e83ea086d98392e8dce38e754c781be33088d2e6 100644 --- a/arkoala-arkts/memo-plugin/src/MemoFactory.ts +++ b/arkoala-arkts/memo-plugin/src/MemoFactory.ts @@ -108,6 +108,20 @@ export class factory { false, ) } + static createMemoParameterAccessMemo(name: string, hash: arkts.NumberLiteral | arkts.StringLiteral, passArgs?: arkts.AstNode[]): arkts.CallExpression { + const updatedArgs = passArgs ? passArgs : [] + return arkts.factory.createCallExpression( + arkts.factory.createMemberExpression( + factory.createMemoParameterIdentifier(name), + arkts.factory.createIdentifier(RuntimeNames.VALUE), + arkts.Es2pandaMemberExpressionKind.MEMBER_EXPRESSION_KIND_GETTER, + false, + false, + ), + undefined, + [...factory.createHiddenArguments(hash), ...updatedArgs], + ) + } // Recache static createScopeDeclaration(returnTypeAnnotation: arkts.AstNode | undefined, hash: arkts.NumberLiteral | arkts.StringLiteral, cnt: number): arkts.VariableDeclaration { diff --git a/arkoala-arkts/memo-plugin/src/MemoTransformer.ts b/arkoala-arkts/memo-plugin/src/MemoTransformer.ts index 8ac6ee906c6ef7537c4d5e4568d653459f05e2d6..876720259b87175c3d30d249142e55f1ac6e9bb6 100644 --- a/arkoala-arkts/memo-plugin/src/MemoTransformer.ts +++ b/arkoala-arkts/memo-plugin/src/MemoTransformer.ts @@ -29,7 +29,7 @@ export default function memoTransformer( ) { return (node: arkts.EtsScript) => { const positionalIdTracker = new PositionalIdTracker(arkts.getFileName(), false) - const parameterTransformer = new ParameterTransformer() + const parameterTransformer = new ParameterTransformer(positionalIdTracker) const returnTransformer = new ReturnTransformer() const functionTransformer = new FunctionTransformer(positionalIdTracker, parameterTransformer, returnTransformer) return functionTransformer.visitor( diff --git a/arkoala-arkts/memo-plugin/src/ParameterTransformer.ts b/arkoala-arkts/memo-plugin/src/ParameterTransformer.ts index cf4d874e3b21c517cfaed2db31e40a06cb7853ee..eaf6fb289caa3554b459050fcd9090372dbe522a 100644 --- a/arkoala-arkts/memo-plugin/src/ParameterTransformer.ts +++ b/arkoala-arkts/memo-plugin/src/ParameterTransformer.ts @@ -17,15 +17,24 @@ import * as arkts from "@koalaui/libarkts" import { factory } from "./MemoFactory" import { AbstractVisitor } from "./AbstractVisitor" import { KPointer } from "@koalaui/interop" -import { isMemoParametersDeclaration } from "./utils" +import { hasMemoAnnotation, hasMemoIntrinsicAnnotation, isMemoParametersDeclaration, PositionalIdTracker } from "./utils" export class ParameterTransformer extends AbstractVisitor { - private rewrites?: Map arkts.MemberExpression> + private rewrites?: Map arkts.CallExpression | arkts.MemberExpression> private skipNode?: arkts.VariableDeclaration + constructor(private positionalIdTracker: PositionalIdTracker) { + super() + } + withParameters(parameters: arkts.ETSParameterExpression[]): ParameterTransformer { this.rewrites = new Map(parameters.map((it) => { - return [it.peer, () => factory.createMemoParameterAccess(it.identifier.name)] + return [it.peer, (passArgs?: arkts.AstNode[]) => { + if (hasMemoAnnotation(it) || hasMemoIntrinsicAnnotation(it)) { + return factory.createMemoParameterAccessMemo(it.identifier.name, this.positionalIdTracker?.id(""), passArgs) + } + return factory.createMemoParameterAccess(it.identifier.name) + }] })) return this } @@ -41,11 +50,21 @@ export class ParameterTransformer extends AbstractVisitor { if (/* beforeChildren === this.skipNode */ isMemoParametersDeclaration(beforeChildren)) { return beforeChildren } + if (beforeChildren instanceof arkts.CallExpression) { + if (beforeChildren.expression instanceof arkts.Identifier) { + const decl = arkts.getDecl(beforeChildren.expression) + if (decl instanceof arkts.ETSParameterExpression && this.rewrites?.has(decl.peer)) { + return this.rewrites.get(decl.peer)!( + beforeChildren.arguments.map((it) => this.visitor(it)) + ) + } + } + } const node = this.visitEachChild(beforeChildren) if (node instanceof arkts.Identifier) { const decl = arkts.getDecl(node) - if (decl instanceof arkts.ETSParameterExpression && this.rewrites?.get(decl.peer)) { - return this.rewrites.get(decl.peer)?.() ?? node + if (decl instanceof arkts.ETSParameterExpression && this.rewrites?.has(decl.peer)) { + return this.rewrites.get(decl.peer)!() } } return node diff --git a/arkoala-arkts/memo-plugin/src/utils.ts b/arkoala-arkts/memo-plugin/src/utils.ts index d936ba040ba8e9e209102689571149bbaa204448..fd1cd4691f08c4ef6c156cfdbad977bffe2bd19c 100644 --- a/arkoala-arkts/memo-plugin/src/utils.ts +++ b/arkoala-arkts/memo-plugin/src/utils.ts @@ -20,6 +20,7 @@ export enum RuntimeNames { __CONTEXT = "__context", __ID = "__id", ANNOTATION = "memo", + ANNOTATION_INTRINSIC = "memo_intrinsic", COMPUTE = "compute", CONTEXT = "__memo_context", CONTEXT_TYPE = "__memo_context_type", @@ -81,6 +82,18 @@ export class PositionalIdTracker { } } +export function hasMemoAnnotation(node: arkts.ScriptFunction | arkts.ETSParameterExpression) { + return node.annotations.some((it) => + it.expr !== undefined && arkts.isIdentifier(it.expr) && it.expr.name === RuntimeNames.ANNOTATION + ) +} + +export function hasMemoIntrinsicAnnotation(node: arkts.ScriptFunction | arkts.ETSParameterExpression) { + return node.annotations.some((it) => + it.expr !== undefined && arkts.isIdentifier(it.expr) && it.expr.name === RuntimeNames.ANNOTATION_INTRINSIC + ) +} + /** * TODO: * @deprecated