Source code for autoqasm.converters.assignments
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Converters for assignment nodes."""
import ast
from malt.core import ag_ctx, converter
from malt.pyct import templates
from autoqasm.operators.assignments import assign_for_output
[docs]
class AssignTransformer(converter.Base):
[docs]
def visit_AugAssign(self, node: ast.stmt) -> ast.stmt:
"""Converts augmented assignment operations (``+=``, ``-=``, etc.) into
regular assignments so they flow through ``assign_stmt``.
``val += expr`` is desugared to ``val = val + expr`` and then
transformed by ``visit_Assign``.
Args:
node (ast.stmt): AST node to transform.
Returns:
ast.stmt: Transformed node.
"""
new_value = ast.BinOp(
left=ast.Name(id=node.target.id, ctx=ast.Load()),
op=node.op,
right=node.value,
)
ast.copy_location(new_value, node)
assign_node = ast.Assign(
targets=[node.target],
value=new_value,
)
ast.copy_location(assign_node, node)
return self.visit_Assign(assign_node)
[docs]
def visit_Assign(self, node: ast.stmt) -> ast.stmt:
"""Converts assignment operations to their AutoQASM counterpart.
Supports assignment to a single variable. Operator declares the
``oq`` variable, or sets variable's value if it's already declared.
Args:
node (ast.stmt): AST node to transform.
Returns:
ast.stmt: Transformed node.
"""
template = """
tar_ = ag__.assign_stmt(tar_name_, val_)
"""
try:
# Assignments for main function return statements have already been handled,
# so return early
if node.value.func.attr == assign_for_output.__name__:
return node
except AttributeError:
pass
node = self.generic_visit(node)
# TODO: implement when target has multiple variable
if len(node.targets) > 1:
raise NotImplementedError
if isinstance(node.targets[0], ast.Name):
target_name = ast.Constant(node.targets[0].id, None)
new_node = templates.replace(
template,
tar_name_=target_name,
tar_=node.targets[0],
val_=node.value,
original=node,
)
else:
new_node = node
return new_node
[docs]
def transform(node: ast.stmt, ctx: ag_ctx.ControlStatusCtx) -> ast.stmt:
"""Transform assignment nodes.
Args:
node (ast.stmt): AST node to transform.
ctx (ag_ctx.ControlStatusCtx): Transformer context.
Returns:
ast.stmt: Transformed node.
"""
return AssignTransformer(ctx).visit(node)