Skip to content

Commit

Permalink
Address review
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolevn committed Sep 20, 2024
1 parent 69251ac commit 5ec3a08
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
13 changes: 10 additions & 3 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,15 +1551,22 @@ class C(Base):
annotations[name] = tp

def annotate_method(format):
typing = sys.modules.get("typing")
if typing is None and format == annotationlib.Format.FORWARDREF:
typing_any = annotationlib.ForwardRef("Any", module="typing")
return {
ann: typing_any if t is any_marker else t
for ann, t in annotations.items()
}

from typing import Any, _convert_to_source
ann_dict = {
ann: Any if t is any_marker else t
for ann, t in annotations.items()
}
if format == 1 or format == 2:
return ann_dict
else:
if format == annotationlib.Format.SOURCE:
return _convert_to_source(ann_dict)
return ann_dict

# Update 'ns' with the user-supplied namespace plus our calculated values.
def exec_body_callback(ns):
Expand Down
29 changes: 23 additions & 6 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import *

import abc
import annotationlib
import io
import pickle
import inspect
Expand All @@ -23,6 +24,7 @@
import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.

from test import support
from test.support import import_helper

# Just any custom exception we can catch.
class CustomError(Exception): pass
Expand Down Expand Up @@ -3667,7 +3669,6 @@ class A(WithDictSlot): ...
@support.cpython_only
def test_dataclass_slot_dict_ctype(self):
# https://github.com/python/cpython/issues/123935
from test.support import import_helper
# Skips test if `_testcapi` is not present:
_testcapi = import_helper.import_module('_testcapi')

Expand Down Expand Up @@ -4171,23 +4172,39 @@ def test_no_types(self):
'z': typing.Any})

def test_no_types_get_annotations(self):
from annotationlib import Format, get_annotations

C = make_dataclass('C', ['x', ('y', int), 'z'])

self.assertEqual(
get_annotations(C, format=Format.VALUE),
annotationlib.get_annotations(C, format=annotationlib.Format.VALUE),
{'x': typing.Any, 'y': int, 'z': typing.Any},
)
self.assertEqual(
get_annotations(C, format=Format.FORWARDREF),
annotationlib.get_annotations(
C, format=annotationlib.Format.FORWARDREF),
{'x': typing.Any, 'y': int, 'z': typing.Any},
)
self.assertEqual(
get_annotations(C, format=Format.SOURCE),
annotationlib.get_annotations(
C, format=annotationlib.Format.SOURCE),
{'x': 'typing.Any', 'y': 'int', 'z': 'typing.Any'},
)

def test_no_types_no_typing_import(self):
import sys

C = make_dataclass('C', ['x', ('y', int)])

with import_helper.CleanImport('typing'):
self.assertEqual(
annotationlib.get_annotations(
C, format=annotationlib.Format.FORWARDREF),
{
'x': annotationlib.ForwardRef('Any', module='typing'),
'y': int,
},
)
self.assertNotIn('typing', sys.modules)

def test_module_attr(self):
self.assertEqual(ByMakeDataClass.__module__, __name__)
self.assertEqual(ByMakeDataClass(1).__module__, __name__)
Expand Down

0 comments on commit 5ec3a08

Please sign in to comment.