I create a new struct called HousingData, and also define function such as iterate and length. However, when I use the function collect for my HousingData object, I run into the following error.
TypeError: in typeassert, expected Integer, got a value of type Float64
import Base: length, size, iterate
struct HousingData
x
y
batchsize::Int
shuffle::Bool
num_instances::Int
function HousingData(
x, y; batchsize::Int=100, shuffle::Bool=false, dtype::Type=Array{Float64})
new(convert(dtype,x),convert(dtype,y),batchsize,shuffle,size(y)[end])
end
end
function length(d::HousingData)
return ceil(d.num_instances/d.batchsize)
end
function iterate(d::HousingData, state=ifelse(
d.shuffle, randperm(d.num_instances), collect(1:d.num_instances)))
if(length(state)==0)
return nothing
end
return ((d.x[:,state[1]],d.y[:,state[1]]),state[2:end])
end
x1 = randn(5, 100); y1 = rand(1, 100);
obj = HousingData(x1,y1; batchsize=20)
collect(obj)
There are multiple problems in your code. The first one is related to length not returning an integer, but rather a float. This is explained by the behavior of ceil:
julia> ceil(3.8)
4.0 # Notice: 4.0 (Float64) and not 4 (Int)
You can easily fix this:
function length(d::HousingData)
return Int(ceil(d.num_instances/d.batchsize))
end
Another problem lies in the logic of your iteration function, which is not consistent with the advertised length. To take a smaller example than yours:
julia> x1 = [i+j/10 for i in 1:2, j in 1:6]
2×6 Array{Float64,2}:
1.1 1.2 1.3 1.4 1.5 1.6
2.1 2.2 2.3 2.4 2.5 2.6
# As an aside, unless you really want to work with 1xN matrices
# it is more idiomatic in Julia to use 1D Vectors in such situations
julia> y1 = [Float64(j) for i in 1:1, j in 1:6]
1×6 Array{Float64,2}:
1.0 2.0 3.0 4.0 5.0 6.0
julia> obj = HousingData(x1,y1; batchsize=3)
HousingData([1.1 1.2 … 1.5 1.6; 2.1 2.2 … 2.5 2.6], [1.0 2.0 … 5.0 6.0], 3, false, 6)
julia> length(obj)
2
julia> for (i, e) in enumerate(obj)
println("$i -> $e")
end
1 -> ([1.1, 2.1], [1.0])
2 -> ([1.2, 2.2], [2.0])
3 -> ([1.3, 2.3], [3.0])
4 -> ([1.4, 2.4], [4.0])
5 -> ([1.5, 2.5], [5.0])
6 -> ([1.6, 2.6], [6.0])
The iterator produces 6 elements, whereas the length of this object is only 2. This explains why collect errors out:
julia> collect(obj)
ERROR: ArgumentError: destination has fewer elements than required
Knowing your code, you're probably the best person to fix its logic.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With