Spaces:
Running
Running
Zai
commited on
Commit
•
02fa587
1
Parent(s):
4de03b2
training prototype with notebook
Browse files
README.md
CHANGED
@@ -1 +1,2 @@
|
|
1 |
# Headshot Project
|
|
|
|
1 |
# Headshot Project
|
2 |
+
|
notebooks/.ipynb_checkpoints/detection_pytorch-checkpoint.ipynb
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 14,
|
6 |
+
"id": "13343cbd-bede-41d9-9506-08ed04e66cf6",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stderr",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"C:\\Users\\Myo Win Zaw\\.conda\\envs\\ai_env\\lib\\site-packages\\torchvision\\io\\image.py:13: UserWarning: Failed to load image Python extension: '[WinError 127] The specified procedure could not be found'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n",
|
14 |
+
" warn(\n"
|
15 |
+
]
|
16 |
+
}
|
17 |
+
],
|
18 |
+
"source": [
|
19 |
+
"import torch\n",
|
20 |
+
"import torch.nn as nn\n",
|
21 |
+
"import torch.optim as optim\n",
|
22 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
23 |
+
"from torchvision import transforms, datasets\n",
|
24 |
+
"from tqdm import tqdm"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": 15,
|
30 |
+
"id": "31812e9b-1b04-44fc-891e-1e77c866ff75",
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"# declaration"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": 18,
|
40 |
+
"id": "0c3b7de2-b89c-47f8-83d9-d20f27af390a",
|
41 |
+
"metadata": {},
|
42 |
+
"outputs": [],
|
43 |
+
"source": [
|
44 |
+
"# Load image datas\n",
|
45 |
+
"\n",
|
46 |
+
"device = 'cpu'"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 19,
|
52 |
+
"id": "58383873-e2c3-4683-ba04-72bad7a6d773",
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [],
|
55 |
+
"source": [
|
56 |
+
"# dataset \n",
|
57 |
+
"class FaceDataset(Dataset):\n",
|
58 |
+
" def __init__(self,data,labels,transforms=None):\n",
|
59 |
+
" self.tranforms = tranforms\n",
|
60 |
+
" self.data = x_data\n",
|
61 |
+
" self.labels = y_labels\n",
|
62 |
+
"\n",
|
63 |
+
" def __len__(self):\n",
|
64 |
+
" return len(self.data)\n",
|
65 |
+
"\n",
|
66 |
+
" def __getitem__(self, idx):\n",
|
67 |
+
" # Load and preprocess the image at the given index\n",
|
68 |
+
" image = self.data[idx]\n",
|
69 |
+
" label = self.labels[idx]\n",
|
70 |
+
" \n",
|
71 |
+
" if self.transform:\n",
|
72 |
+
" image = self.transform(image)\n",
|
73 |
+
" return image,label"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": 41,
|
79 |
+
"id": "8d6e52a5-b80a-40fd-9a58-40255efba005",
|
80 |
+
"metadata": {},
|
81 |
+
"outputs": [],
|
82 |
+
"source": [
|
83 |
+
"class Detector(nn.Module):\n",
|
84 |
+
" def __init__(self):\n",
|
85 |
+
" super().__init__()\n",
|
86 |
+
" self.conv1 = nn.Conv2d(72,64,4)\n",
|
87 |
+
"\n",
|
88 |
+
" def forward(self):\n",
|
89 |
+
" pass"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": 42,
|
95 |
+
"id": "298e5cd7-d247-437d-ae69-271dbbdfdf03",
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [],
|
98 |
+
"source": [
|
99 |
+
"# optimization\n",
|
100 |
+
"lr = 1e-3\n",
|
101 |
+
"# model = Detector().to(device)\n",
|
102 |
+
"# optimizer = torch.optim.Adam(model)\n",
|
103 |
+
"loss_fn = nn.CrossEntropyLoss()\n",
|
104 |
+
"\n",
|
105 |
+
"num_epochs = 50\n"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": 43,
|
111 |
+
"id": "89877e6b-9682-4d00-acae-94020722e7e1",
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [
|
114 |
+
{
|
115 |
+
"name": "stdout",
|
116 |
+
"output_type": "stream",
|
117 |
+
"text": [
|
118 |
+
"xi\n"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"ename": "TypeError",
|
123 |
+
"evalue": "CrossEntropyLoss.forward() missing 2 required positional arguments: 'input' and 'target'",
|
124 |
+
"output_type": "error",
|
125 |
+
"traceback": [
|
126 |
+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
127 |
+
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
|
128 |
+
"Cell \u001b[1;32mIn[43], line 9\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(dummy_dataset):\n\u001b[0;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mxi\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m----> 9\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mloss_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
129 |
+
"File \u001b[1;32m~\\.conda\\envs\\ai_env\\lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
130 |
+
"File \u001b[1;32m~\\.conda\\envs\\ai_env\\lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
131 |
+
"\u001b[1;31mTypeError\u001b[0m: CrossEntropyLoss.forward() missing 2 required positional arguments: 'input' and 'target'"
|
132 |
+
]
|
133 |
+
}
|
134 |
+
],
|
135 |
+
"source": [
|
136 |
+
"# train model\n",
|
137 |
+
"\n",
|
138 |
+
"dummy_dataset= ['hello']\n",
|
139 |
+
"\n",
|
140 |
+
"for i in range(num_epochs):\n",
|
141 |
+
" for img,label in enumerate(dummy_dataset):\n",
|
142 |
+
" optimizer.zero_grad()\n",
|
143 |
+
" outputs = model(img,label)\n",
|
144 |
+
" loss = loss_fn(outputs,labels)\n",
|
145 |
+
" loss.backward()\n",
|
146 |
+
" optimizer.step()\n",
|
147 |
+
" print(f\"epoch {i} done\")"
|
148 |
+
]
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"cell_type": "code",
|
152 |
+
"execution_count": null,
|
153 |
+
"id": "7a3e59e6-7701-4e11-8d29-c6e5468d63ab",
|
154 |
+
"metadata": {},
|
155 |
+
"outputs": [],
|
156 |
+
"source": [
|
157 |
+
"# eval model\n",
|
158 |
+
"def evaluate():\n",
|
159 |
+
" pass"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"id": "c0ddaf7e-185b-4190-b491-120008d1e1ea",
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": []
|
169 |
+
}
|
170 |
+
],
|
171 |
+
"metadata": {
|
172 |
+
"kernelspec": {
|
173 |
+
"display_name": "Python 3 (ipykernel)",
|
174 |
+
"language": "python",
|
175 |
+
"name": "python3"
|
176 |
+
},
|
177 |
+
"language_info": {
|
178 |
+
"codemirror_mode": {
|
179 |
+
"name": "ipython",
|
180 |
+
"version": 3
|
181 |
+
},
|
182 |
+
"file_extension": ".py",
|
183 |
+
"mimetype": "text/x-python",
|
184 |
+
"name": "python",
|
185 |
+
"nbconvert_exporter": "python",
|
186 |
+
"pygments_lexer": "ipython3",
|
187 |
+
"version": "3.10.13"
|
188 |
+
}
|
189 |
+
},
|
190 |
+
"nbformat": 4,
|
191 |
+
"nbformat_minor": 5
|
192 |
+
}
|
notebooks/detection_pytorch.ipynb
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 14,
|
6 |
+
"id": "13343cbd-bede-41d9-9506-08ed04e66cf6",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stderr",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"C:\\Users\\Myo Win Zaw\\.conda\\envs\\ai_env\\lib\\site-packages\\torchvision\\io\\image.py:13: UserWarning: Failed to load image Python extension: '[WinError 127] The specified procedure could not be found'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n",
|
14 |
+
" warn(\n"
|
15 |
+
]
|
16 |
+
}
|
17 |
+
],
|
18 |
+
"source": [
|
19 |
+
"import torch\n",
|
20 |
+
"import torch.nn as nn\n",
|
21 |
+
"import torch.optim as optim\n",
|
22 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
23 |
+
"from torchvision import transforms, datasets\n",
|
24 |
+
"from tqdm import tqdm"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": 15,
|
30 |
+
"id": "31812e9b-1b04-44fc-891e-1e77c866ff75",
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"# declaration"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": 18,
|
40 |
+
"id": "0c3b7de2-b89c-47f8-83d9-d20f27af390a",
|
41 |
+
"metadata": {},
|
42 |
+
"outputs": [],
|
43 |
+
"source": [
|
44 |
+
"# Load image datas\n",
|
45 |
+
"\n",
|
46 |
+
"device = 'cpu'"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 19,
|
52 |
+
"id": "58383873-e2c3-4683-ba04-72bad7a6d773",
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [],
|
55 |
+
"source": [
|
56 |
+
"# dataset \n",
|
57 |
+
"class FaceDataset(Dataset):\n",
|
58 |
+
" def __init__(self,data,labels,transforms=None):\n",
|
59 |
+
" self.tranforms = tranforms\n",
|
60 |
+
" self.data = x_data\n",
|
61 |
+
" self.labels = y_labels\n",
|
62 |
+
"\n",
|
63 |
+
" def __len__(self):\n",
|
64 |
+
" return len(self.data)\n",
|
65 |
+
"\n",
|
66 |
+
" def __getitem__(self, idx):\n",
|
67 |
+
" # Load and preprocess the image at the given index\n",
|
68 |
+
" image = self.data[idx]\n",
|
69 |
+
" label = self.labels[idx]\n",
|
70 |
+
" \n",
|
71 |
+
" if self.transform:\n",
|
72 |
+
" image = self.transform(image)\n",
|
73 |
+
" return image,label"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": 41,
|
79 |
+
"id": "8d6e52a5-b80a-40fd-9a58-40255efba005",
|
80 |
+
"metadata": {},
|
81 |
+
"outputs": [],
|
82 |
+
"source": [
|
83 |
+
"class Detector(nn.Module):\n",
|
84 |
+
" def __init__(self):\n",
|
85 |
+
" super().__init__()\n",
|
86 |
+
" self.conv1 = nn.Conv2d(72,64,4)\n",
|
87 |
+
"\n",
|
88 |
+
" def forward(self):\n",
|
89 |
+
" pass"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": 42,
|
95 |
+
"id": "298e5cd7-d247-437d-ae69-271dbbdfdf03",
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [],
|
98 |
+
"source": [
|
99 |
+
"# optimization\n",
|
100 |
+
"lr = 1e-3\n",
|
101 |
+
"# model = Detector().to(device)\n",
|
102 |
+
"# optimizer = torch.optim.Adam(model)\n",
|
103 |
+
"loss_fn = nn.CrossEntropyLoss()\n",
|
104 |
+
"\n",
|
105 |
+
"num_epochs = 50\n"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": 43,
|
111 |
+
"id": "89877e6b-9682-4d00-acae-94020722e7e1",
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [
|
114 |
+
{
|
115 |
+
"name": "stdout",
|
116 |
+
"output_type": "stream",
|
117 |
+
"text": [
|
118 |
+
"xi\n"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"ename": "TypeError",
|
123 |
+
"evalue": "CrossEntropyLoss.forward() missing 2 required positional arguments: 'input' and 'target'",
|
124 |
+
"output_type": "error",
|
125 |
+
"traceback": [
|
126 |
+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
127 |
+
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
|
128 |
+
"Cell \u001b[1;32mIn[43], line 9\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(dummy_dataset):\n\u001b[0;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mxi\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m----> 9\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mloss_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
129 |
+
"File \u001b[1;32m~\\.conda\\envs\\ai_env\\lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
130 |
+
"File \u001b[1;32m~\\.conda\\envs\\ai_env\\lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
131 |
+
"\u001b[1;31mTypeError\u001b[0m: CrossEntropyLoss.forward() missing 2 required positional arguments: 'input' and 'target'"
|
132 |
+
]
|
133 |
+
}
|
134 |
+
],
|
135 |
+
"source": [
|
136 |
+
"# train model\n",
|
137 |
+
"\n",
|
138 |
+
"dummy_dataset= ['hello']\n",
|
139 |
+
"\n",
|
140 |
+
"for i in range(num_epochs):\n",
|
141 |
+
" for img,label in enumerate(dummy_dataset):\n",
|
142 |
+
" optimizer.zero_grad()\n",
|
143 |
+
" outputs = model(img,label)\n",
|
144 |
+
" loss = loss_fn(outputs,labels)\n",
|
145 |
+
" loss.backward()\n",
|
146 |
+
" optimizer.step()\n",
|
147 |
+
" print(f\"epoch {i} done\")"
|
148 |
+
]
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"cell_type": "code",
|
152 |
+
"execution_count": null,
|
153 |
+
"id": "7a3e59e6-7701-4e11-8d29-c6e5468d63ab",
|
154 |
+
"metadata": {},
|
155 |
+
"outputs": [],
|
156 |
+
"source": [
|
157 |
+
"# eval model\n",
|
158 |
+
"def evaluate():\n",
|
159 |
+
" pass"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"id": "c0ddaf7e-185b-4190-b491-120008d1e1ea",
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": []
|
169 |
+
}
|
170 |
+
],
|
171 |
+
"metadata": {
|
172 |
+
"kernelspec": {
|
173 |
+
"display_name": "Python 3 (ipykernel)",
|
174 |
+
"language": "python",
|
175 |
+
"name": "python3"
|
176 |
+
},
|
177 |
+
"language_info": {
|
178 |
+
"codemirror_mode": {
|
179 |
+
"name": "ipython",
|
180 |
+
"version": 3
|
181 |
+
},
|
182 |
+
"file_extension": ".py",
|
183 |
+
"mimetype": "text/x-python",
|
184 |
+
"name": "python",
|
185 |
+
"nbconvert_exporter": "python",
|
186 |
+
"pygments_lexer": "ipython3",
|
187 |
+
"version": "3.10.13"
|
188 |
+
}
|
189 |
+
},
|
190 |
+
"nbformat": 4,
|
191 |
+
"nbformat_minor": 5
|
192 |
+
}
|